In [12]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl

import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from ilan_src.models import *
from ilan_src.dataloader import *
from ilan_src.utils import *

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt


## Train and Val

In [13]:
DATADRIVE = '/datadrive_ssd/'

In [4]:
ds_train = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
    rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
#     const_fn='/datadrive/tigge/32km/constants.nc',
#     const_vars=['orog', 'lsm'],
    data_period=('2018-01', '2019-12'),
    val_days=5,
    split='train',
    pure_sr_ratio=None, 
    tp_log=0.01
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [5]:
ds_valid = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
    rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
#     const_fn='/datadrive/tigge/32km/constants.nc',
#     const_vars=['orog', 'lsm'],
    data_period=('2018-01', '2019-12'),
    val_days=5,
    split='valid',
    mins=ds_train.mins,
    maxs=ds_train.maxs,
    pure_sr_ratio=None,
    tp_log= 0.01
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [6]:
sampler_train = torch.utils.data.WeightedRandomSampler(ds_train.compute_weights(), len(ds_train))
sampler_valid = torch.utils.data.WeightedRandomSampler(ds_valid.compute_weights(), len(ds_valid))

In [7]:
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=16, sampler=sampler_train, num_workers=6)
dl_valid = torch.utils.data.DataLoader(ds_valid, batch_size=16, sampler=sampler_valid, num_workers=6)

In [8]:
len(ds_train), len(dl_train)

(42876, 2680)

In [14]:
import pickle
pickle.dump(dl_train, open(DATADRIVE+"saved_datasets/trainloader_single_forecast_only_log_trans_full.pkl", "wb"))
pickle.dump(dl_valid, open(DATADRIVE+"saved_datasets/validloader_single_forecast_only_log_trans_full.pkl", "wb"))
pickle.dump(ds_train, open(DATADRIVE+"saved_datasets/traindataset_single_forecast_only_log_trans_full.pkl", "wb"))
pickle.dump(ds_valid, open(DATADRIVE+"saved_datasets/validdataset_single_forecast_only_log_trans_full.pkl", "wb"))

## Test

In [15]:
ds_test = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
    rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
#     const_fn='/datadrive/tigge/32km/constants.nc',
#     const_vars=['orog', 'lsm'],
    data_period=('2020-01', '2020-12'),
    mins=ds_train.mins,
    maxs=ds_train.maxs,
    first_days=2,
    pure_sr_ratio=None, 
    tp_log=0.01
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [16]:
pickle.dump(ds_test, open(DATADRIVE+"saved_datasets/testdataset_single_forecast_only_log_trans_first_days_2.pkl", "wb"))