In [9]:
import os
import json
import torch
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

In [2]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

NVIDIA A100 80GB PCIe


In [3]:
import json

data = [
    [0.21191783249378204, 6.515537738800049, -30.78557586669922, 40.70701217651367],
    [0.34416693449020386, 5.294043064117432, -33.579959869384766, 32.54892349243164],
    [8.225244982895674e-07, 6.251675222301856e-05, -0.0008085378794930875, 0.000658102217130363],
    [2.1786141395568848, 7.301267623901367, 0.0, 475.8797607421875],
    [18.24185562133789, 48.0205078125, 0.0, 2924.923095703125],
    [27.4210262298584, 1093.1732177734375, 0.0, 529028.125],
    [0.19962824881076813, 2.6003715991973877, 0.0, 321.9514465332031],
    [1.571555495262146, 2.791754722595215, 0.0, 134.43466186523438],
    [0.053363505750894547, 0.33962109684944153, 0.0, 39.45267105102539],
    [4.951100826263428, 57.252777099609375, 0.0, 9472.2021484375],
]

variables = ['u', 'v', 'w', 'prec', 'ss_src', 'c_src', 'bc_src', 'ss_conc', 'c_conc', 'bc_conc']

result = {}

for i, variable in enumerate(variables):
    mean, std, min_val, max_val = data[i]
    result[variable] = {
        'mean': mean,
        'std': std,
        'min': min_val,
        'max': max_val
    }

with open('variable_statistics.json', 'a') as jf:
    json.dump(result, jf, indent=4)


In [5]:
import warnings
# warnings.filterwarnings("ignore", message="Converting a CFTimeIndex with dates from a non-standard calendar")
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", message="Converting a CFTimeIndex with dates from a non-standard calendar")
    for root, dirs, files in os.walk('/home/serfani/serfani_data0/E33OMA'):
        
        sorted_files = sorted(files)
        list1 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'aijlh1E33oma_ai']   # Velocity Fields (time, level, lat, lon)
        # list2 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'cijh1E33oma_ai'] 
    datetimeindex = xr.open_mfdataset(list1[:365]).indexes['time'].to_datetimeindex()

In [6]:
from torch.utils.data import Dataset
import torchvision.transforms as T
import warnings


class E33OMA(Dataset):

    def __init__(self, period, species, padding, root='/home/serfani/serfani_data0/E33OMA'):
        super(E33OMA, self).__init__()
        
        self.period  = period
        self.species = species
        self.padding = padding
        self.root    = root
        
        self._get_data_index()
    
    def _get_data_index(self):
        
        for root, dirs, files in os.walk(self.root):
            
            sorted_files = sorted(files)
            list1 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'aijlh1E33oma_ai']   # Velocity Fields (time, level, lat, lon)
            # list2 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'cijh1E33oma_ai']    # Precipitation (time, lat, lon)
            # list3 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'taijh1E33oma_ai']   # Sea Salt Src (time, lat, lon)
            # list3 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'tNDaijh1E33oma_ai'] # Clay-BCB Src (time, lat, lon)
            # list5 = [os.path.join(root, file) for file in sorted_files if file.split(".")[1] == 'taijlh1E33oma_ai']  # Aerosols Mixing Ratio (time, level, lat, lon)

        # Convert `cftime.DatetimeNoLeap` to `pandas.to_datetime()`
        datetimeindex = xr.open_mfdataset(list1[:365]).indexes['time'].to_datetimeindex()

        idx = np.arange(len(datetimeindex))
        rng = np.random.default_rng(0)
        rng.shuffle(idx)
        
        if   self.period == 'train':
            self.datetimeindex = datetimeindex[idx[:12264]]
        
        elif self.period == 'val':
            self.datetimeindex = datetimeindex[idx[12264:]]

    def __getitem__(self, index):
        
        timestep = self.datetimeindex[index].strftime('%Y%m%d')
        
        ds1 = xr.open_dataset(os.path.join(self.root, f'{timestep}.aijlh1E33oma_ai.nc'))
        ds1['time'] = ds1.indexes['time'].to_datetimeindex()
        
        ds2 = xr.open_dataset(os.path.join(self.root, f'{timestep}.cijh1E33oma_ai.nc'))
        ds2['time'] = ds2.indexes['time'].to_datetimeindex()

        X1 = np.expand_dims(ds1['u'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)
        X2 = np.expand_dims(ds1['v'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)
        X3 = np.expand_dims(ds1['omega'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)

        X4 = np.expand_dims(ds2['prec'].sel(time=self.datetimeindex[index]), axis=0)

        with open('variable_statistics.json', 'r') as jf:
            vs = json.load(jf)
        
        X1_mean = vs['u']['mean'];    X1_std = vs['u']['std']
        X2_mean = vs['v']['mean'];    X2_std = vs['v']['std']
        X3_mean = vs['w']['mean'];    X3_std = vs['w']['std']
        X4_mean = vs['prec']['mean']; X4_std = vs['prec']['std']

        if self.species == 'seasalt':
            # Add positive lag for target variable
            ds3 = xr.open_dataset(os.path.join(self.root, f'{timestep}.taijh1E33oma_ai.nc'))
            ds3['time'] = ds3.indexes['time'].to_datetimeindex()

            ds4 = xr.open_dataset(os.path.join(self.root, f'{timestep}.taijlh1E33oma_ai.nc'))
            ds4['time'] = ds4.indexes['time'].to_datetimeindex()

            X5 = np.expand_dims(ds3['seasalt1_ocean_src'].sel(time=self.datetimeindex[index]), axis=0)
            y  = np.expand_dims(ds4['seasalt1'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)

            X5_mean = vs['ss_src']['mean']; X5_std = vs['ss_src']['std']
            y_mean  = vs['ss_conc']['mean']; y_std = vs['ss_conc']['std']

        if self.species == 'clay':
            # Add positive lag for target variable
            ds3 = xr.open_dataset(os.path.join(self.root, f'{timestep}.tNDaijh1E33oma_ai.nc'))
            ds3['time'] = ds3.indexes['time'].to_datetimeindex()

            ds4 = xr.open_dataset(os.path.join(self.root, f'{timestep}.taijlh1E33oma_ai.nc'))
            ds4['time'] = ds4.indexes['time'].to_datetimeindex()

            X5 = np.expand_dims(ds3['Clay_emission'].sel(time=self.datetimeindex[index]), axis=0)
            y  = np.expand_dims(ds4['Clay'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)

            X5_mean = vs['c_src']['mean']; X5_std = vs['c_src']['std']
            y_mean  = vs['c_conc']['mean']; y_std = vs['c_conc']['std']

        if self.species == 'bcb':
            # Add positive lag for target variable
            ds3 = xr.open_dataset(os.path.join(self.root, f'{timestep}.tNDaijh1E33oma_ai.nc'))
            ds3['time'] = ds3.indexes['time'].to_datetimeindex()

            ds4 = xr.open_dataset(os.path.join(self.root, f'{timestep}.taijlh1E33oma_ai.nc'))
            ds4['time'] = ds4.indexes['time'].to_datetimeindex()

            X5 = np.expand_dims(ds3['BCB_biomass_src'].sel(time=self.datetimeindex[index]), axis=0)
            y  = np.expand_dims(ds4['BCB'].isel(level=0).sel(time=self.datetimeindex[index]), axis=0)

            X5_mean = vs['bc_src']['mean']; X5_std = vs['bc_src']['std']
            y_mean  = vs['bc_conc']['mean']; y_std = vs['bc_conc']['std']


        X = np.concatenate((X1, X2, X3, X4, X5), axis=0)  # (5, 90, 144)

        Xs_mean = np.array((X1_mean, X2_mean, X3_mean, X4_mean, X5_mean)).reshape(-1, 1, 1)
        Xs_std  = np.array((X1_std, X2_std, X3_std, X4_std, X5_std)).reshape(-1, 1, 1)
        
        self.y_mean = np.array(y_mean).reshape(-1, 1, 1)
        self.y_std  = np.array(y_std).reshape(-1, 1, 1)

        X = (X - Xs_mean) / Xs_std
        y = (y -  self.y_mean) / self.y_std

        if self.padding:
            w = X.shape[2] # width
            h = X.shape[1] # height
            
            top_pad   = self.padding - h
            right_pad = self.padding - w
            
            X = np.lib.pad(X, ((0, 0), (top_pad, 0), (0, right_pad)), mode='constant', constant_values=0)
        
        X = torch.from_numpy(X).type(torch.float32) # torch image: C x H x W
        y = torch.from_numpy(y).type(torch.float32) # torch image: C x H x W

        return X, y
        
    def __len__(self):
        return len(self.datetimeindex)

In [12]:
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", message="Converting a CFTimeIndex with dates from a non-standard calendar")
    
    dataset = E33OMA(period='val', species='bcb', padding=256)

    print(len(dataset))
    dataiter = iter(dataset)
    X, y = next(dataiter)
    print(X.shape, y.shape)

5256
torch.Size([5, 256, 256]) torch.Size([1, 256, 256])
