In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.score import *
from src.data_generator import *
from src.networks import *
from src.utils import *
import os
import ast, re
import numpy as np
import xarray as xr
import tensorflow as tf
import tensorflow.keras as keras
from configargparse import ArgParser
import pickle
import pdb

In [3]:
exp_id = '31-resnet_multi'
datadir = '/data/WeatherBench/5.625deg/'
model_save_dir = '/home/stephan/data/myWeatherBench/predictions/saved_models/'
pred_save_dir = '/home/stephan/data/myWeatherBench/predictions/'
var_dict = {'geopotential': ('z', [200, 500, 850]), 'temperature': ('t', [200, 500, 850]), 'u_component_of_wind': ('u', [200, 500, 850]), 'v_component_of_wind': ('v', [200, 500, 850]), 'constants': ['lsm','orography','lat2d']}
output_vars = ['z_500', 't_850']
filters = [128, 128, 128, 128, 128, 128, 128, 128, 2]
kernels = [7, 3, 3, 3, 3, 3, 3, 3, 3]
lead_time = 72
lr = 0.5e-4
early_stopping_patience = 5
data_subsample = 2
norm_subsample = 30000
epochs = 150
network_type = 'resnet'
activation = 'relu'
bn_position = 'post'
batch_size = 4

In [4]:
limit_mem()

In [5]:
var_dict = {
    'geopotential': ('z', [500]),
    'temperature': ('t', [850]),
    'constants': ['lat2d', 'orography', 'lsm']
}

In [6]:
ds = xr.merge([xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords') for var in var_dict.keys()])

In [7]:
ds_train = ds.sel(time=slice('2000', '2015'))
ds_valid = ds.sel(time=slice('2016', '2016'))
ds_test = ds.sel(time=slice('2017', '2018'))

In [8]:
dg_train = DataGenerator(
    ds_train, var_dict, lead_time, batch_size=batch_size, output_vars=output_vars,
    data_subsample=data_subsample, norm_subsample=norm_subsample
)
dg_valid = DataGenerator(
    ds_valid, var_dict, lead_time, batch_size=batch_size, mean=dg_train.mean, std=dg_train.std,
    shuffle=False, output_vars=output_vars
)
dg_test =  DataGenerator(
    ds_test, var_dict, lead_time, batch_size=batch_size, mean=dg_train.mean, std=dg_train.std,
    shuffle=False, output_vars=output_vars
)

DG start 17:13:14.877780
DG normalize 17:13:14.892840
DG load 17:13:29.254796
Loading data into RAM
DG done 17:14:47.016230
DG start 17:14:47.020809
DG normalize 17:14:47.080753
DG load 17:14:47.099299
Loading data into RAM
DG done 17:14:51.985847
DG start 17:14:51.986091
DG normalize 17:14:51.999516
DG load 17:14:52.005391
Loading data into RAM
DG done 17:15:01.576044


In [9]:
X, y = dg_train[0]
X.shape

(4, 32, 64, 5)

In [30]:
model = build_resnet(
    filters, kernels, input_shape=(32, 64, 5),
    bn_position=bn_position, use_bias=True, l2=1e-4, skip=True,
    dropout=0.1
)

TypeError: ('Functional models may only specify `name` and `trainable` keyword arguments during initialization. Got an unexpected argument:', 'run_eagerly')

In [12]:
model.compile(keras.optimizers.Adam(lr), 'mse')

In [13]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 64, 5)]  0                                            
__________________________________________________________________________________________________
periodic_conv2d (PeriodicConv2D (None, 32, 64, 128)  31488       input_1[0][0]                    
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 32, 64, 128)  0           periodic_conv2d[0][0]            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 32, 64, 128)  512         re_lu[0][0]                      
______________________________________________________________________________________________

In [14]:
X.dtype

dtype('float32')

In [15]:
p1 = model(X)

In [16]:
p1[0,0,0]

<tf.Tensor: id=3491, shape=(2,), dtype=float32, numpy=array([1.003304  , 0.32654575], dtype=float32)>

In [17]:
model.save('./test.h5')

In [18]:
a = keras.models.load_model('./test.h5', custom_objects={'PeriodicConv2D': PeriodicConv2D})

In [19]:
a(X)[0,0,0]

<tf.Tensor: id=7135, shape=(2,), dtype=float32, numpy=array([1.003304  , 0.32654575], dtype=float32)>

In [20]:
callbacks = []
callbacks.append(tf.keras.callbacks.EarlyStopping(
      monitor='val_loss',
      min_delta=0,
      patience=early_stopping_patience,
      verbose=1,
      mode='auto',
      restore_best_weights=True
  ))

In [21]:
history = model.fit_generator(dg_train, epochs=1, validation_data=dg_valid,
                      callbacks=callbacks
                      )

  798/17523 [>.............................] - ETA: 33:26 - loss: 2.3522

KeyboardInterrupt: 

In [23]:
p1 = model(X)

In [24]:
p1[0,0,0]

<tf.Tensor: id=1367500, shape=(2,), dtype=float32, numpy=array([-2.4111865, -2.8216019], dtype=float32)>

In [25]:
model.save('./test.h5')

In [26]:
a = keras.models.load_model('./test.h5', custom_objects={'PeriodicConv2D': PeriodicConv2D})

In [27]:
a(X)[0,0,0]

<tf.Tensor: id=1379891, shape=(2,), dtype=float32, numpy=array([-2.4111865, -2.8216019], dtype=float32)>