In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import os
import sys
import time

DATA_DIRECTORY = '/oak/stanford/groups/earlew/yuchen'

RAW_DATA_DIRECTORY = '/scratch/users/yucli/cesm_data'

# Renamed variable names 
VAR_NAMES = ["icefrac", "icethick", "temp", "geopotential", "psl", "lw_flux", "sw_flux", "ua"]


In [6]:
import os
import torch
from torch.utils.data import Dataset
import xarray as xr
import numpy as np

class CESM_SeaIceDataset(Dataset):
    def __init__(self, data_dir, ensemble_members, transform=None):
        """
        Param:
            data_dir (str): Path to the directory containing model-ready input and target files 
            ensemble_members (list): List of ensemble member ids (ripf notation)
            transform (callable, optional): Optional transform to apply to the samples.
        """
        self.data_dir = data_dir
        self.ensemble_members = ensemble_members
        self.transform = transform

        # Build a global index of samples
        self.samples = []
        for member in ensemble_members:
            input_file = os.path.join(data_dir, f"inputs_member_{member}.nc")
            with xr.open_dataset(input_file) as ds:
                for start_idx in range(len(ds["start_prediction_month"])):
                    self.samples.append((member, start_idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        member, start_idx = self.samples[idx]
        input_file = os.path.join(self.data_dir, f"inputs_member_{member}.nc")
        target_file = os.path.join(self.data_dir, f"targets_member_{member}.nc")

        # Load the specific sample lazily
        with xr.open_dataset(input_file) as input_ds:
            input_sample = input_ds["data"].isel(start_prediction_month=start_idx).values

        with xr.open_dataset(target_file) as target_ds:
            target_sample = target_ds["data"].isel(start_prediction_month=start_idx).values

        sample = {"input": torch.tensor(input_sample, dtype=torch.float32),
                  "target": torch.tensor(target_sample, dtype=torch.float32)}

        if self.transform:
            sample = self.transform(sample)

        return sample

data_dir = "/scratch/users/yucli/model-ready_cesm_data/data_pairs_setting1"
ensemble_members = np.unique([name.split("_")[2].split(".")[0] for name in os.listdir(data_dir)])


dataset = CESM_SeaIceDataset(data_dir, ensemble_members)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)


In [6]:
class SeaIceDataset(torch.utils.data.Dataset):
    def __init__(self, data_directory, configuration, split_array, start_prediction_months, \
                split_type='train', target_shape=(80, 80), mode="regression", class_splits=None):
        self.data_directory = data_directory
        self.configuration = configuration
        self.split_array = split_array
        self.start_prediction_months = start_prediction_months
        self.split_type = split_type
        self.target_shape = target_shape
        self.class_splits = class_splits
        self.mode = mode

        # Open the HDF5 files
        self.inputs_file = h5py.File(f"{data_directory}/inputs_{configuration}.h5", 'r')

        if "sicanom" in configuration: 
            targets_configuration = "anom_regression" 
        else: 
            targets_configuration = "regression"

        self.targets_file = h5py.File(f"{data_directory}/targets_{targets_configuration}.h5", 'r')
        
        self.inputs = self.inputs_file[f"inputs_{configuration}"]
        self.targets = self.targets_file['targets_sea_ice_only']

        self.n_samples, self.n_channels, self.n_y, self.n_x = self.inputs.shape
        
        # Get indices for the specified split type
        self.indices = np.where(self.split_array == split_type)[0]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        input_data = self.inputs[actual_idx]
        target_data = self.targets[actual_idx]
        start_prediction_month = self.start_prediction_months[actual_idx]

        # Pad input_data and target_data to the target shape
        pad_y = self.target_shape[0] - self.n_y
        pad_x = self.target_shape[1] - self.n_x
        input_data = np.pad(input_data, ((0, 0), (pad_y//2, pad_y//2), (pad_x//2, pad_x//2)), mode='constant', constant_values=0)
        target_data = np.pad(target_data, ((0, 0), (pad_y//2, pad_y//2), (pad_x//2, pad_x//2)), mode='constant', constant_values=0)

        # If we are doing classification, then discretise the target data
        if self.mode == "classification":
            if self.class_splits is None:
                raise ValueError("need to specify a monotonically increasing list class_splits denoting class boundaries")

            # check if class_split is monotonically increasing
            if len(self.class_splits) > 1 and np.any(np.diff(self.class_splits) < 0): 
                raise ValueError("class_splits needs to be monotonically increasing")

            bounds = [] # bounds for classes
            for i,class_split in enumerate(self.class_splits): 
                if i == 0: 
                    bounds.append([0, class_split])
                if i == len(self.class_splits) - 1: 
                    bounds.append([class_split, 1])
                else: 
                    bounds.append([class_split, self.class_splits[i+1]])
            
            target_classes_data = np.zeros_like(target_data) 
            target_classes_data = target_classes_data[np.newaxis,:,:,:]
            target_classes_data = np.repeat(target_classes_data, len(bounds), axis=0)
            for i,bound in enumerate(bounds): 
                if i == len(bounds) - 1: 
                    target_classes_data[i,:,:,:] = np.logical_and(target_data >= bound[0], target_data <= bound[1]).astype(int)
                else:
                    target_classes_data[i,:,:,:] = np.logical_and(target_data >= bound[0], target_data < bound[1]).astype(int)
            
            target_data = target_classes_data 

        input_tensor = torch.tensor(input_data, dtype=torch.float32)
        target_tensor = torch.tensor(target_data, dtype=torch.float32)

        # Get the target months for this sample
        target_months = pd.date_range(start=start_prediction_month, end=start_prediction_month + pd.DateOffset(months=5), freq="MS")
        target_months = target_months.month.to_numpy()
        
        return input_tensor, target_tensor, target_months

    def __del__(self):
        self.inputs_file.close()
        self.targets_file.close()


