In [91]:
import tensorflow as tf
import numpy as np
from keras.layers import Input, Conv2D, Convolution1D, MaxPooling2D, Dense, Dropout, \
                          Flatten, concatenate, Activation, Reshape, \
                          UpSampling2D,ZeroPadding2D
import keras
from keras import layers
from keras.regularizers import L1L2
from keras.models import Model
import keras.backend as K
import xarray as xr
import datetime

from bilinear_interp import BilinearInterpolation
from model_tester import *

def size(model): # Compute number of params in a model (the actual number of floats)
    return sum([np.prod(K.get_value(w).shape) for w in model.trainable_weights])

def get_initial_weights(output_size):
    b = np.zeros((2, 3), dtype='float32')
    b[0, 0] = 1
    b[1, 1] = 1
    W = np.zeros((output_size, 6), dtype='float32')
    weights = [W, b.flatten()]
    return 

In [96]:
ds = xr.open_zarr('../data/data_full.zarr')
ds = xr.concat([
    ds.sel(time=slice(None, datetime.datetime(2011, 9, 30))), 
    ds.sel(time=slice(datetime.datetime(2012, 10, 1), None))
    ], dim='time')
ds = ds.drop(['Lambert_Azimuthal_Grid', 'status_flag', 'ceda_sic_bin', 'era5_sic', 'era5_sic_bin', 'total_standard_error'])

In [80]:
def stn(input_shape=(32, 64, 1), sampling_size=(8, 16), num_classes=10):
    inputs = Input(shape=input_shape)
    
    # Localization network (?)
    conv1 = Conv2D(32, (5, 5), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (5, 5), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(32, (5, 5), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(32, (5, 5), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(32, (5, 5), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(32, (5, 5), activation='relu', padding='same')(conv3)


    conv5 = Conv2D(32, (5, 5), activation='relu', padding='same')(conv3)
    conv5 = Conv2D(32, (5, 5), activation='relu', padding='same')(conv5)
    
    locnet = Flatten()(conv5)
    locnet = Dense(500)(locnet)
    locnet = Activation('relu')(locnet)
    locnet = Dense(200)(locnet)
    locnet = Activation('relu')(locnet)
    locnet = Dense(100)(locnet)
    locnet = Activation('relu')(locnet)
    locnet = Dense(50)(locnet)
    locnet = Activation('relu')(locnet)
    weights = get_initial_weights(50)
    locnet = Dense(6, weights=weights)(locnet)
    
    # Apply
    x = BilinearInterpolation(sampling_size)([inputs, locnet])
    
    # U-net up
    up6 = keras.layers.Concatenate(axis=-1)([Conv2D(32, (2, 2), activation='relu', padding='same')(UpSampling2D(size=(2, 2))(x)), conv2])
    conv6 = Conv2D(32, (5, 5), activation='relu', padding='same')(up6)
    conv6 = Conv2D(32, (5, 5), activation='relu', padding='same')(conv6)

    up7 = keras.layers.Concatenate(axis=-1)([Conv2D(32, (2, 2),activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6)), conv1])
    conv7 = Conv2D(32, (5, 5), activation='relu', padding='same')(up7)
    conv7 = Conv2D(32, (5, 5), activation='relu', padding='same')(conv7)

    conv10 = Conv2D(1, (5, 5), activation='linear',padding='same')(conv7)

    model = Model(inputs=inputs, outputs=conv10)
    return model
size(model)

1516541

In [111]:
def convlstm_stn(num_timesteps, input_shape, with_stn=True):
    # Construct the input layer with no definite frame size.
    inp = layers.Input(shape=(num_timesteps, *input_shape))

    last_day = inp[:, -1:, ..., :1]

    # Stacked ConvLSTM layers
    x = layers.ConvLSTM2D(
        filters=64,
        kernel_size=(5, 5),
        padding="same",
        return_sequences=True,
        activation="relu",
        dropout=0.1,
        recurrent_dropout=0.1,
        kernel_regularizer=L1L2(0.001, 0.01),
    )(inp)
    # x = layers.BatchNormalization(name='BN1')(x)
    x = layers.ConvLSTM2D(
        filters=64,
        kernel_size=(5, 5),
        padding="same",
        return_sequences=False,
        activation="relu",
        dropout=0.1,
        recurrent_dropout=0.1,
        kernel_regularizer=L1L2(0.001, 0.001),
    )(x)

    # Convolutions to reduce the number of channels
    # x = layers.Conv2D(
    #     filters=64,
    #     kernel_size=(5, 5),
    #     activation='relu',
    #     padding='same',
    # )(x)

    # x = layers.Conv2D(
    #     filters=64,
    #     kernel_size=(5, 5),
    #     activation='relu',
    #     padding='same',
    # )(x)
    
    if with_stn:
        # Localization network
        locnet = layers.Conv2D(filters=16, kernel_size=(5, 5), activation='relu', padding='same')(x)
        locnet = MaxPooling2D(pool_size=(2, 2))(locnet)
        locnet = layers.Conv2D(filters=1, kernel_size=(5, 5), activation='relu', padding='same')(locnet)
        locnet = MaxPooling2D(pool_size=(2, 2))(locnet)
        locnet = Flatten()(locnet)
        locnet = Dense(500)(locnet)
        locnet = Activation('relu')(locnet)
        locnet = Dense(200)(locnet)
        locnet = Activation('relu')(locnet)
        locnet = Dense(100)(locnet)
        locnet = Activation('relu')(locnet)
        locnet = Dense(50)(locnet)
        locnet = Activation('relu')(locnet)
        locnet = Dense(6, weights=get_initial_weights(50))(locnet)
        locnet_out = BilinearInterpolation(input_shape[:-1])([inp, locnet])

        # Reshape to get the 1 timestep for concatenation
        locnet_out = tf.expand_dims(locnet_out, 1)
        x = tf.expand_dims(x, 1)
        x = layers.Concatenate(axis=-1)([x, last_day, locnet_out])
    else:
        x = tf.expand_dims(x, 1)
        x = layers.Concatenate(axis=-1)([x, last_day])

    # Convolutions to reduce the number of channels
    x = layers.Conv2D(
        filters=1,
        kernel_size=(1, 1),
        # activation="relu",
        padding="same",
        use_bias=False,
    )(x)

    model = keras.models.Model(inp, x)
    return model

model = convlstm_stn(num_timesteps=3, input_shape=(51, 63, 7), with_stn=True)
size(model)

1516413

In [108]:
model_tester.train_X.shape

(150, 4, 32, 32, 6)

In [112]:
save_dir = '/content/drive/MyDrive/syde770/save_dir_stn/'
model_tester = ModelTester()

nan_mask = np.isnan(np.array(ds.ceda_sic.isel(time=0)))

i = 0
num_timesteps = 4
input_shape = (len(ds.y), len(ds.x), len(ds.data_vars))
model_tester.preprocess_data(
    ds,
    weekly=True,
    num_timesteps=num_timesteps,
    gap=0,
    only_polynya=False,
    )

model_name = 'test_stn'

filename = save_dir + model_name

model = convlstm_stn(num_timesteps=num_timesteps, input_shape=input_shape, with_stn=True)
loss = masked_MSE(mask=np.expand_dims(~nan_mask, [0, -1]))
model.compile(loss=loss, optimizer=keras.optimizers.Adam(learning_rate=0.01))
model_tester.model = model

# Define some callbacks to improve training.
early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=5)

# Define modifiable training hyperparameters.
epochs = 500
batch_size = 10

# Train
history = model_tester.model.fit(
    model_tester.train_X,
    model_tester.train_Y,
    batch_size=batch_size,
    epochs=epochs,
    validation_data=(model_tester.test_X, model_tester.test_Y),
    callbacks=[early_stopping, reduce_lr],
    verbose=1,
)

# model_tester.save(model_name, save_dir)

  x = np.divide(x1, x2, out)
  x = np.divide(x1, x2, out)
  x = np.divide(x1, x2, out)
  x = np.divide(x1, x2, out)


Epoch 1/500


2022-04-28 17:40:15.018312: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.




2022-04-28 17:43:32.211649: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


Epoch 2/500

KeyboardInterrupt: 

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]])