In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import pandas as pd
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

In [3]:
from src.dataloader import *
from src.models import *
from src.trainer import *
from src.utils import *

In [4]:
DATADRIVE = '/datadrive_ssd/'

In [5]:
interval=0.75
cat_bins = np.arange(0, interval*128, interval)
len(cat_bins), max(cat_bins)

(128, 95.25)

## Load data

In [6]:
ds_train = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
    rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
    data_period=('2018-01', '2019-12'),
    val_days=5,
    cat_bins=cat_bins,
    split='train',
    tp_log=1
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [None]:
ds_valid = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
    rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
    data_period=('2018-01', '2019-12'),
    val_days=5,
    cat_bins=cat_bins,
    split='valid',
    tp_log=1,
    mins=ds_train.mins,
    maxs=ds_train.maxs
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [None]:
sampler_train = torch.utils.data.WeightedRandomSampler(ds_train.compute_weights(), len(ds_train))
sampler_valid = torch.utils.data.WeightedRandomSampler(ds_valid.compute_weights(), len(ds_valid))

In [None]:
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=16, sampler=sampler_train)
dl_valid = torch.utils.data.DataLoader(ds_valid, batch_size=16, sampler=sampler_valid)

In [None]:
X, y = ds_valid[600]
X.shape, y.shape

## Model

In [None]:
gen = Generator(
    nres=3, nf_in=1, nf=256, nout=128, activation_out='softmax', use_noise=False,
    spectral_norm=False, halve_filters_up=False, batch_norm=True
).to(device)

In [None]:
count_parameters(gen)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gen.parameters(), lr=1e-4)

In [None]:
trainer = Trainer(gen, optimizer, criterion, dl_train, dl_valid)

In [None]:
trainer.fit(10)

In [None]:
trainer.plot_losses()

In [None]:
with open('01.trainer', 'wb') as f:
    pickle.dump(trainer, f)