In [40]:
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")

import datetime
import xarray as xr
import os
import numpy as np
import yaml


import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl

In [41]:
# set variable downloaded from the 20th century Reanalysis Data

vrbl = 'pv'
data_name = 'ERA5'

cfd = Path(os.getcwd()).parent.absolute()
data_dir = str(Path(cfd).parent.absolute()) + '/Data/'
dataset_dir = f'{data_dir}Index/'

dataset_dir = Path(f'{data_dir}{data_name}/datasets')
    
resolution = '1.40625' # in degrees, available: '5.625'

start_year, end_year = 1980, 2023 # set both to None to use all available years


reduce_to_single_dim = False # if True, the dataset will be reduced to a single dimension (time) by aggregating over the other dimensions
months_to_keep = [11, 12, 1, 2, 3]
start_season, end_season = '11-15', '03-31'


var_name = 'u' 
pressure_level = 10 # in hPa, ignored for variables without levels. available: 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1,000
region = 'spv' # available: 'global', 'europe', 'northern_hemi'
file_name = dataset_dir / Path(f'{var_name}_{pressure_level}_{resolution}deg_{start_year}-{end_year}_{region}_{"1d" if reduce_to_single_dim else "2d"}.nc')

region_coords = {
    'spv': {'lat': slice(61,59)}, # u10 for Polar Vortex
    'mjo': None, # index for MJO
    }

index_path = f'{data_dir}Index'


In [42]:
ds = xr.open_dataset(file_name)
ds = ds.sel(**region_coords[region])
ds = ds.sel({'time':slice(np.datetime64(f'{start_year}-{start_season}'),np.datetime64(f'{end_year}-{end_season}'))})
ds = ds.sel(time=ds['time.month'].isin(months_to_keep))

ds = ds.mean(dim = 'lon')
ds.__xarray_dataarray_variable__.to_netcdf(f"{dataset_dir}/{vrbl}_index_{start_year}-{end_year}_{data_name}.nc")

### Check regime compatibility

In [43]:
nae = xr.open_dataarray(f"{dataset_dir}/z_500_1.40625deg_1980-2023_northern_hemi_2d_NAEregimes.nc")
categories_nae = nae.values.astype(int)
ds_norm = nae.expand_dims(dim={'nae_regime_cat': int(np.max(categories_nae)+1)}).transpose('time', 'nae_regime_cat')
ds_norm.copy(data=np.eye(max(categories_nae)+1)[categories_nae])

## Format according to Dataset

In [44]:
class ClimateIndicies(pl.LightningDataModule):
    def __init__(self, ds, data_info, seasons=None, return_dates=False):
        self.lag = data_info['config'].get('n_steps_lag')
        self.n_in = data_info['config'].get('n_steps_in')
        self.n_out = data_info['config'].get('n_steps_out')
        self.stack_maps = data_info['config'].get('stack_maps')
        self.regime_path = data_info['config'].get('regime_path','')
        self.data_path = data_info['config'].get('data_path','')
        self.strt = data_info['config'].get('strt','1950')
        self.return_dates = return_dates
        self.seasons = seasons

        self.inputs = []
        self.time_steps = set()
        self.output = None

        if self.seasons is not None:
            ds = ds.sel(time=ds.season.isin(self.seasons))
        else:
            self.seasons=np.unique(ds.season.values)
        try:
            ds = ds[var_name].squeeze()
        except:
            ds = ds.squeeze()

        self.time_steps.add(len(ds.time))
            
        self.inputs.append(ds) 
        self.output = ds 

        
        # compute input and output shapes for model
        output_shape = (self.n_out, 1)
        self.shapes = {'input': [], 'output': output_shape}
        for data in self.inputs:
            self.shapes['input'].append((self.n_in, 1,))

        # compute number of samples
        assert len(self.time_steps) == 1, 'All variables must have the same number of time steps'
        self.n_samples_per_season = {}
        for season in self.seasons:
            s = self.output.time.sel(time=self.output.season == season)
            self.n_samples_per_season[season] = max(len(s) - (self.n_in + self.lag + self.n_out) * 7, 0)

        self.n_samples_per_season_accumulated = [sum(list(self.n_samples_per_season.values())[:i]) for i in range(1,len(self.seasons)+1)]
            
    
    def __getitem__(self, idx):
        inputs = []
        season_idx = np.digitize(idx, self.n_samples_per_season_accumulated)    
        season = self.seasons[season_idx]
        idx = idx - self.n_samples_per_season_accumulated[season_idx-1] if season_idx > 0 else idx
        in_idxs = [idx + i*7 for i in range(self.n_in)]
        out_idxs = [idx + (self.n_in + self.lag + i) * 7 for i in range(self.n_out)]

        for d in self.inputs:
            inp = d.sel(time=d.season == season).isel(time=in_idxs).values
            inputs.append(inp)

        output_slice = self.output.sel(time=self.output.season == season).isel(time=out_idxs)
        output = output_slice.values

        if self.return_dates:
            dates = [datetime.datetime.fromisoformat((str(d)[:10])).timetuple().tm_yday for d in output_slice.time.values]
            r = (inputs, output, dates, output_slice.time.values)
        else:
            r = (inputs, output)

        return r

In [45]:
class DataLoader(pl.LightningDataModule):

    def __init__(self, dataset, data = None, batchsize = 32, **params):
        
        super().__init__()
        
        self.ds = dataset
        self.bs = batchsize
        self.dataset = {'train': [], 'val': [],'test': []}
        self.seasons = params.get('seasons',None) 
        self.return_dates = params.get('return_dates', False)
        self.combine_test = params.get('combine_test',False)
        if data is None: 
            raise Exception("Weather variable need to be specified")
        else:
            self.data = data

    def train_dataloader(self):
        self.dataset['train']= ClimateIndicies(
                    ds=self.ds,
                    data_info=self.data,
                    seasons=self.seasons['train']
                )
        return DataLoader(self.dataset['train'], batch_size = self.bs, shuffle=True)
    
    def val_dataloader(self):
        self.dataset['val']= ClimateIndicies(
                    ds=self.ds,
                    data_info=self.data,
                    seasons=self.seasons['val']
                )
        return DataLoader(self.dataset['val'], batch_size = self.bs, shuffle=False)
    
    def test_dataloader(self):
        self.dataset['test'] = ClimateIndicies(
                    ds=self.ds,
                    data_info=self.data,
                    seasons=self.seasons['test']
                )
            
        return DataLoader(self.dataset['test'], batch_size = self.bs, shuffle=False)
    
    def access_dataset(self):
        return self.dataset


In [46]:
from deepS2S.utils.utils import statics_from_config

cfile = '_index_lstm'
config = yaml.load(open(f'{cfd}/config/config{cfile}.yaml'), Loader=yaml.FullLoader)
config['data_root'] = str(cfd.parent.absolute()) + f'/Data'

data_info, seasons = statics_from_config(config)
seasons =  {'train':list(range(config['data']['fine']['train_start'], config['data']['fine']['train_end'])),
               'val':list(range(config['data']['fine']['val_start'], config['data']['fine']['val_end'])),
               'test':list(range(config['data']['fine']['test_start'], config['data']['fine']['test_end'])),}

params = {'seasons': seasons, 'test_set_name':data_name}



In [52]:
ds.__xarray_dataarray_variable__



In [53]:
dates = []
daytimes = []
target_index = []
input_index = []
error_sum = 0
i = 0
num_smps = 0


test_set = ClimateIndicies(
                ds=ds.__xarray_dataarray_variable__,
                data_info=data_info,
                seasons=seasons['test'],
                return_dates=True)

for input, output, weeks, days in test_set:
    target_index.append(np.array(output).squeeze())
    input_index.append(np.array(input).squeeze())
    dates.append(np.array(weeks).squeeze())
    daytimes.append(np.array(days).squeeze())
    i += 1
    num_smps += len(np.array(output))

In [54]:
for var in ['dates', 'daytimes', 'target_index', 'input_index']:
    locals()[var] = np.concatenate(locals()[var]).reshape((int(num_smps/data_info['config'].get('n_steps_in')),data_info['config'].get('n_steps_in')))


np.savez(f"{index_path}/{vrbl}_index_{start_year}-{end_year}_{region}_testset.npz", input = input_index, 
         output = target_index, dates = dates, daytimes = daytimes)
