## Nah

In [None]:
import tensorflow as tf  # I am using tensorflow=2.1
import tensorflow.keras as keras
from tensorflow.keras.layers import *
from tensorflow.keras import regularizers
import numpy as np
import os
from src.utils import *
from src.score import *
from src.data_generator import *
from src.networks import *
from src.train import *

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]=str(2)

In [3]:
class PeriodicPadding2D(tf.keras.layers.Layer):
    def __init__(self, pad_width, **kwargs):
        super().__init__(**kwargs)
        self.pad_width = pad_width

    def call(self, inputs, **kwargs):
        if self.pad_width == 0:
            return inputs
        inputs_padded = tf.concat(
            [inputs[:, :, -self.pad_width:, :], inputs, inputs[:, :, :self.pad_width, :]], axis=2)
        # Zero padding in the lat direction
        inputs_padded = tf.pad(inputs_padded, [[0, 0], [self.pad_width, self.pad_width], [0, 0], [0, 0]])
        return inputs_padded

    def get_config(self):
        config = super().get_config()
        config.update({'pad_width': self.pad_width})
        return config

class PeriodicConv2D(tf.keras.layers.Layer):
    def __init__(self, filters,
                 kernel_size,
                 conv_kwargs={},
                 **kwargs, ):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.conv_kwargs = conv_kwargs
        self.use_bias = False
        self.kernel_regularizer = None
        self.stride = 1
        if 'use_bias' in conv_kwargs:
            self.use_bias = conv_kwargs['use_bias']
        if 'kernel_regularizer' in conv_kwargs:
            self.kernel_regularizer = conv_kwargs['kernel_regularizer']
        if 'stride' in conv_kwargs:
            self.stride = conv_kwargs['stride']

        if type(kernel_size) is not int:
            assert kernel_size[0] == kernel_size[1], 'PeriodicConv2D only works for square kernels'
            kernel_size = kernel_size[0]
        pad_width = (kernel_size - 1) // 2
        self.padding = PeriodicPadding2D(pad_width)
        #self.conv = Conv2D(
        #    #filters, kernel_size, padding='valid', **conv_kwargs
        #)
        #self.transpose_conv=Conv2DTranspose(
        #    #filters, kernel_size, padding='valid', **conv_kwargs
        #)

    def call(self, inputs):
        paddinginputs = self.padding(inputs)
        inputshape = paddinginputs.get_shape()
        weight = tf.compat.v1.get_variable(name = 'weights', shape=[self.kernel_size,self.kernel_size,inputshape[-1],self.filters], initializer = tf.keras.initializers.glorot_normal, regularizer=self.kernel_regularizer)
        conv = tf.nn.conv2d(paddinginputs, weight, strides=self.stride,padding='VALID')
        paddinginput_back = conv
        if self.use_bias:
            bias = tf.compat.v1.get_variable(name = 'bias', shape=[self.filters], initializer = tf.keras.initializers.glorot_normal)
            conv += bias
        paddinginput_back = tf.nn.conv2d_transpose(paddinginput_back, weight,tf.shape(paddinginputs), self.stride, padding = 'VALID')
        rrloss = tf.reduce_mean(tf.nn.l2_loss(paddinginputs - paddinginput_back))
        return conv, rrloss

    def get_config(self):
        config = super().get_config()
        config.update({'filters': self.filters, 'kernel_size': self.kernel_size, 'conv_kwargs': self.conv_kwargs})
        return config

def convblock(inputs, filters, kernel=3, stride=1, bn_position=None, l2=0,
              use_bias=True, dropout=0, activation='relu'):
    x = inputs
    if bn_position == 'pre': x = BatchNormalization()(x)
    x, rrloss_p = PeriodicConv2D(
        filters,  kernel, conv_kwargs={
            'kernel_regularizer': regularizers.l2(l2),
            'use_bias': use_bias,
        }
    )(x)
    if bn_position == 'mid': x = BatchNormalization()(x)
    x = LeakyReLU()(x) if activation == 'leakyrelu' else Activation(activation)(x)
    if bn_position == 'post': x = BatchNormalization()(x)
    if dropout > 0: x = Dropout(dropout)(x)
    return x, rrloss_p

def resblock(inputs, filters, kernel, bn_position=None, l2=0, use_bias=True,
             dropout=0, skip=True, activation='relu', down=False, up=False):
    rrloss = tf.constant(0.)
    x = inputs
    if down:
        x = MaxPooling2D()(x)
    for i in range(2):
        with tf.compat.v1.variable_scope("reslayer_%d"%i):
            x, rrloss_p = convblock(
                x, filters, kernel, bn_position=bn_position, l2=l2, use_bias=use_bias,
                dropout=dropout, activation=activation
            )
            rrloss += rrloss_p
    if down or up:
        with tf.compat.v1.variable_scope("reslayer_%s"%("down" if down else "up")):
            inputs, rrloss_p = PeriodicConv2D(
                filters,  kernel, conv_kwargs={
                    'kernel_regularizer': regularizers.l2(l2),
                    'use_bias': use_bias,
                    'strides': 2 if down else 1
                }
            )(inputs)
            rrloss += rrloss_p
    if skip: x = Add()([inputs, x])
    return x, rrloss

#def build_resnet(input, filters, kernels, input_shape, bn_position=None, use_bias=True, l2=0,
def build_resnet(input, filters, kernels, bn_position=None, use_bias=True, l2=0,
                 skip=True, dropout=0, activation='relu', long_skip=False,
                 **kwargs):
    x = input# = Input(shape=input_shape)
    rrloss = tf.constant(0.)

    # First conv block to get up to shape
    with tf.compat.v1.variable_scope("head"):
        x, rrloss_p = convblock(
            x, filters[0], kernels[0], bn_position=bn_position, l2=l2, use_bias=use_bias,
            dropout=dropout, activation=activation
        )
        rrloss += rrloss_p
        ls = x

    # Resblocks
    layerindex = 0
    for f, k in zip(filters[1:-1], kernels[1:-1]):
        with tf.compat.v1.variable_scope("resblock_%d"%layerindex):
            x, rrloss_p = resblock(x, f, k, bn_position=bn_position, l2=l2, use_bias=use_bias,
                    dropout=dropout, skip=skip, activation=activation)
            rrloss += rrloss_p
            if long_skip:
                x = Add()([x, ls])
        layerindex += 1

    with tf.compat.v1.variable_scope("end"):
        # Final convolution
        output, rrloss_p = PeriodicConv2D(
            filters[-1], kernels[-1],
            conv_kwargs={'kernel_regularizer': regularizers.l2(l2)},
        )(x)
        rrloss += rrloss_p

    # This is just because I am using mixed precision. Can be left out for regular precision.
    output = Activation('linear', dtype='float32')(output)
    #return keras.models.Model(input, output), rrloss
    return output, rrloss

In [4]:
args = load_args('../nn_configs/C/017-resnet_d3_ztt_3d.yml')
args['train_years'] = ['2015', '2015']
args['valid_years'] = ['2015', '2015']
args['test_years'] = ['2015', '2015']
args['train_tfr_files'] = None
args['test_tfr_files'] = None
args['tvalid_tfr_files'] = None

In [5]:
dg_test = load_data(**args, only_test=True)

In [7]:
inputs = Input(shape=dg_test.shape)

In [9]:
output, rrloss = build_resnet(
    inputs,
    filters=[128, 128, 128, 3],
    kernels=[7, 3, 3, 3],
    #input_shape=(32, 64, 114,),
    bn_position='post',
    dropout=0.1,   # I am currently using a combination of dropout and l2 for regularization
    l2=1e-5,       # Of course it would be great if I didn't have to use them with racecar
    activation='leakyrelu',
)

## Pre for You

In [3]:
import re
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import xarray as xr
import datetime
import pandas as pd
import pdb
from tqdm import tqdm

In [4]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, ds, var_dict, lead_time, batch_size=32, shuffle=True, load=True,
                 mean=None, std=None, output_vars=None, data_subsample=1, norm_subsample=1,
                 nt_in=1, dt_in=1, cont_time=False, fixed_time=False, multi_dt=1, verbose=0,
                 min_lead_time=None, las_kernel=None, las_gauss_std=None, normalize=True,
                 tfrecord_files=None, tfr_buffer_size=1000, tfr_num_parallel_calls=1,
                 cont_dt=1, tfr_prefetch=None, tfr_repeat=True, y_roll=None, X_roll=None,
                 discard_first=None, tp_log=None, tfr_out=False, tfr_out_idxs=None,
                 old_const=False, is_categorical=False, num_bins=50, bin_min=-5, bin_max=5,
                 predict_difference=False, adaptive_bins=None):
        """
        Data generator for WeatherBench data.
        Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
        Args:
            ds: Dataset containing all variables
            var_dict: Dictionary of the form {'var': level}. Use None for level if data is of single level
            lead_time: Lead time in hours
            batch_size: Batch size
            shuffle: bool. If True, data is shuffled.
            load: bool. If True, datadet is loaded into RAM.
            mean: If None, compute mean from data.
            std: If None, compute standard deviation from data.
            data_subsample: Only take every ith time step
            norm_subsample: Same for normalization. This is AFTER data_subsample!
            nt_in: How many time steps for input. AFTER data_subsample!
            dt_in: Interval of input time steps. AFTER data_subsample!
        """
        if verbose: print('DG start', datetime.datetime.now().time())
        self.ds = ds
        self.var_dict = var_dict
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.lead_time = lead_time
        self.nt_in = nt_in
        self.dt_in = dt_in
        self.cont_time = cont_time
        self.min_lead_time = min_lead_time
        self.fixed_time = fixed_time
        self.multi_dt = multi_dt
        self.tfrecord_files = tfrecord_files
        self.normalize = normalize
        self.tfr_num_parallel_calls = tfr_num_parallel_calls
        self.tfr_buffer_size = tfr_buffer_size
        self.cont_dt = cont_dt
        self.tfr_prefetch = tfr_prefetch
        self.tfr_repeat = tfr_repeat
        self.tfr_out = tfr_out
        self.y_roll = y_roll
        self.X_roll = X_roll
        self.tfr_max_lead = 120
        self.tfr_out_idxs = tfr_out_idxs
        self.old_const = old_const
        self.is_categorical = is_categorical
        self.num_bins = num_bins
        self.bin_min = bin_min
        self.bin_max = bin_max
        self.predict_difference = predict_difference
        if self.predict_difference:
            assert self.tfrecord_files is None, 'difference does not work for tfr'
        self.adaptive_bins = adaptive_bins

        data = []
        level_names = []
        generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
        for long_var, params in var_dict.items():
            if long_var == 'constants':
                for var in params:
                    data.append(ds[var].expand_dims(
                        {'level': generic_level, 'time': ds.time}, (1, 0)
                    ))
                    level_names.append(var)
            else:
                var, levels = params
                da = ds[var]
                if tp_log and var == 'tp':
                    da = log_trans(da, tp_log)
                try:
                    data.append(da.sel(level=levels))
                    level_names += [f'{var}_{level}' for level in levels]
                except ValueError:
                    data.append(da.expand_dims({'level': generic_level}, 1))
                    level_names.append(var)

        self.data = xr.concat(data, 'level').transpose('time', 'lat', 'lon', 'level')
        if discard_first is not None:
            self.data = self.data.isel(time=slice(discard_first, None))
        self.data['level_names'] = xr.DataArray(
            level_names, dims=['level'], coords={'level': self.data.level})
        if output_vars is None:
            self.output_idxs = range(len(self.data.level))
        else:
            self.output_idxs = [i for i, l in enumerate(self.data.level_names.values)
                                if any([bool(re.match(o, l)) for o in output_vars])]
        self.const_idxs = [i for i, l in enumerate(self.data.level_names) if l in var_dict['constants']]
        self.not_const_idxs = [i for i, l in enumerate(self.data.level_names) if l not in var_dict['constants']]

        # Subsample
        self.data = self.data.isel(time=slice(0, None, data_subsample))
        self.raw_data = self.data
        self.dt = self.data.time.diff('time')[0].values / np.timedelta64(1, 'h')
        self.dt_in = int(self.dt_in // self.dt)
        self.nt_offset = (nt_in - 1) * self.dt_in

        if self.min_lead_time is None:
            self.min_nt = 1
        else:
            self.min_nt = int(self.min_lead_time / self.dt)

        # Normalize
        if verbose: print('DG normalize', datetime.datetime.now().time())
        if mean is not None:
            self.mean = mean
        else:
            self.mean = self.data.isel(time=slice(0, None, norm_subsample)).mean(
                ('time', 'lat', 'lon')).compute()
            if 'tp' in self.data.level_names:  # set tp mean to zero but not if ext
                tp_idx = list(self.data.level_names).index('tp')
                self.mean.values[tp_idx] = 0

        if std is not None:
            self.std = std
        else:
            self.std = self.data.isel(time=slice(0, None, norm_subsample)).std(
                ('time', 'lat', 'lon')).compute()
        if tp_log is not None:
            self.mean.attrs['tp_log'] = tp_log
            self.std.attrs['tp_log'] = tp_log
        if normalize:
            self.data = (self.data - self.mean) / self.std

        if verbose: print('DG load', datetime.datetime.now().time())
        if load:
            if verbose: print('Loading data into RAM')
            self.data.load()
        if verbose: print('DG done', datetime.datetime.now().time())

        if self.X_roll is not None:
            self.X_roll = int(self.X_roll // self.dt)
            self.X_rolled = self.data.rolling(time=self.X_roll).mean()
            self.nt_offset += self.X_roll

        self.on_epoch_end()

        if self.y_roll is not None:
            self.y_roll = int(self.y_roll // self.dt)
            assert self.y_roll < self.nt, 'nt must be larger than y_roll'
            self.y_rolled = self.data.isel(level=self.output_idxs).rolling(time=self.y_roll).mean()

        if self.tfrecord_files is not None:
            self.is_tfr = True
            self._setup_tfrecord_ds()
        else:
            self.is_tfr = False
            self.tfr_dataset = None

        if self.is_categorical:
            self.bins = np.linspace(self.bin_min, self.bin_max, self.num_bins+1)
            self.bins[0] = -np.inf; self.bins[-1] = np.inf  # for rare out-of-bound cases.

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.idxs = np.arange(self.nt_offset, self.n_samples)
        if self.shuffle:
            np.random.shuffle(self.idxs)

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(len(self.idxs) / self.batch_size))
    
    @property
    def shape(self):
        return (
            len(self.data.lat),
            len(self.data.lon),
            len(self.data.level.isel(level=self.not_const_idxs)) * self.nt_in + len(self.data.level.isel(
                level=self.const_idxs)) + self.cont_time
        )

    @property
    def nt(self):
        assert (self.lead_time / self.dt).is_integer(), "lead_time and dt not compatible."
        return int(self.lead_time / self.dt)

    @property
    def init_time(self):
        stop = -self.nt
        if self.is_tfr:
            stop += -(self.tfr_max_lead - self.lead_time // self.dt)
        return self.data.isel(time=slice(self.nt_offset, int(stop))).time

    @property
    def valid_time(self):
        start = self.nt+self.nt_offset
        stop = None
        if self.multi_dt > 1:
            diff = self.nt - self.nt // self.multi_dt
            start -= diff; stop = -diff
        if self.is_tfr:
            stop = -int((self.tfr_max_lead - self.lead_time) // self.dt)
            if stop == 0:
                stop = None
        return self.data.isel(time=slice(start, stop)).time

    @property
    def n_samples(self):
        return self.data.isel(time=slice(0, -self.nt)).shape[0]

    def __getitem__(self, i):
        if self.tfrecord_files is None:
            if hasattr(self, 'cheat'):
                X, y = self._get_item(i)
                return X, y[-1]
            else:
                return self._get_item(i)
        else:
            return self._get_tfrecord_item(i)

    def _get_item(self, i):
        'Generate one batch of data'
        idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]

        if self.cont_time:
            if not self.fixed_time:
                nt = np.random.randint(self.min_nt, self.nt + 1, len(idxs))
            else:
                nt = np.ones(len(idxs), dtype='int') * self.nt
            ftime = (nt * self.dt / 100)[:, None, None] * np.ones((1, len(self.data.lat),
                                                                   len(self.data.lon)))
        else:
            nt = self.nt

        if self.X_roll is not None:
            X_data = self.X_rolled
        else:
            X_data = self.data

        X = X_data.isel(time=idxs).values.astype('float32')

        if self.multi_dt > 1: consts = X[..., self.const_idxs]

        if self.nt_in > 1:
            if self.old_const:
                X = np.concatenate([
                                       self.data.isel(time=idxs - nt_in * self.dt_in).values
                                       for nt_in in range(self.nt_in - 1, 0, -1)
                                   ] + [X], axis=-1).astype('float32')
            else:
                X = np.concatenate([
                                       X_data.isel(time=idxs - nt_in * self.dt_in).values[..., self.not_const_idxs]
                                       for nt_in in range(self.nt_in - 1, 0, -1)
                                   ] + [X], axis=-1).astype('float32')

        if self.multi_dt > 1:
            X = [X[..., self.not_const_idxs], consts]
            step = self.nt // self.multi_dt
            y = [
                self.data.isel(time=idxs + nt, level=self.output_idxs).values.astype('float32')
                for nt in np.arange(step, self.nt + step, step)
            ]
        elif self.y_roll is not None:
            y = self.y_rolled.isel(
                time=idxs + nt,
            ).values.astype('float32')
        elif self.tfr_out:
            assert self.batch_size == 1, 'bs must be one'
            time_slice = slice(idxs[0]+self.min_nt, idxs[0]+self.nt+1)
            y = self.data.isel(time=time_slice, level=self.output_idxs).values.astype('float32')[None]
        elif self.predict_difference:
            y = (
                self.data.isel(time=idxs + nt, level=self.output_idxs).values -
                self.data.isel(time=idxs, level=self.output_idxs).values
            ).astype('float32')
        else:
            y = self.data.isel(time=idxs + nt, level=self.output_idxs).values.astype('float32')

        if self.is_categorical:
            y_shape = y.shape
            y = pd.cut(y.reshape(-1), self.bins, labels=False).reshape(y_shape)
            y = tf.keras.utils.to_categorical(y, num_classes=self.num_bins)

        if self.cont_time:
            X = np.concatenate([X, ftime[..., None]], -1).astype('float32')
        return X, y


    def _decode(self, example_proto):
        dic = _parse(example_proto)
        X = tf.io.parse_tensor(dic['X'], np.float32)
        y = tf.io.parse_tensor(dic['y'], np.float32)
        if self.tfr_out_idxs is not None:
            y = tf.gather(y, self.tfr_out_idxs, axis=-1)
        if self.cont_time:
            if self.fixed_time:
                y_idx = self.nt-1
            else:
                y_idx = tf.random.uniform((), self.min_nt-1, self.nt, dtype=tf.int32)
            y_time = (y_idx+1) * self.dt
            ftime = (y_time / 100) * np.ones((len(self.data.lat), len(self.data.lon), 1))
            X = tf.concat([X, tf.cast(ftime, tf.float32)], -1)
            return X, y[y_idx]
        else:
            y_idx = self.nt-1
            return X, y[y_idx]

    def _setup_tfrecord_ds(self):
        # Find all files to be used
        if type(self.tfrecord_files) is list:
            tfr_fns = self.tfrecord_files
        else:
            tfr_fns = sorted(glob(self.tfrecord_files))

        dataset = tf.data.TFRecordDataset(
            tfr_fns, num_parallel_reads=self.tfr_num_parallel_calls
        ).map(self._decode)

        if self.shuffle:
            dataset = dataset.shuffle(
                buffer_size=self.tfr_buffer_size, reshuffle_each_iteration=True
            )

        self.tfr_dataset = dataset.batch(self.batch_size)
        # if self.tfr_repeat:
        #     self.tfr_dataset = self.tfr_dataset.repeat()
        if self.tfr_prefetch is not None:
            self.tfr_dataset = self.tfr_dataset.prefetch(self.tfr_prefetch)
        self.tfr_dataset_np = self.tfr_dataset.as_numpy_iterator()


    def _get_tfrecord_item(self, i):
        X, y = next(self.tfr_dataset_np)
        return X, y

    def to_tfr(self, savedir, steps_per_file=250):
        assert self.batch_size == 1, 'bs must be one'
        for i, (X, y) in tqdm(enumerate(self)):
            if i % steps_per_file == 0:
                c = int(np.floor(i / steps_per_file))
                fn = f'{savedir}/{str(c).zfill(3)}.tfrecord'
                print('Writing to file:', fn)
                writer = tf.io.TFRecordWriter(fn)
            serialized_example = serialize_example(X[0], y[0])  # Remove batch dimension
            writer.write(serialized_example)
            if i + 1 % steps_per_file == 0:
                writer.close()
        writer.close()

In [5]:
datadir = '/data/stephan/WeatherBench/5.625deg/'
var_dict = {
    'geopotential': ('z', [50, 250, 500, 600, 700, 850, 925]), 
    'temperature': ('t', [50, 250, 500, 600, 700, 850, 925]), 
    'u_component_of_wind': ('u', [50, 250, 500, 600, 700, 850, 925]), 
    'v_component_of_wind': ('v', [50, 250, 500, 600, 700, 850, 925]), 
    'specific_humidity': ('q', [50, 250, 500, 600, 700, 850, 925]), 
    'toa_incident_solar_radiation': ('tisr', None), 
    '2m_temperature': ('t2m', None), 
    '6hr_precipitation': ('tp', None), 
    'constants': ['lsm','orography','lat2d']
}
output_vars = ['z_500', 't_850', 't2m']
lead_time = 72
data_subsample = 2
norm_subsample = 30000
nt = 3
dt = 6
discard_first = 24

In [6]:
ds = xr.merge(
    [xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords')
     for var in var_dict.keys()],
    fill_value=0  # For the 'tisr' NaNs
)

In [7]:
train_years = ['2015', '2015']  # For full training data, use ['1979', '2015']. Will use 200GB of RAM.
valid_years = ['2016', '2016']
test_years = ['2017', '2018']
ds_train = ds.sel(time=slice(*train_years))
ds_valid = ds.sel(time=slice(*valid_years))
ds_test = ds.sel(time=slice(*test_years))

In [8]:
dg_train = DataGenerator(
    ds_train,
    var_dict,
    lead_time,
    output_vars=output_vars,
    data_subsample=data_subsample,
    norm_subsample=norm_subsample,
    nt_in=nt,
    dt_in=dt,
    discard_first=discard_first
)
# dg_valid = DataGenerator(
#     ds_valid,
#     var_dict,
#     lead_time,
#     output_vars=output_vars,
#     data_subsample=data_subsample,
#     norm_subsample=norm_subsample,
#     nt_in=nt,
#     dt_in=dt,
#     discard_first=discard_first,
#     mean=dg_train.mean,
#     std=dg_train.std
# )

In [9]:
len(dg_train)

136

In [10]:
X, y = dg_train[0]
print(X.shape, y.shape)

(32, 32, 64, 117) (32, 32, 64, 3)


In [11]:
def create_lat_mse(lat):
    weights_lat = np.cos(np.deg2rad(lat)).values
    weights_lat /= weights_lat.mean()
    def lat_mse(y_true, y_pred):
        error = y_true - y_pred
        mse = error**2 * weights_lat[None, : , None, None]
        return mse
    return lat_mse

In [12]:
lat_mse = create_lat_mse(dg_train.data.lat)