In [None]:
# !pip install -r requirements.txt

In [None]:
import os
import pandas as pd
import xarray as xr
from torch.utils.data import Dataset, DataLoader

# path config
feature_path = ''
gt_path = ''
years = ['2019']
fcst_steps = list(range(1, 73, 1))

In [None]:
class Feature:
    def __init__(self):
        self.path = feature_path
        self.years = years
        self.fcst_steps = fcst_steps
        self.features_paths_dict = self.get_features_paths()

    def get_features_paths(self):
        init_time_path_dict = {}
        for year in self.years:
            init_time_dir_year = os.listdir(os.path.join(self.path, year))
            for init_time in sorted(init_time_dir_year):
                init_time_path_dict[pd.to_datetime(init_time)] = os.path.join(self.path, year, init_time)
        return init_time_path_dict

    def get_fts(self, init_time):
        return xr.open_mfdataset(self.features_paths_dict.get(init_time) + '/*').sel(lead_time=self.fcst_steps).isel(
            time=0)


class GT:
    def __init__(self):
        self.path = gt_path
        self.years = years
        self.fcst_steps = fcst_steps
        self.gt_paths = [os.path.join(self.path, f'{year}.nc') for year in self.years]
        self.gts = xr.open_mfdataset(self.gt_paths)

    def parser_gt_timestamps(self, init_time):
        return [init_time + pd.Timedelta(f'{fcst_step}h') for fcst_step in self.fcst_steps]

    def get_gts(self, init_time):
        return self.gts.sel(time=self.parser_gt_timestamps(init_time))


class mydataset(Dataset):
    def __init__(self):
        self.ft = Feature()
        self.gt = GT()
        self.features_paths_dict = self.ft.features_paths_dict
        self.init_times = list(self.features_paths_dict.keys())

    def __getitem__(self, index):
        init_time = self.init_times[index]
        ft_item = self.ft.get_fts(init_time).to_array().isel(variable=0).values
        gt_item = self.gt.get_gts(init_time).to_array().isel(variable=0).values
        return ft_item, gt_item

    def __len__(self):
        return len(list(self.init_times))
    
# define dataset
my_data = mydataset()
print('sample num:', mydataset().__len__())
train_loader = DataLoader(my_data, batch_size=1, shuffle=True)

In [None]:
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, num_in_ch, num_out_ch):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1)

    def forward(self, x):
        B, S, C, W, H = tuple(x.shape)
        x = x.reshape(B, -1, W, H)
        out = self.conv1(x)
        out = out.reshape(B, S, W, H)
        return out

# define model
in_varibales = 24
in_times = len(fcst_steps)
out_varibales = 1
out_times = len(fcst_steps)
input_size = in_times * in_varibales
output_size = out_times * out_varibales
model = Model(input_size, output_size).cuda()

In [None]:
# define loss
loss_func = nn.MSELoss()

In [None]:
# train
import numpy as np
if __name__ == '__main__':
    for index, (ft_item, gt_item) in enumerate(train_loader):
        ft_item = ft_item.cuda().float()
        gt_item = gt_item.cuda().float()
        output_item = model(ft_item)
        loss = loss_func(output_item, gt_item)
        print(10*'*', np.round(loss.item(),4), 10*'*')