In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import time
import os
import dask
#import torch

DATA_DIRECTORY = '/oak/stanford/groups/earlew/yuchen'

RAW_DATA_DIRECTORY = '/scratch/users/yucli/cesm_data'

# Renamed variable names 
VAR_NAMES = ["icefrac", "temp", "geopotential", "icethick", "lw_flux", "sw_flux", "ua", "va"]


In [4]:
file_list = sorted(os.listdir(f"{RAW_DATA_DIRECTORY}/icefrac"))
ds_list = []

for file in file_list:
    ds = xr.open_dataset(os.path.join(f"{RAW_DATA_DIRECTORY}/icefrac", file), chunks={'time': 120})
    ds_list.append(ds)

merged_ds = xr.concat(ds_list, dim="member_id")
merged_ds


Unnamed: 0,Array,Chunk
Bytes,7.18 GiB,2.93 MiB
Shape,"(100, 3012, 80, 80)","(1, 120, 80, 80)"
Dask graph,2600 chunks in 301 graph layers,2600 chunks in 301 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 7.18 GiB 2.93 MiB Shape (100, 3012, 80, 80) (1, 120, 80, 80) Dask graph 2600 chunks in 301 graph layers Data type float32 numpy.ndarray",100  1  80  80  3012,

Unnamed: 0,Array,Chunk
Bytes,7.18 GiB,2.93 MiB
Shape,"(100, 3012, 80, 80)","(1, 120, 80, 80)"
Dask graph,2600 chunks in 301 graph layers,2600 chunks in 301 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [8]:
merged_ds.assign_coords(time=pd.date_range("1850-01", "2100-12", freq="MS")).icefrac.time.dt.month

In [9]:

def normalize(x, m, s, var_name=None):
    # Avoid divide by zero by setting normalized value to zero where std deviation is zero
    with np.errstate(divide='ignore', invalid='ignore'):
        normalized = (x - m) / s
        normalized = np.where(s == 0, 0, normalized)  # Set to zero where std dev is zero

    # For SST below sea ice, the stdev is very low. Normalized values are set to 0 
    # if the stdev is below threshold value
    if var_name == "temp":
        threshold = 0.001
        normalized = np.where(s <= threshold, 0, normalized)

    return normalized


def normalize_data(overwrite=False, verbose=1, vars_to_normalize="all"):
    """ 
    Normalize inputs based on statistics of the training data and save. 
    """

    if vars_to_normalize == "all":
        vars_to_normalize = VAR_NAMES
    
    if verbose >= 1: print(f"Normalizing variables {vars_to_normalize}")

    save_dir = os.path.join(RAW_DATA_DIRECTORY, "/normalized_inputs")
    os.makedirs(save_dir, exist_ok=True)

    for var_name in vars_to_normalize:
        save_path = os.path.join(save_dir, f"{var_name}_norm.nc")
        if os.path.exists(save_path) and not overwrite:
            if verbose >= 1: print(f"Already found normalized file for {var_name}. Skipping...")
            continue

        print(f"Normalizing {var_name}...", end=" ")

        # First make a merged dataset from the separate ones 
        file_list = sorted(os.listdir(f"{RAW_DATA_DIRECTORY}/{var_name}"))
        ds_list = []

        for file in file_list:
            ds = xr.open_dataset(os.path.join(f"{RAW_DATA_DIRECTORY}/{var_name}", file), chunks={'time': 120})
            ds_list.append(ds)

        merged_ds = xr.concat(ds_list, dim="member_id")

        # change the time index to pandas instead of cftime 
        merged_ds[var_name].assign_coords(time=pd.date_range("1850-01", "2100-12", freq="MS"))

        # save the merged ds before normalizing 
        merged_ds.to_netcdf(f"{RAW_DATA_DIRECTORY}/{var_name}/{var_name}_combined.nc")
        
        # now calculate the climatology. We define this as the period from 1850 to 1980 
        # across all ensemble members. This means that the climate change signal, especially
        # for the ssp simulations, will be present. 
        da = ds[var_name]
        print("calculating means and stdev...", end=" ")

        time_subset = pd.date_range("1850-01", "1979-12", freq="MS")
        monthly_means = da.sel(time=time_subset).groupby("time.month").mean("time", "member_id").load()
        monthly_stdevs = da.sel(time=time_subset).groupby("time.month").std("time", "member_id").load()
        print("done!")

        months = da['time'].dt.month
        normalized_da = xr.apply_ufunc(
            normalize,
            da,
            monthly_means.sel(month=months),
            monthly_stdevs.sel(month=months),
            var_name,
            output_dtypes=[da.dtype]
        )
        
        normalized_ds = normalized_da.to_dataset(name=var_name)
        monthly_means_ds = monthly_means.to_dataset(name=var_name)
        monthly_stdevs_ds = monthly_stdevs.to_dataset(name=var_name)

        print("Saving...", end="")
        write_nc_file(monthly_means_ds, os.path.join(save_dir, f"{var_name}_mean.nc"), overwrite)
        write_nc_file(monthly_stdevs_ds, os.path.join(save_dir, f"{var_name}_stdev.nc"), overwrite)
        write_nc_file(normalized_ds, os.path.join(save_dir, f"{var_name}_norm.nc"), overwrite)
        print("done!")

    print("done! \n\n")


In [6]:
class SeaIceDataset(torch.utils.data.Dataset):
    def __init__(self, data_directory, configuration, split_array, start_prediction_months, \
                split_type='train', target_shape=(80, 80), mode="regression", class_splits=None):
        self.data_directory = data_directory
        self.configuration = configuration
        self.split_array = split_array
        self.start_prediction_months = start_prediction_months
        self.split_type = split_type
        self.target_shape = target_shape
        self.class_splits = class_splits
        self.mode = mode

        # Open the HDF5 files
        self.inputs_file = h5py.File(f"{data_directory}/inputs_{configuration}.h5", 'r')

        if "sicanom" in configuration: 
            targets_configuration = "anom_regression" 
        else: 
            targets_configuration = "regression"

        self.targets_file = h5py.File(f"{data_directory}/targets_{targets_configuration}.h5", 'r')
        
        self.inputs = self.inputs_file[f"inputs_{configuration}"]
        self.targets = self.targets_file['targets_sea_ice_only']

        self.n_samples, self.n_channels, self.n_y, self.n_x = self.inputs.shape
        
        # Get indices for the specified split type
        self.indices = np.where(self.split_array == split_type)[0]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        input_data = self.inputs[actual_idx]
        target_data = self.targets[actual_idx]
        start_prediction_month = self.start_prediction_months[actual_idx]

        # Pad input_data and target_data to the target shape
        pad_y = self.target_shape[0] - self.n_y
        pad_x = self.target_shape[1] - self.n_x
        input_data = np.pad(input_data, ((0, 0), (pad_y//2, pad_y//2), (pad_x//2, pad_x//2)), mode='constant', constant_values=0)
        target_data = np.pad(target_data, ((0, 0), (pad_y//2, pad_y//2), (pad_x//2, pad_x//2)), mode='constant', constant_values=0)

        # If we are doing classification, then discretise the target data
        if self.mode == "classification":
            if self.class_splits is None:
                raise ValueError("need to specify a monotonically increasing list class_splits denoting class boundaries")

            # check if class_split is monotonically increasing
            if len(self.class_splits) > 1 and np.any(np.diff(self.class_splits) < 0): 
                raise ValueError("class_splits needs to be monotonically increasing")

            bounds = [] # bounds for classes
            for i,class_split in enumerate(self.class_splits): 
                if i == 0: 
                    bounds.append([0, class_split])
                if i == len(self.class_splits) - 1: 
                    bounds.append([class_split, 1])
                else: 
                    bounds.append([class_split, self.class_splits[i+1]])
            
            target_classes_data = np.zeros_like(target_data) 
            target_classes_data = target_classes_data[np.newaxis,:,:,:]
            target_classes_data = np.repeat(target_classes_data, len(bounds), axis=0)
            for i,bound in enumerate(bounds): 
                if i == len(bounds) - 1: 
                    target_classes_data[i,:,:,:] = np.logical_and(target_data >= bound[0], target_data <= bound[1]).astype(int)
                else:
                    target_classes_data[i,:,:,:] = np.logical_and(target_data >= bound[0], target_data < bound[1]).astype(int)
            
            target_data = target_classes_data 

        input_tensor = torch.tensor(input_data, dtype=torch.float32)
        target_tensor = torch.tensor(target_data, dtype=torch.float32)

        # Get the target months for this sample
        target_months = pd.date_range(start=start_prediction_month, end=start_prediction_month + pd.DateOffset(months=5), freq="MS")
        target_months = target_months.month.to_numpy()
        
        return input_tensor, target_tensor, target_months

    def __del__(self):
        self.inputs_file.close()
        self.targets_file.close()




In [1]:
import xarray as xr
xr.open_dataset("/scratch/users/yucli/cesm_data/temp/temp_member_00.nc")