In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import datetime
import os
import time
import glob
import pandas as pd
import xarray as xr
from dateutil.relativedelta import relativedelta

import argparse

from utils import normalize

from mpnnlstm import NextFramePredictorS2S
from seq2seq import Seq2Seq

from torch.utils.data import Dataset, DataLoader

from ice_test import IceDataset

from graph_functions import create_static_heterogeneous_graph, create_static_homogeneous_graph

exp = 9
device = 'cpu'
month = 4

# Defaults
convolution_type = 'TransformerConv'
lr = 0.001
multires_training = False
truncated_backprop = 0

training_years = range(2013, 2014)
x_vars = ['siconc', 't2m', 'v10', 'u10', 'sshf']
y_vars = ['siconc']  # ['siconc', 't2m']
input_features = len(x_vars)
input_timesteps = 10
output_timesteps= 90

binary=False

ds = xr.open_mfdataset(glob.glob('data/hb_era5_glorys_nc/*.nc'))
mask = np.isnan(ds.siconc.isel(time=0)).values

image_shape = mask.shape
graph_structure = None

if exp == 9:
    graph_structure = create_static_heterogeneous_graph(image_shape, 4, mask, use_edge_attrs=True, resolution=1/12, device=device)
elif exp == 10:
    graph_structure = create_static_homogeneous_graph(image_shape, 4, mask, use_edge_attrs=True, resolution=1/12, device=device)

# Full resolution datasets
data_val = IceDataset(ds, range(training_years[-1]+2, training_years[-1]+2+1), month, input_timesteps, output_timesteps, x_vars, y_vars)
loader_val = DataLoader(data_val, batch_size=1, shuffle=False)

climatology = ds[y_vars].groupby('time.dayofyear').mean('time', skipna=True).to_array().values
climatology = torch.tensor(np.nan_to_num(climatology)).to(device)

# Set threshold 
thresh = -np.inf
print(f'Threshold is {thresh}')

# Note: irrelevant if thresh = -np.inf
def dist_from_05(arr):
    return abs(abs(arr - 0.5) - 0.5)

# Arguments passed to Seq2Seq constructor
model_kwargs = dict(
    hidden_size=32,
    dropout=0.1,
    n_layers=1,
    transform_func=dist_from_05,
    dummy=False,
    n_conv_layers=3,
    rnn_type='LSTM',
    convolution_type=convolution_type,
)

experiment_name = f'M{str(month)}_Y{training_years[0]}_Y{training_years[-1]}_I{input_timesteps}O{output_timesteps}'

model = NextFramePredictorS2S(
    thresh=thresh,
    experiment_name=experiment_name,
    input_features=input_features,
    input_timesteps=input_timesteps,
    output_timesteps=output_timesteps,
    transform_func=dist_from_05,
    device=device,
    binary=binary,
    debug=True, 
    model_kwargs=model_kwargs)

Threshold is -inf


In [18]:
from torch_geometric.nn import TransformerConv

{p[0]:p[1] for p in TransformerConv(2, 3, beta=False, edge_dim=None).named_parameters()}

{'lin_key.weight': Parameter containing:
 tensor([[-0.2350,  0.5029],
         [-0.3364, -0.5943],
         [-0.3066, -0.3863]], requires_grad=True),
 'lin_key.bias': Parameter containing:
 tensor([ 0.0195, -0.1665, -0.5804], requires_grad=True),
 'lin_query.weight': Parameter containing:
 tensor([[ 0.0702,  0.1720],
         [ 0.3632, -0.2066],
         [ 0.6133,  0.6374]], requires_grad=True),
 'lin_query.bias': Parameter containing:
 tensor([-0.5611, -0.1725, -0.0108], requires_grad=True),
 'lin_value.weight': Parameter containing:
 tensor([[0.5547, 0.6581],
         [0.5404, 0.5012],
         [0.6790, 0.2506]], requires_grad=True),
 'lin_value.bias': Parameter containing:
 tensor([ 0.3010,  0.3542, -0.6998], requires_grad=True),
 'lin_skip.weight': Parameter containing:
 tensor([[ 0.0717, -0.1549],
         [ 0.1779, -0.3913],
         [-0.6366, -0.2701]], requires_grad=True),
 'lin_skip.bias': Parameter containing:
 tensor([ 0.0180, -0.1510, -0.1079], requires_grad=True)}

In [21]:
[p[-1] for p in model.model.encoder.rnns[0].conv_x_i.convolutions[0].named_parameters()]

[Parameter containing:
 tensor([[-2.4386e-01, -4.8817e-02,  3.8739e-02,  9.4434e-02, -1.1092e-01,
          -5.3421e-02,  6.8142e-02, -3.4405e-01],
         [-2.3320e-01, -3.2890e-01,  2.1596e-01,  3.3421e-01, -1.4628e-01,
          -2.5903e-01, -8.4080e-02, -3.1113e-01],
         [ 1.2065e-01,  1.0528e-01,  3.2905e-01, -2.1720e-02,  2.0183e-01,
          -2.5872e-01,  1.0905e-01,  3.4713e-01],
         [ 5.2814e-02,  1.9680e-02, -2.5725e-01, -3.1567e-01, -2.8564e-02,
          -2.9338e-01, -2.4633e-01, -1.2180e-01],
         [ 7.0629e-02, -1.4158e-02,  1.2214e-02,  2.9991e-01,  1.2829e-01,
           8.6204e-02,  1.6069e-01, -1.7049e-01],
         [-2.5080e-01, -1.5417e-01, -3.2472e-01,  1.6054e-02, -1.5643e-01,
          -1.0002e-01, -1.7180e-01,  1.1509e-01],
         [ 1.2246e-01,  3.3694e-02,  1.6270e-02,  1.7405e-01,  1.2550e-01,
           1.7992e-01,  2.4506e-01, -1.5682e-02],
         [-2.1549e-01,  1.5673e-01, -2.5839e-01,  3.3869e-01,  3.1862e-01,
          -1.7906e-01,  1.2