In [2]:
import os
import torch
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

In [3]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

NVIDIA A100 80GB PCIe


In [70]:
from torch.utils.data import Dataset
import torchvision.transforms as T
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)

class E33OMA(Dataset):

    def __init__(self, period, species, padding, root='/home/serfani/serfani_data0/E33OMA'):
        super(E33OMA, self).__init__()
        
        self.period  = period
        self.species = species
        self.padding = padding
        self.root    = root
        
        self._get_data_index()
    
    def _get_data_index(self):
        
        for root, dirs, files in os.walk(self.root):
            
            sorted_files = sorted(files)
            list1 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'aijlh1E33oma_ai']   # Velocity Fields (time, level, lat, lon)
            # list2 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'cijh1E33oma_ai']    # Precipitation (time, lat, lon)
            # list3 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'taijh1E33oma_ai']   # Sea Salt Src (time, lat, lon)
            # list3 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'tNDaijh1E33oma_ai'] # Clay-BCB Src (time, lat, lon)
            # list5 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'taijlh1E33oma_ai']  # Aerosols Mixing Ratio (time, level, lat, lon)

        # Convert `cftime.DatetimeNoLeap` to `pandas.to_datetime()`
        datetimeindex = xr.open_mfdataset(list1[:365]).indexes['time'].to_datetimeindex()

        idx = np.arange(len(datetimeindex))
        rng = np.random.default_rng(0)
        rng.shuffle(idx)
        
        if   self.period == 'train':
            self.datetimeindex = datetimeindex[idx[:12264]]
        
        elif self.period == 'val':
            self.datetimeindex = datetimeindex[idx[12264:]]

    def __getitem__(self, index):
        
        timestep = self.datetimeindex[index].strftime('%Y%m%d')
        
        ds1 = xr.open_dataset(os.path.join(self.root, f'{timestep}.aijlh1E33oma_ai.nc'))
        ds1['time'] = ds1.indexes['time'].to_datetimeindex()
        
        ds2 = xr.open_dataset(os.path.join(self.root, f'{timestep}.cijh1E33oma_ai.nc'))
        ds2['time'] = ds2.indexes['time'].to_datetimeindex()

        X1 = np.expand_dims(ds1['u'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)
        X2 = np.expand_dims(ds1['v'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)
        X3 = np.expand_dims(ds1['omega'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)

        X4 = np.expand_dims(ds2['prec'].sel(time=self.datetimeindex[index]), axis=0)

        if self.species == 'seasalt':
            # Add positive lag for target variable
            ds3 = xr.open_dataset(os.path.join(self.root, f'{timestep}.taijh1E33oma_ai.nc'))
            ds3['time'] = ds3.indexes['time'].to_datetimeindex()

            ds4 = xr.open_dataset(os.path.join(self.root, f'{timestep}.taijlh1E33oma_ai.nc'))
            ds4['time'] = ds4.indexes['time'].to_datetimeindex()

            X5 = np.expand_dims(ds3['seasalt1_ocean_src'].sel(time=self.datetimeindex[index]), axis=0)
            y  = np.expand_dims(ds4['seasalt1'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)
        
        if self.species == 'clay':
            # Add positive lag for target variable
            ds3 = xr.open_dataset(os.path.join(self.root, f'{timestep}.tNDaijh1E33oma_ai.nc'))
            ds3['time'] = ds3.indexes['time'].to_datetimeindex()

            ds4 = xr.open_dataset(os.path.join(self.root, f'{timestep}.taijlh1E33oma_ai.nc'))
            ds4['time'] = ds4.indexes['time'].to_datetimeindex()

            X5 = np.expand_dims(ds3['Clay_emission'].sel(time=self.datetimeindex[index]), axis=0)
            y  = np.expand_dims(ds4['Clay'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)

        if self.species == 'bcb':
            # Add positive lag for target variable
            ds3 = xr.open_dataset(os.path.join(self.root, f'{timestep}.tNDaijh1E33oma_ai.nc'))
            ds3['time'] = ds3.indexes['time'].to_datetimeindex()

            ds4 = xr.open_dataset(os.path.join(self.root, f'{timestep}.taijlh1E33oma_ai.nc'))
            ds4['time'] = ds4.indexes['time'].to_datetimeindex()

            X5 = np.expand_dims(ds3['BCB_biomass_src'].sel(time=self.datetimeindex[index]), axis=0)
            y  = np.expand_dims(ds4['BCB'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)


        X = np.concatenate((X1, X2, X3, X4, X5), axis=0)  # (5, 90, 144)
        
        if self.padding:
            w = X.shape[2] # width
            h = X.shape[1] # height
            
            top_pad   = self.padding - h
            right_pad = self.padding - w
            
            X = np.lib.pad(X, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', constant_values=0)
            y = np.lib.pad(y, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', constant_values=0)
        
        X = torch.from_numpy(X) # torch image: C x H x W
        y = torch.from_numpy(y) # torch image: C x H x W

        return X, y
        
    def __len__(self):
        return len(self.datetimeindex)

In [75]:
dataset = E33OMA(period='val', species='bcb', padding=256)
print(len(dataset))
dataiter = iter(dataset)
X, y = next(dataiter)
print(X.shape, y.shape)

5256
torch.Size([5, 256, 256]) torch.Size([1, 256, 256])


  datetimeindex = xr.open_mfdataset(list1[:365]).indexes['time'].to_datetimeindex()
  ds1['time'] = ds1.indexes['time'].to_datetimeindex()
  ds2['time'] = ds2.indexes['time'].to_datetimeindex()
  ds3['time'] = ds3.indexes['time'].to_datetimeindex()
  ds4['time'] = ds4.indexes['time'].to_datetimeindex()
