In [11]:
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")

import datetime
import xarray as xr
import os
import numpy as np
import xesmf as xe
import yaml


import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl

In [12]:
def adapt_inactive_phase(ds):

    ds_phasenumber = ds["phase_number"].copy(deep=True)
    ds_amplitude = ds["amplitude"].copy(deep=True)
    ds_inactive = ds_amplitude[ds_amplitude <= 1]
    ds_phasenumber.loc[ds_inactive['time']] = 0
    ds = ds.assign(phase_number = ds_phasenumber)

    return ds

In [None]:
# set variable downloaded from the 20th century Reanalysis Data
vrbl = 'MJO'
m_cat = '_9cat'
# vrbl = 'pv'

data_name = 'ERA5'



dataset_dir = Path(f'/mnt/beegfs/home/bommer1/WiOSTNN/Version1/data/{data_name}/datasets')
    
resolution = '1.40625' # in degrees, available: '5.625'
if 'ERA5' in data_name:
    start_year, end_year = 1980, 2023 # set both to None to use all available years
else:
    start_year, end_year = 1836, 1980 # set both to None to use all available years


reduce_to_single_dim = False # if True, the dataset will be reduced to a single dimension (time) by aggregating over the other dimensions
months_to_keep = [11, 12, 1, 2, 3]
start_season, end_season = '11-15', '03-31'


if vrbl == 'MJO': 
    var_name = 'SST' # available: 't2m', 'tp', 'z'
    region = 'mjo' # available: 'global', 'europe', 'northern_hemi'
    pressure_level = '' # in hPa, ignored for variables without levels. available: 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1,000
    days =7
    var_index = 'phase_number'
    # var_index = 'rmm2'
elif vrbl == 'pv':
    var_name = 'u' 
    pressure_level = 10 # in hPa, ignored for variables without levels. available: 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1,000
    region = 'spv' # available: 'global', 'europe', 'northern_hemi'
    file_name = dataset_dir / Path(f'{var_name}_{pressure_level}_{resolution}deg_{start_year}-{end_year}_{region}_{"1d" if reduce_to_single_dim else "2d"}.nc')

region_coords = {
    'spv': {'lat': slice(61,59)}, # u10 for Polar Vortex
    'mjo': None, # index for MJO
    }

index_path = '/mnt/beegfs/home/bommer1/WiOSTNN/Data/Index'


In [14]:
if vrbl == 'MJO': 
    index_name = index_path / Path(f'mjo_rmm_index_noaa_{start_year}-{start_season}_{end_year}-{end_season}.nc')
    ds = xr.open_dataset(index_name)
    ds = ds.sel({'time':slice(np.datetime64(f'{start_year}-{start_season}'),np.datetime64(f'{end_year}-{end_season}'))})
    ds = ds.sel(time=ds['time.month'].isin(months_to_keep))
    if m_cat == '_9cat':
        ds = adapt_inactive_phase(ds)
else:
    file_name = '/mnt/beegfs/home/bommer1/WiOSTNN/Version1/data/ERA5/datasets/dai/u_10_1.40625deg_1980-2023_spv_2d.nc'
    ds = xr.open_dataset(file_name)
    ds = ds.sel(**region_coords[region])
    ds = ds.sel({'time':slice(np.datetime64(f'{start_year}-{start_season}'),np.datetime64(f'{end_year}-{end_season}'))})
    ds = ds.sel(time=ds['time.month'].isin(months_to_keep))

    ds = ds.mean(dim = 'lon')
    ds.__xarray_dataarray_variable__.to_netcdf(f"{dataset_dir}/{vrbl}_index_{start_year}-{end_year}_{data_name}_s.nc")

In [15]:
# compute seasons coordinates
start_season, end_season = '11-15', '03-31'
if vrbl == 'MJO':
    yrs = np.arange(start_year, end_year+1)
    # seas = []
    for i in range(len(yrs)-1):
        year = yrs[i]
        ds_sel = ds.sel({'time':slice(np.datetime64(f'{year}-{start_season}'),np.datetime64(f'{year+1}-{end_season}'))}).resample(time='1D').mean('time')
        len_time = len(ds_sel.time.values)
        seas_coord = np.repeat(i+1,len_time)
        ds_sel = ds_sel.assign_coords(season=('time', seas_coord))
        # assert np.isnan(ds_sel.values).any() == 0, 'There are NaN values in the dataset in selected season.'
        
        if i > 0:
            dst = xr.concat((dst,ds_sel), dim ='time')
            seas.append(len_time)
        if i ==0:
            seas = [len_time,len_time]
            dst = ds_sel
            del ds_sel
    print(f'Season assignment done.')    
    # seas.append(len_time)
    ds = dst[var_index]
    del ds_sel, dst
    # yrs = np.arange(ds.isel(time=0).time.dt.year.values, ds.isel(time=-1).time.dt.year.values+1)
    # seas = []
    # for i in range(len(yrs)-1):
    #     year = yrs[i]
    #     ds_sel = ds.sel({'time':slice(np.datetime64(f'{year}-{start_season}'),np.datetime64(f'{year+1}-{end_season}'))})
    #     len_time = len(ds_sel.time.values)
    #     if i > 0:
    #         dst = xr.concat((dst,ds_sel), dim ='time')
    #         seas.append(len_time)
    #     if i ==0:
    #         seas = [len_time,len_time]
    #         dst = ds_sel
    #         del ds_sel
    # seas.append(len_time)
    # ds = dst
    # del ds_sel, dst

    # last = ds.isel(time=-1)
    # first = ds.isel(time=0)
    # seasons = seas
    # # cut days from first and last season that are not in dataset
    # cut_beginning = int((first.time.values - np.datetime64(f'{start_year - 1}-{start_season}')).astype(int) / (1e9 * 60 * 60 * 24))
    # cut_end = -1*int((last.time.values - np.datetime64(f'{end_year + 1}-{end_season}')).astype(int) / (1e9 * 60 * 60 * 24)) -7
    # seasons[0] -= cut_beginning
    # seasons[-1] -= cut_end

    # # add season coordinates along time dimension
    # season_coords = np.concatenate([s * [i] for i, s in enumerate(seasons)])
    # ds = ds.assign_coords(season=('time', season_coords))
    # ds = ds.phase_number
    # ds.season.values = ds.season.values
    if var_index == 'phase_number':
        ds.to_netcdf(f"{dataset_dir}/{vrbl}{m_cat}_index_{start_year}-{end_year}_{data_name}_s.nc")
    else:
        ds.to_netcdf(f"{dataset_dir}/{var_index}_index_{start_year}-{end_year}_{data_name}_s.nc")

Season assignment done.


In [16]:
np.unique(ds.values, return_counts=True)

(array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8., nan]),
 array([2335,  328,  419,  514,  425,  361,  442,  529,  493,   55]))

In [17]:
# if vrbl == 'MJO':
#     start_season, end_season = '11-15', '03-30'
#     yrs = np.arange(ds.isel(time=0).time.dt.year.values, ds.isel(time=-1).time.dt.year.values+1)
#     for i in range(len(yrs)-1):
#         year = yrs[i]
#         ds_sel = ds.sel({'time':slice(np.datetime64(f'{year}-{start_season}'),np.datetime64(f'{year+1}-{end_season}'))})
#         if i > 0:
#             dst = xr.concat((dst,ds_sel), dim ='time')
#         if i ==0:
#             dst = ds_sel
#             del ds_sel

In [18]:
nae = xr.open_dataarray("/mnt/beegfs/home/bommer1/WiOSTNN/Version1/data/ERA5/datasets/z_500_1.40625deg_1980-2023_northern_hemi_2d_NAEregimes.nc")
categories_nae = nae.values.astype(int)
ds_norm = nae.expand_dims(dim={'nae_regime_cat': int(np.max(categories_nae)+1)}).transpose('time', 'nae_regime_cat')
ds_norm.copy(data=np.eye(max(categories_nae)+1)[categories_nae])

In [19]:
categories = ds.values
print(f"{vrbl}: {categories.shape}")
print(f"nae: {categories_nae.shape}")
# np.eye(max(categories)+1)[categories]

MJO: (5901,)
nae: (5901,)


In [20]:
# ds.shape
# # categories = ds.values.astype(int)
# # categories = (categories- np.ones(categories.shape)).astype(int)
# # ds.values = categories
# ds_norm = ds.expand_dims(dim={'mjo_cat': int(np.max(categories)+1)}).transpose('time', 'mjo_cat')
# categories = ds_norm.values.astype(int)
# ds_norm.copy(data=np.eye(max(categories)+1)[categories])

In [21]:
if m_cat == '_9cat':
    xr.load_dataarray(f"{dataset_dir}/{vrbl}{m_cat}_index_{start_year}-{end_year}_{data_name}_s.nc")
else:
    xr.load_dataarray(f"{dataset_dir}/{vrbl}_index_{start_year}-{end_year}_{data_name}_s.nc")


## Format according to Dataset

In [22]:
class ClimateIndicies(pl.LightningDataModule):
    def __init__(self, ds, data_info, seasons=None, return_dates=False):
        self.lag = data_info['config'].get('n_steps_lag')
        self.n_in = data_info['config'].get('n_steps_in')
        self.n_out = data_info['config'].get('n_steps_out')
        self.stack_maps = data_info['config'].get('stack_maps')
        self.regime_path = data_info['config'].get('regime_path','')
        self.data_path = data_info['config'].get('data_path','')
        self.strt = data_info['config'].get('strt','1950')
        self.return_dates = return_dates
        self.seasons = seasons

        self.inputs = []
        self.time_steps = set()
        self.output = None

        if self.seasons is not None:
            ds = ds.sel(time=ds.season.isin(self.seasons))
        else:
            self.seasons=np.unique(ds.season.values)
        try:
            ds = ds[var_name].squeeze()
        except:
            ds = ds.squeeze()

        self.time_steps.add(len(ds.time))
        # convert categorical variables to one-hot encoding
        # try:
        #     ds = ds.data.astype(int)
        # except:
        #     ds.values = ds.values.astype(int)
            
        self.inputs.append(ds) 
        self.output = ds 

        
        # compute input and output shapes for model
        output_shape = (self.n_out, 1)
        self.shapes = {'input': [], 'output': output_shape}
        for data in self.inputs:
            self.shapes['input'].append((self.n_in, 1,))

        # compute number of samples
        assert len(self.time_steps) == 1, 'All variables must have the same number of time steps'
        self.n_samples_per_season = {}
        for season in self.seasons:
            s = self.output.time.sel(time=self.output.season == season)
            self.n_samples_per_season[season] = max(len(s) - (self.n_in + self.lag + self.n_out) * 7, 0)

        self.n_samples_per_season_accumulated = [sum(list(self.n_samples_per_season.values())[:i]) for i in range(1,len(self.seasons)+1)]
            
    
    def __getitem__(self, idx):
        inputs = []
        season_idx = np.digitize(idx, self.n_samples_per_season_accumulated)    
        season = self.seasons[season_idx]
        idx = idx - self.n_samples_per_season_accumulated[season_idx-1] if season_idx > 0 else idx
        in_idxs = [idx + i*7 for i in range(self.n_in)]
        out_idxs = [idx + (self.n_in + self.lag + i) * 7 for i in range(self.n_out)]

        for d in self.inputs:
            inp = d.sel(time=d.season == season).isel(time=in_idxs).values
            inputs.append(inp)

        output_slice = self.output.sel(time=self.output.season == season).isel(time=out_idxs)
        output = output_slice.values

        if self.return_dates:
            dates = [datetime.datetime.fromisoformat((str(d)[:10])).timetuple().tm_yday for d in output_slice.time.values]
            r = (inputs, output, dates, output_slice.time.values)
        else:
            r = (inputs, output)

        return r

In [23]:
class DataLoader(pl.LightningDataModule):

    def __init__(self, dataset, data = None, batchsize = 32, **params):
        
        super().__init__()
        
        self.ds = dataset
        self.bs = batchsize
        self.dataset = {'train': [], 'val': [],'test': []}
        self.seasons = params.get('seasons',None) 
        self.return_dates = params.get('return_dates', False)
        self.combine_test = params.get('combine_test',False)
        if data is None: 
            raise Exception("Weather variable need to be specified")
        else:
            self.data = data

    def train_dataloader(self):
        self.dataset['train']= ClimateIndicies(
                    ds=self.ds,
                    data_info=self.data,
                    seasons=self.seasons['train']
                )
        return DataLoader(self.dataset['train'], batch_size = self.bs, shuffle=True)
    
    def val_dataloader(self):
        self.dataset['val']= ClimateIndicies(
                    ds=self.ds,
                    data_info=self.data,
                    seasons=self.seasons['val']
                )
        return DataLoader(self.dataset['val'], batch_size = self.bs, shuffle=False)
    
    def test_dataloader(self):
        self.dataset['test'] = ClimateIndicies(
                    ds=self.ds,
                    data_info=self.data,
                    seasons=self.seasons['test']
                )
            
        return DataLoader(self.dataset['test'], batch_size = self.bs, shuffle=False)
    
    def access_dataset(self):
        return self.dataset


In [24]:
from utils import statics_from_config

cfd = '/mnt/beegfs/home/bommer1/WiOSTNN/Experiments/light'
cfile = '_1980_index'
config = yaml.load(open(f'{cfd}/config/convlstm_config{cfile}.yaml'), Loader=yaml.FullLoader)

data_info, seasons = statics_from_config(config)
# if vrbl == 'MJO': 
#     seasons =  {'train':list(range(config['data']['fine']['train_start']-31, config['data']['fine']['train_end']-31)),
#                'val':list(range(config['data']['fine']['val_start']-31, config['data']['fine']['val_end']-31)),
#                'test':list(range(config['data']['fine']['test_start']-30, config['data']['fine']['test_end']-29)),}
# else:
seasons =  {'train':list(range(config['data']['fine']['train_start'], config['data']['fine']['train_end'])),
               'val':list(range(config['data']['fine']['val_start'], config['data']['fine']['val_end'])),
               'test':list(range(config['data']['fine']['test_start'], config['data']['fine']['test_end'])),}


params = {'seasons': seasons, 'test_set_name':data_name}

# Fine_data = DataLoader(ds, data_info, config['data']['bs'], **params)

# test_loader = Fine_data.test_dataloader()

In [25]:
dates = []
daytimes = []
target_index = []
input_index = []
error_sum = 0
i = 0
num_smps = 0

if vrbl == 'MJO':
        test_set = ClimateIndicies(
                ds=ds,
                data_info=data_info,
                seasons=seasons['test'],
                return_dates=True)
else:
        test_set = ClimateIndicies(
        ds=ds,
        data_info=data_info,
        seasons=seasons['test'],
        return_dates=True)

for input, output, weeks, days in test_set:
    target_index.append(np.array(output).squeeze())
    input_index.append(np.array(input).squeeze())
    dates.append(np.array(weeks).squeeze())
    daytimes.append(np.array(days).squeeze())
    i += 1
    num_smps += len(np.array(output))

In [26]:
for var in ['dates', 'daytimes', 'target_index', 'input_index']:
    locals()[var] = np.concatenate(locals()[var]).reshape((int(num_smps/data_info['config'].get('n_steps_in')),data_info['config'].get('n_steps_in')))

if var_index == 'phase_number':
    if m_cat == '_9cat':
        np.savez(f"{index_path}/{vrbl}_9cat_index_{start_year}-{end_year}_{region}_testset.npz", input = input_index, 
             output = target_index, dates = dates, daytimes = daytimes)
    else:
        np.savez(f"{index_path}/{vrbl}_index_{start_year}-{end_year}_{region}_testset.npz", input = input_index, 
         output = target_index, dates = dates, daytimes = daytimes)
else:
    np.savez(f"{index_path}/{var_index}_index_{start_year}-{end_year}_{region}_testset.npz", input = input_index, 
         output = target_index, dates = dates, daytimes = daytimes)


In [27]:
target_index.shape
#(574, 6)

(532, 6)