In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl

import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# from src.models import *
from ilan_src.models import *
from src.dataloader import *
from src.utils import *
from src.evaluation import *

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

import pickle

if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

## Set up dataset - my way

In [8]:
DATADRIVE = '/home/jupyter/data/'

In [9]:
ds_train = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
    rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
    data_period=('2018-01', '2019-12'),
    val_days=5,
    split='train',
    tp_log=0.01
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [10]:
ds_test = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
    rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
    data_period=('2020-01', '2020-12'),
    first_days=5,
    tp_log=0.01,
    mins=ds_train.mins,
    maxs=ds_train.maxs
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [11]:
sampler_test = torch.utils.data.SequentialSampler(ds_test)
dl_test = torch.utils.data.DataLoader(
    ds_test, batch_size=32, sampler=sampler_test
)

In [24]:
ds_test_small = TiggeMRMSDataset(
    tigge_dir=f'{DATADRIVE}/tigge/32km/',
    tigge_vars=['total_precipitation'],
    mrms_dir=f'{DATADRIVE}/mrms/4km/RadarOnly_QPE_06H/',
    rq_fn=f'{DATADRIVE}/mrms/4km/RadarQuality.nc',
    data_period=('2020-01', '2020-01'),
    first_days=2,
    tp_log=0.01,
    mins=ds_train.mins,
    maxs=ds_train.maxs
)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


Loading data


In [25]:
sampler_test_small = torch.utils.data.SequentialSampler(ds_test_small)
dl_test_small = torch.utils.data.DataLoader(
    ds_test_small, batch_size=32, sampler=sampler_test_small
)

### Ilan's dataset

In [18]:
ds_test2 = pickle.load(open("/home/jupyter/data/saved_datasets/testdataset_single_forecast_only_log_trans_sample.pkl", "rb"))

In [19]:
sampler_test2 = torch.utils.data.SequentialSampler(ds_test2)
dl_test2 = torch.utils.data.DataLoader(
    ds_test2, batch_size=64, sampler=sampler_test2
)

In [22]:
ds_test[0][0][0, 0, 0]

0.3974892

In [23]:
ds_test2[0][0][0, 0, 0]

0.3974892

## Load model

In [14]:
gan = LeinGANGP.load_from_checkpoint("/home/jupyter/data/saved_models/leingan/single_forecast/0/epoch=120-step=324279.ckpt")

In [15]:
gen = gan.gen
gen = gen.to(device)
gen.train(False);

## Ilan's evaluation

In [12]:
def plot_samples_per_input(cond, target, gen, k=1, samples = 3):
    fig, axs = plt.subplots(k, samples+2, figsize=(15, k*5))
    gen_images = np.zeros((k,samples+2,128,128))
    with torch.no_grad():    
        for i in range(4):
            noise = torch.randn(cond.shape[0], 1, cond.shape[2], cond.shape[3]).to(device)
            pred = gen(cond, noise).detach().cpu().numpy()
            for j in range(k):
                gen_images[j,i,:,:] = pred[j, 0] 

    for j in range(k):
        lr = cond[j, 0].detach().cpu().numpy()
        if lr.shape[0]==64:
            lr = lr[24:40, 24:40]
        hr = target[j, 0].detach().cpu().numpy()
        mn = np.min([np.min(hr), np.min(pred), np.min(gen_images[j,i,:,:])])
        mx = np.max([np.max(hr), np.max(pred), np.max(gen_images[j,i,:,:])])
        im = axs[j,0].imshow(lr, vmin=mn, vmax=mx, cmap='gist_ncar_r')
#         plt.colorbar(im, ax=axs[j,0], shrink=0.7)
        im = axs[j,1].imshow(hr, vmin=mn, vmax=mx, cmap='gist_ncar_r')
#         plt.colorbar(im, ax=axs[j,0], shrink=0.7)
        for i in range(samples):
            im = axs[j,i+2].imshow(gen_images[j,i,:,:], vmin=mn, vmax=mx, cmap='gist_ncar_r')
#             plt.colorbar(im, ax=axs[j,i], shrink=0.7)
    plt.show()  

In [17]:
x_sample, y_sample = next(iter(dl_test))
x_sample, y_sample = x_sample.to(device), y_sample.to(device)
# plot_samples_per_input(x_sample, y_sample, gen, k=16)

### Patch evaluation

In [41]:
num_samples=10
scores = gen_patch_eval(
    gen, 
    dl_test_small, 
    num_samples, 
    ds_test_small.mins.tp.values, 
    ds_test_small.maxs.tp.values, 
    ds_test_small.tp_log, 
    device
)

batch 0 out of 4
metrics took 5.379225 seconds.
batch 1 out of 4
metrics took 5.407562 seconds.
batch 2 out of 4
metrics took 5.341427 seconds.
batch 3 out of 4
metrics took 0.946686 seconds.


In [42]:
scores

(0.32465395752506965,
 0.444688116088458,
 0.3218651547988402,
 <xarray.DataArray (rank: 11)>
 array([227523., 174535., 157127., 148649., 142779., 136920., 133902.,
        130989., 132315., 138849., 245884.])
 Dimensions without coordinates: rank,
 (<xarray.DataArray (forecast_probability: 9)>
  array([0.03172549, 0.16364529, 0.31069911, 0.        , 0.42108236,
         0.50461072, 0.57405935, 0.        , 0.6304834 ])
  Coordinates:
    * forecast_probability  (forecast_probability) float64 0.05 0.15 ... 0.75 0.85,
  None,
  <xarray.DataArray 'samples' (forecast_probability: 9)>
  array([1041686.,  158868.,  155826.,       0.,   53106.,   94996.,
           46537.,       0.,   98821.])
  Coordinates:
    * forecast_probability  (forecast_probability) float64 0.05 0.15 ... 0.75 0.85),
 0.68224376)

## My full field eval

In [37]:
def create_valid_predictions(model, ds_valid):
    # Get predictions for full field
    preds = []
    for t in tqdm.tqdm(range(len(ds_valid.tigge.valid_time))):
        X, y = ds_valid.return_full_array(t)
        noise = torch.randn(1, X.shape[0], X.shape[1], X.shape[2]).to(device)
        pred = model(torch.FloatTensor(X[None]).to(device), noise).to('cpu').detach().numpy()[0, 0]
        preds.append(pred)
    preds = np.array(preds)
    
    # Unscale
    preds = preds * (ds_valid.maxs.tp.values - ds_valid.mins.tp.values) + ds_valid.mins.tp.values
    
    # Un-log
    if ds_valid.tp_log:
        preds = log_retrans(preds, ds_valid.tp_log)
    
    # Convert to xarray
    preds = xr.DataArray(
        preds,
        dims=['valid_time', 'lat', 'lon'],
        coords={
            'valid_time': ds_valid.tigge.valid_time,
            'lat': ds_valid.mrms.lat.isel(
                lat=slice(ds_valid.pad_mrms, ds_valid.pad_mrms+preds.shape[1])
            ),
            'lon': ds_valid.mrms.lon.isel(
                lon=slice(ds_valid.pad_mrms, ds_valid.pad_mrms+preds.shape[2])
            )
        },
        name='tp'
    )
    return preds

In [38]:
def create_valid_ensemble(model, ds_valid, nens):
    """Wrapper to create ensemble"""
    preds = [create_valid_predictions(model, ds_valid) for _ in range(nens)]
    return xr.concat(preds, 'member')

In [39]:
%%time
det_pred = create_valid_predictions(gen, ds_test_small)

  0%|          | 0/3 [00:00<?, ?it/s]

CPU times: user 1min 45s, sys: 4.23 s, total: 1min 50s
Wall time: 29.4 s


In [40]:
det_pred

In [None]:
%%time
ens_pred = create_valid_ensemble(gen, ds_test, nens=10)

  0%|          | 0/110 [00:00<?, ?it/s]