In [1]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
import torchmetrics

import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from ilan_src.models import *

In [2]:
import xarray as xr
import numpy as np
from src.dataloader import *
from src.utils import *

## Data

In [3]:
DATADRIVE = '/datadrive_ssd/'

In [4]:
ds_train = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
    rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
#     const_fn='/datadrive/tigge/32km/constants.nc',
#     const_vars=['orog', 'lsm'],
    data_period=('2018-01', '2018-12'),
    val_days=5,
    split='train',
    pure_sr_ratio=None, 
    tp_log=0.01
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [5]:
ds_valid = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
    rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
#     const_fn='/datadrive/tigge/32km/constants.nc',
#     const_vars=['orog', 'lsm'],
    data_period=('2018-01', '2018-12'),
    val_days=5,
    split='valid',
    mins=ds_train.mins,
    maxs=ds_train.maxs,
    pure_sr_ratio=None,
    tp_log= 0.01
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [6]:
sampler_train = torch.utils.data.WeightedRandomSampler(ds_train.compute_weights(), len(ds_train))
sampler_valid = torch.utils.data.WeightedRandomSampler(ds_valid.compute_weights(), len(ds_valid))

In [7]:
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, sampler=sampler_train, num_workers=6)
dl_valid = torch.utils.data.DataLoader(ds_valid, batch_size=32, sampler=sampler_valid, num_workers=6)

In [8]:
len(ds_train), len(dl_train)

(20988, 656)

In [9]:
import pickle
pickle.dump(dl_train, open("./dataset/trainloader_single_forecast_batch32.pkl", "wb"))
pickle.dump(dl_valid, open("./dataset/validloader__single_forecast_batch32.pkl", "wb"))

In [None]:
dl_train = pickle.load(open("./dataset/trainloader_single_forecast_batch32.pkl", "rb"))
dl_valid = pickle.load(open("./dataset/valdloader_single_forecast_batch32.pkl", "rb"))


## Model

In [53]:
class ForecastClassifier(LightningModule):
    def __init__(self, discriminator, channels_img, num_classes, img_size, features_d, 
                      b1 = 0.0, b2 = 0.9, lr = 1e-4, cond_idx = 0, real_idx = 1): # fill in
        super().__init__()
        self.lr, self.b1, self.b2 = lr, b1, b2
        self.disc = discriminator(channels_img, features_d, num_classes, img_size)
        self.real_idx = real_idx
        self.cond_idx = cond_idx
        self.loss = nn.BCELoss()
        self.output_sigmoid = nn.Sigmoid()
        self.train_accuracy = torchmetrics.Accuracy()
        self.valid_accuracy = torchmetrics.Accuracy()

    def forward(self, condition, high_res):
        return self.disc(condition, high_res)
    
    def training_step(self, batch, batch_idx):

        condition, high_res = batch[self.cond_idx], batch[self.real_idx]
        half_batch = condition.shape[0]//2
        condition = condition[:half_batch, :, :, :]
        targets = torch.cat((torch.ones(condition.shape[0], 1, 1, 1,  device = self.device), torch.zeros(condition.shape[0], 1, 1, 1,  device = self.device)), dim=0)
        condition = torch.cat((condition, condition), dim=0)
        preds = self.disc(condition, high_res)
        loss = self.loss(self.output_sigmoid(preds), targets)
        self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
        self.log('train_acc_step', self.train_accuracy(self.output_sigmoid(preds), targets.int()))
        return loss
        
    def training_epoch_end(self, outs):
        self.log('train_acc_epoch', self.train_accuracy.compute())
        
    def validation_step(self, batch, batch_idx):
        condition, high_res = batch[self.cond_idx], batch[self.real_idx]
        half_batch = condition.shape[0]//2
        condition = condition[:half_batch, :, :, :]
        targets = torch.cat((torch.ones(condition.shape[0], 1, 1, 1,  device = self.device), torch.zeros(condition.shape[0], 1, 1, 1,  device = self.device)), dim=0)
#         print(targets.dtype)
        condition = torch.cat((condition, condition), dim=0)
        preds = self.disc(condition, high_res)
        loss = self.loss(self.output_sigmoid(preds), targets)
        self.log('val_loss', loss, on_epoch=True, on_step=True, prog_bar=True, logger=True)
        self.log('val_acc_step', self.valid_accuracy(self.output_sigmoid(preds), targets.int()))
    
    def validation_epoch_end(self, outs):
        self.log('val_acc_epoch', self.valid_accuracy.compute())
        
    def configure_optimizers(self):
        disc_opt = optim.Adam(self.disc.parameters(), lr=self.lr, betas=(self.b1, self.b2))
#         disc_opt = optim.SGD(self.disc.parameters(), lr=self.lr, momentum=0.9)
        return disc_opt

In [54]:
LEARNING_RATE = 1e-4
IMG_SIZE = 128
CHANNELS_IMG = 1
FEATURES_CRITIC = 32 #64
NUM_CLASSES = 2

model = ForecastClassifier(DSDiscriminator, CHANNELS_IMG, 
                           NUM_CLASSES, IMG_SIZE, FEATURES_CRITIC, 
                           lr = LEARNING_RATE)

trainer = pl.Trainer(gpus = 1)
trainer.fit(model, dl_train, dl_valid)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type            | Params
---------------------------------------------------
0 | disc           | DSDiscriminator | 4.0 M 
1 | loss           | BCELoss         | 0     
2 | output_sigmoid | Sigmoid         | 0     
3 | train_accuracy | Accuracy        | 0     
4 | valid_accuracy | Accuracy        | 0     
---------------------------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
15.903    Total estimated model params size (MB)


Epoch 0:  84%|████████▍ | 657/780 [06:07<01:08,  1.79it/s, loss=0.476, v_num=87, val_loss_epoch=0.755, train_loss_step=0.438]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/124 [00:00<?, ?it/s][A
Epoch 0:  84%|████████▍ | 658/780 [06:08<01:08,  1.79it/s, loss=0.476, v_num=87, val_loss_epoch=0.755, train_loss_step=0.438]
Epoch 0:  84%|████████▍ | 659/780 [06:08<01:07,  1.79it/s, loss=0.476, v_num=87, val_loss_epoch=0.755, train_loss_step=0.438]
Epoch 0:  85%|████████▍ | 660/780 [06:08<01:07,  1.79it/s, loss=0.476, v_num=87, val_loss_epoch=0.755, train_loss_step=0.438]
Epoch 0:  85%|████████▍ | 661/780 [06:09<01:06,  1.79it/s, loss=0.476, v_num=87, val_loss_epoch=0.755, train_loss_step=0.438]
Epoch 0:  85%|████████▍ | 662/780 [06:09<01:05,  1.79it/s, loss=0.476, v_num=87, val_loss_epoch=0.755, train_loss_step=0.438]
Epoch 0:  85%|████████▌ | 663/780 [06:09<01:05,  1.79it/s, loss=0.476, v_num=87, val_loss_epoch=0.755, train_loss_step=0.438]
Epoch 0:  85%|████████▌ | 664