# Train a CNN

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K
import seaborn as sns
import pickle
from src.score import *
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from collections import OrderedDict

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
def limit_mem():
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    tf.compat.v1.Session(config=config)

In [4]:
limit_mem()

In [5]:
sns.set_style('darkgrid')
sns.set_context('notebook')

In [6]:
DATADIR = '/data/weather-benchmark/5.625deg/'

## Create data generator

In [7]:
# Load the validation subset of the data: 2017 and 2018
z500_valid = load_test_data(f'{DATADIR}geopotential_500', 'z')
t850_valid = load_test_data(f'{DATADIR}temperature_850', 't')

In [8]:
z = xr.open_mfdataset(f'{DATADIR}geopotential_500/*.nc', combine='by_coords')
t = xr.open_mfdataset(f'{DATADIR}temperature_850/*.nc', combine='by_coords')

In [9]:
datasets = [z, t]
ds = xr.merge(datasets)

In [10]:
ds_train = ds.sel(time=slice('1979', '2016'))
ds_test = ds.sel(time=slice('2017', '2018'))

In [11]:
dic = OrderedDict({'z': None, 't': None})
dic = OrderedDict({'z': None})

In [53]:
class DataGenerator(keras.utils.Sequence):
    """https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly"""
    def __init__(self, ds, var_dict, lead_time, batch_size=32, shuffle=True, load=True, mean=None, std=None):
        
        self.ds = ds
        self.var_dict = var_dict
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.lead_time = lead_time
        
        data = []
        generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
        for var, levels in var_dict.items():
            try:
                data.append(ds[var].sel(level=levels))
            except ValueError:
                data.append(ds[var].expand_dims({'level': generic_level}, 1))
        self.data = xr.concat(data, 'level').transpose('time', 'lat', 'lon', 'level')
        if load: print('Loading data into RAM'); self.data.load()
        self.mean = self.data.mean(('time', 'lat', 'lon')) if mean is None else mean
        self.std = self.data.std('time').mean(('lat', 'lon')) if std is None else std
        # Normalize
        self.data = (self.data - self.mean) / self.std
        self.n_samples = self.data.isel(time=slice(0, -lead_time)).shape[0]
        self.valid_time = self.data.isel(time=slice(lead_time, None)).time
        
        self.on_epoch_end()
        
    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(self.n_samples / self.batch_size))
    
    def __getitem__(self, i):
        'Generate one batch of data'
        idxs = self.idxs[i*self.batch_size:(i+1)*self.batch_size]
        X = self.data.isel(time=idxs).values
        y = self.data.isel(time=idxs+self.lead_time).values
        return X, y
    
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.idxs = np.arange(self.n_samples)
        if self.shuffle == True:
            np.random.shuffle(self.idxs)

In [54]:
dic

OrderedDict([('z', None)])

In [56]:
# Create train valid split
train_idxs = np.arange(len(ds_train.time))
split = int(0.9 * len(train_idxs))
valid_idxs = train_idxs[split:]
train_idxs = train_idxs[:split]

In [57]:
%%time
dg_train = DataGenerator(ds_train.isel(time=train_idxs), dic, 5*24, batch_size=128)
dg_valid = DataGenerator(ds_train.isel(time=valid_idxs), dic, 5*24, batch_size=128, mean=dg_train.mean, std=dg_train.std)

Loading data into RAM
Loading data into RAM
CPU times: user 13.9 s, sys: 1min 33s, total: 1min 47s
Wall time: 35.1 s


In [58]:
dg_train.mean, dg_train.std

(<xarray.DataArray 'z' (level: 1)>
 array([54103.176], dtype=float32)
 Coordinates:
   * level    (level) int64 850, <xarray.DataArray 'z' (level: 1)>
 array([1119.1687], dtype=float32)
 Coordinates:
   * level    (level) int64 850)

In [59]:
dg_test = DataGenerator(ds_test, dic, 5*24, batch_size=1024, mean=dg_train.mean, std=dg_train.std, shuffle=False)

Loading data into RAM


## Create model class

In [60]:
class PeriodicConv2D(tf.keras.layers.Conv2D):
    """Convolution with periodic padding in second spatial dimension (lon)"""
    def __init__(self, filters, kernel_size, **kwargs):
        assert type(kernel_size) is int, 'Periodic convolutions only works for square kernels.'
        self.pad_width = (kernel_size - 1) // 2
        super().__init__(filters, kernel_size, **kwargs)
        assert self.padding == 'valid', 'Periodic convolution only works for valid padding.'
        assert sum(self.strides) == 2, 'Periodic padding only works for stride (1, 1)'
        
    def __call__(self, inputs, *args, **kwargs):
        # Input: [samples, lat, lon, filters]
        # Periodic padding in lon direction
        inputs_padded = K.concatenate(
            [inputs[:, :, -self.pad_width:, :], inputs, inputs[:, :, :self.pad_width, :]], axis=2)
        # Zero padding in the lat direction
        inputs_padded = tf.pad(inputs_padded, [[0, 0], [self.pad_width, self.pad_width], [0, 0], [0, 0]])
        return super().__call__(inputs_padded, *args, **kwargs)

In [61]:
def build_cnn(filters, kernels, input_shape, activation='elu', dr=0):
    """Fully convolutional network"""
    x = input = Input(shape=input_shape)
    for f, k in zip(filters[:-1], kernels[:-1]):
        x = PeriodicConv2D(f, k, activation=activation)(x)
        if dr > 0: x = Dropout(dr)(x)
    output = PeriodicConv2D(filters[-1], kernels[-1])(x)
    return keras.models.Model(input, output)

In [62]:
cnn = build_cnn([32, 64, 64, 64, 1], [5, 5, 5, 5, 5], (32, 64, 1))

In [63]:
cnn.compile(keras.optimizers.Adam(1e-4), 'mse')

In [64]:
cnn.fit_generator(dg_train, epochs=100, validation_data=dg_valid, 
                  callbacks=[tf.keras.callbacks.EarlyStopping(
                                monitor='val_loss',
                                min_delta=0,
                                patience=2,
                                verbose=1, 
                                mode='auto'
                            )]
                 )

Epoch 1/100
Epoch 2/100
Epoch 3/100

KeyboardInterrupt: 

In [None]:
cnn.save_weights('/data/tmp/test.h5')

In [None]:
def create_predictions(model, dg):
    preds = cnn.predict_generator(dg)
    # Unnormalize
    preds = preds * dg.std.values + dg.mean.values
    fcs = []
    lev_idx = 0
    for var, levels in dg.var_dict.items():
        if levels is None:
            fcs.append(xr.DataArray(
                preds[:, :, :, lev_idx],
                dims=['time', 'lat', 'lon'],
                coords={'time': dg.valid_time, 'lat': dg.ds.lat, 'lon': dg.ds.lon},
                name=var
            ))
            lev_idx += 1
        else:
            nlevs = len(levels)
            fcs.append(xr.DataArray(
                preds[:, :, :, lev_idx:lev_idx+nlevs],
                dims=['time', 'lat', 'lon', 'level'],
                coords={'time': dg.valid_time, 'lat': dg.ds.lat, 'lon': dg.ds.lon, 'level': levels},
                name=var
            ))
            lev_idx += nlevs
    return xr.merge(fcs)

In [None]:
fc = create_predictions(cnn, dg_test)

In [None]:
fc

In [None]:
compute_weighted_rmse(fc.z, z500_valid).values

In [None]:
compute_weighted_rmse(fc.t, t850_valid).values