In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import time
import os
import pyproj
import sys

DATA_DIRECTORY = '/oak/stanford/groups/earlew/yuchen'

RAW_DATA_DIRECTORY = '/scratch/users/yucli/cesm_data'

# Renamed variable names 
VAR_NAMES = ["icefrac", "temp", "geopotential", "icethick", "lw_flux", "sw_flux", "ua", "va"]


In [7]:
def generate_sps_grid(grid_size=80, lat_boundary=-52.5):
    # Define the South Polar Stereographic projection (EPSG:3031)
    proj_south_pole = pyproj.Proj(proj='stere', lat_0=-90, lon_0=0, lat_ts=-70)

    # Define the geographic coordinate system (EPSG:4326)
    proj_geographic = pyproj.Proj(proj='latlong', datum='WGS84')

    # Compute the maximum radius from the South Pole in stereographic coordinates
    _, max_radius = proj_south_pole(0, lat_boundary)

    x = np.linspace(-max_radius, max_radius, grid_size)
    y = np.linspace(-max_radius, max_radius, grid_size)
    X, Y = np.meshgrid(x, y)

    lon, lat = pyproj.transform(proj_south_pole, proj_geographic, X, Y)

    output_grid = xr.Dataset(
        {
            "lat": (["y", "x"], lat),
            "lon": (["y", "x"], lon),
        },
        coords={
            "x": (["x"], x),
            "y": (["y"], y),
        }
    )

    return output_grid


def regrid_variable(ds_to_regrid, input_grid, output_grid):
    start_time = time.time()
    weight_file = f'{DATA_DIRECTORY}/cesm_lens/grids/cesm_{input_grid}_to_sps_bilinear_regridding_weights.nc'

    if os.path.exists(weight_file):
        regridder = xe.Regridder(ds_to_regrid, output_grid, 'bilinear', weights=weight_file, 
                                ignore_degenerate=True, reuse_weights=True, periodic=True)
    else:
        regridder = xe.Regridder(ds_to_regrid, output_grid, 'bilinear', filename=weight_file, 
                                ignore_degenerate=True, reuse_weights=False, periodic=True)

    ds_regridded = regridder(ds_to_regrid).load()
    end_time = time.time()
    elapsed_time = end_time - start_time

    print(f"done! (Time taken: {elapsed_time:.2f} seconds)")
    
    return ds_regridded


In [34]:

ds1 = xr.open_dataset("/scratch/users/yucli/cesm_data/normalized_inputs/icefrac_anom.nc")
#ds2 = xr.open_dataset("/scratch/users/yucli/cesm_temp_hist_regridded/temp_member_r1011i1p1f2.nc")
ds1.member_id.values

array(['r10i1181p1f1', 'r10i1231p1f1', 'r10i1251p1f1', 'r10i1281p1f1',
       'r10i1301p1f1', 'r1i1001p1f1', 'r1i1231p1f1', 'r1i1251p1f1',
       'r1i1281p1f1', 'r1i1301p1f1', 'r2i1021p1f1', 'r2i1231p1f1',
       'r2i1251p1f1', 'r2i1281p1f1', 'r2i1301p1f1', 'r3i1041p1f1',
       'r3i1231p1f1', 'r3i1251p1f1', 'r3i1281p1f1', 'r3i1301p1f1',
       'r4i1061p1f1', 'r4i1231p1f1', 'r4i1251p1f1', 'r4i1281p1f1',
       'r4i1301p1f1', 'r5i1081p1f1', 'r5i1231p1f1', 'r5i1251p1f1',
       'r5i1281p1f1', 'r5i1301p1f1', 'r6i1101p1f1', 'r6i1231p1f1',
       'r6i1251p1f1', 'r6i1281p1f1', 'r6i1301p1f1', 'r7i1121p1f1',
       'r7i1231p1f1', 'r7i1251p1f1', 'r7i1281p1f1', 'r7i1301p1f1',
       'r8i1141p1f1', 'r8i1231p1f1', 'r8i1251p1f1', 'r8i1281p1f1',
       'r8i1301p1f1', 'r9i1161p1f1', 'r9i1231p1f1', 'r9i1251p1f1',
       'r9i1281p1f1', 'r9i1301p1f1', 'r10i1191p1f2', 'r11i1231p1f2',
       'r11i1251p1f2', 'r11i1281p1f2', 'r11i1301p1f2', 'r12i1231p1f2',
       'r12i1251p1f2', 'r12i1281p1f2', 'r12i1301p1f

In [15]:
files = sorted(os.listdir("/scratch/users/yucli/cesm_temp_raw"))

# remove all non netcdf files. I'm not sure why the bash script doesn't get removed 
# by the for loop..?
files.remove('wget-ucar.cgd.cesm2le.ocn.proc.monthly_ave.TEMP.AllFiles-20241128T1755.sh')
for f in files: 
    if f[-3:] != ".nc": files.remove(f)

# get the member id tag from the filename 
# the format is ####-### (init year - realization)
member_ids = [] 

for f in files:
    member_id = f.split("-")[1][0:8]
    member_ids.append(member_id)

member_ids = np.unique(member_ids)
member_ids

array(['1011.001', '1031.002', '1051.003', '1071.004', '1091.005',
       '1111.006', '1131.007', '1151.008', '1171.009', '1191.010',
       '1231.011', '1231.012', '1231.013', '1231.014', '1231.015',
       '1231.016', '1231.017', '1231.018', '1231.019', '1231.020',
       '1251.011', '1251.012', '1251.013', '1251.014', '1251.015',
       '1251.016', '1251.017', '1251.018', '1251.019', '1251.020',
       '1281.011', '1281.012', '1281.013', '1281.014', '1281.015',
       '1281.016', '1281.017', '1281.018', '1281.019', '1281.020',
       '1301.011', '1301.012', '1301.013', '1301.014', '1301.015',
       '1301.016', '1301.017', '1301.018', '1301.019', '1301.020'],
      dtype='<U8')

In [5]:
def get_member_ids(dir="/scratch/users/yucli/cesm_temp_raw"):
    """
    Generates a list of unique member_ids for downloaded raw thetao files

    Returns: 
        (list)  list of netCDF files in dir 
        (list)  items of format ####-### (init year-realization number)
    """
    
    files = sorted(os.listdir(dir))

    # remove all non netcdf files. I'm not sure why the bash script doesn't get removed 
    # by the for loop..?
    files.remove('wget-ucar.cgd.cesm2le.ocn.proc.monthly_ave.TEMP.AllFiles-20241128T1755.sh')
    for f in files: 
        if f[-3:] != ".nc": files.remove(f)

    # get the member id tag from the filename 
    # the format is ####-### (init year - realization)
    member_ids = [] 

    for f in files:
        member_id = f.split("-")[1][0:8]
        member_ids.append(member_id)

    member_ids = np.unique(member_ids)

    return files, member_ids


DATA_DIRECTORY = '/oak/stanford/groups/earlew/yuchen'

CESM_OCEAN_GRID = xr.open_dataset(f"{DATA_DIRECTORY}/cesm_lens/grids/ocean_grid.nc")


In [11]:
files, member_ids = get_member_ids()

output_grid = generate_sps_grid()

# get each member separately 
for i,member_id in enumerate(member_ids):
    save_path = f"/scratch/users/yucli/cesm_temp_hist_regridded/temp_member_{member_id}.nc"
    if os.path.exists(save_path):
        print(f"already found existing {save_path}, skipping")
        continue

    files_subset = []
    for f in files:
        if member_id in f: files_subset.append(os.path.join("/scratch/users/yucli/cesm_temp_raw", f))
    
    ds = xr.open_mfdataset(files_subset)

    subset = ds.TEMP.isel(z_t=0, nlat=slice(0, 93))

    lat = CESM_OCEAN_GRID.lat.sel(nlat=slice(0, 93)).data
    lon = CESM_OCEAN_GRID.lon.sel(nlat=slice(0, 93)).data

    subset = subset.assign_coords(lat=(["nlat", "nlon"], lat), lon=(["nlat", "nlon"], lon))
    subset = subset.to_dataset(name="temp")

    regridded_subset = regrid_variable(subset, "ocn", output_grid)
    regridded_subset = regridded_subset.assign_attrs(member_id=member_id)

    # save 
    regridded_subset.to_netcdf(save_path)



  lon, lat = pyproj.transform(proj_south_pole, proj_geographic, X, Y)


already found existing /scratch/users/yucli/cesm_temp_hist_regridded/temp_member_1011.001.nc, skipping
done! (Time taken: 90.50 seconds)
done! (Time taken: 88.16 seconds)
done! (Time taken: 92.06 seconds)
done! (Time taken: 86.26 seconds)
done! (Time taken: 83.60 seconds)
done! (Time taken: 88.94 seconds)
done! (Time taken: 87.20 seconds)
done! (Time taken: 88.07 seconds)
done! (Time taken: 90.35 seconds)
done! (Time taken: 92.69 seconds)


KeyboardInterrupt: 

In [35]:
xr.open_dataset(f"")

Unnamed: 0,Array,Chunk
Bytes,232.50 kiB,232.50 kiB
Shape,"(93, 320)","(93, 320)"
Dask graph,1 chunks in 81 graph layers,1 chunks in 81 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 232.50 kiB 232.50 kiB Shape (93, 320) (93, 320) Dask graph 1 chunks in 81 graph layers Data type float64 numpy.ndarray",320  93,

Unnamed: 0,Array,Chunk
Bytes,232.50 kiB,232.50 kiB
Shape,"(93, 320)","(93, 320)"
Dask graph,1 chunks in 81 graph layers,1 chunks in 81 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,232.50 kiB,232.50 kiB
Shape,"(93, 320)","(93, 320)"
Dask graph,1 chunks in 81 graph layers,1 chunks in 81 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 232.50 kiB 232.50 kiB Shape (93, 320) (93, 320) Dask graph 1 chunks in 81 graph layers Data type float64 numpy.ndarray",320  93,

Unnamed: 0,Array,Chunk
Bytes,232.50 kiB,232.50 kiB
Shape,"(93, 320)","(93, 320)"
Dask graph,1 chunks in 81 graph layers,1 chunks in 81 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,232.50 kiB,232.50 kiB
Shape,"(93, 320)","(93, 320)"
Dask graph,1 chunks in 81 graph layers,1 chunks in 81 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 232.50 kiB 232.50 kiB Shape (93, 320) (93, 320) Dask graph 1 chunks in 81 graph layers Data type float64 numpy.ndarray",320  93,

Unnamed: 0,Array,Chunk
Bytes,232.50 kiB,232.50 kiB
Shape,"(93, 320)","(93, 320)"
Dask graph,1 chunks in 81 graph layers,1 chunks in 81 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,232.50 kiB,232.50 kiB
Shape,"(93, 320)","(93, 320)"
Dask graph,1 chunks in 81 graph layers,1 chunks in 81 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 232.50 kiB 232.50 kiB Shape (93, 320) (93, 320) Dask graph 1 chunks in 81 graph layers Data type float64 numpy.ndarray",320  93,

Unnamed: 0,Array,Chunk
Bytes,232.50 kiB,232.50 kiB
Shape,"(93, 320)","(93, 320)"
Dask graph,1 chunks in 81 graph layers,1 chunks in 81 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,224.78 MiB,58.12 kiB
Shape,"(1980, 93, 320)","(1, 93, 160)"
Dask graph,3960 chunks in 36 graph layers,3960 chunks in 36 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 224.78 MiB 58.12 kiB Shape (1980, 93, 320) (1, 93, 160) Dask graph 3960 chunks in 36 graph layers Data type float32 numpy.ndarray",320  93  1980,

Unnamed: 0,Array,Chunk
Bytes,224.78 MiB,58.12 kiB
Shape,"(1980, 93, 320)","(1, 93, 160)"
Dask graph,3960 chunks in 36 graph layers,3960 chunks in 36 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [6]:
class SeaIceDataset(torch.utils.data.Dataset):
    def __init__(self, data_directory, configuration, split_array, start_prediction_months, \
                split_type='train', target_shape=(80, 80), mode="regression", class_splits=None):
        self.data_directory = data_directory
        self.configuration = configuration
        self.split_array = split_array
        self.start_prediction_months = start_prediction_months
        self.split_type = split_type
        self.target_shape = target_shape
        self.class_splits = class_splits
        self.mode = mode

        # Open the HDF5 files
        self.inputs_file = h5py.File(f"{data_directory}/inputs_{configuration}.h5", 'r')

        if "sicanom" in configuration: 
            targets_configuration = "anom_regression" 
        else: 
            targets_configuration = "regression"

        self.targets_file = h5py.File(f"{data_directory}/targets_{targets_configuration}.h5", 'r')
        
        self.inputs = self.inputs_file[f"inputs_{configuration}"]
        self.targets = self.targets_file['targets_sea_ice_only']

        self.n_samples, self.n_channels, self.n_y, self.n_x = self.inputs.shape
        
        # Get indices for the specified split type
        self.indices = np.where(self.split_array == split_type)[0]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        input_data = self.inputs[actual_idx]
        target_data = self.targets[actual_idx]
        start_prediction_month = self.start_prediction_months[actual_idx]

        # Pad input_data and target_data to the target shape
        pad_y = self.target_shape[0] - self.n_y
        pad_x = self.target_shape[1] - self.n_x
        input_data = np.pad(input_data, ((0, 0), (pad_y//2, pad_y//2), (pad_x//2, pad_x//2)), mode='constant', constant_values=0)
        target_data = np.pad(target_data, ((0, 0), (pad_y//2, pad_y//2), (pad_x//2, pad_x//2)), mode='constant', constant_values=0)

        # If we are doing classification, then discretise the target data
        if self.mode == "classification":
            if self.class_splits is None:
                raise ValueError("need to specify a monotonically increasing list class_splits denoting class boundaries")

            # check if class_split is monotonically increasing
            if len(self.class_splits) > 1 and np.any(np.diff(self.class_splits) < 0): 
                raise ValueError("class_splits needs to be monotonically increasing")

            bounds = [] # bounds for classes
            for i,class_split in enumerate(self.class_splits): 
                if i == 0: 
                    bounds.append([0, class_split])
                if i == len(self.class_splits) - 1: 
                    bounds.append([class_split, 1])
                else: 
                    bounds.append([class_split, self.class_splits[i+1]])
            
            target_classes_data = np.zeros_like(target_data) 
            target_classes_data = target_classes_data[np.newaxis,:,:,:]
            target_classes_data = np.repeat(target_classes_data, len(bounds), axis=0)
            for i,bound in enumerate(bounds): 
                if i == len(bounds) - 1: 
                    target_classes_data[i,:,:,:] = np.logical_and(target_data >= bound[0], target_data <= bound[1]).astype(int)
                else:
                    target_classes_data[i,:,:,:] = np.logical_and(target_data >= bound[0], target_data < bound[1]).astype(int)
            
            target_data = target_classes_data 

        input_tensor = torch.tensor(input_data, dtype=torch.float32)
        target_tensor = torch.tensor(target_data, dtype=torch.float32)

        # Get the target months for this sample
        target_months = pd.date_range(start=start_prediction_month, end=start_prediction_month + pd.DateOffset(months=5), freq="MS")
        target_months = target_months.month.to_numpy()
        
        return input_tensor, target_tensor, target_months

    def __del__(self):
        self.inputs_file.close()
        self.targets_file.close()




In [1]:
import xarray as xr
xr.open_dataset("/scratch/users/yucli/cesm_data/temp/temp_member_00.nc")