In [None]:
import os
import logging
import shutil
import numpy as np
import scipy.linalg as linalg
import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp "drive/MyDrive/Work/RSTA/final_repo/utils.py" "./"
!cp "drive/MyDrive/Work/RSTA/final_repo/data_utils.py" "./"
!cp "drive/MyDrive/Work/RSTA/final_repo/model_utils.py" "./"
!cp "drive/MyDrive/Work/RSTA/final_repo/models.py" "./"

In [None]:
from utils import Config, torch2numpy, numpy2torch, get_MAEs
from data_utils import load_dataset, pad, unpad, augment_rbf_coefs
from models import RBF, SpatialTransform, NAB, RFN
from model_utils import sample_independent, sample_dependent

In [None]:
work_dir = 'drive/MyDrive/Work/RSTA/final_repo/'
dataset = 'noaa_pt' # e.g., 'convdiff', 'noaa_ec', 'noaa_pt', 'sst'

In [None]:
args = Config(os.path.join(work_dir, 'configs', 'config_{}.txt'.format(dataset)))

In [None]:
args.data_dir = os.path.join(work_dir, args.data_dir)

In [None]:
savepath = os.path.join(work_dir, 'saved_models/', args.dataset)
if not os.path.exists(savepath):
    os.makedirs(os.path.join(savepath))

In [None]:
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

In [None]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
train_set, _, test_set, p = load_dataset(args.data_dir, args.dataset)
n_nodes = len(p)

In [None]:
all_train = []
for item in train_set:
    all_train.extend([frame for frame in item])
all_train = np.array(all_train)

In [None]:
mu, sig = all_train.mean(), all_train.std()

In [None]:
# Get the RBF matrix
PHI = np.zeros([n_nodes, n_nodes])
ARGS = np.zeros([n_nodes, n_nodes, 4])
for i in range(n_nodes):
    for j in range(n_nodes):
        PHI[i, j] = linalg.norm(p[i] - p[j], axis=-1)
        ARGS[i, j, :2] = p[i]
        ARGS[i, j, 2:] = p[j]

In [None]:
PHI = numpy2torch(PHI)
ARGS = numpy2torch(ARGS)

In [None]:
rbf = RBF(eps=args.eps, cls=args.rbf)
if torch.cuda.is_available():
    rbf = rbf.cuda()

In [None]:
PHI = rbf(PHI)
PHI = PHI.detach()
invPHI = numpy2torch(linalg.inv(torch2numpy(PHI)))
lhs = torch.matmul(PHI.t(), PHI) + args.lstsq_reg * torch.eye(PHI.shape[1])

In [None]:
spatial = SpatialTransform(out_dim=args.n_spatial_fts)
nab = NAB(in_dim=args.n_spatial_fts)
rfn = RFN(in_dim=args.n_levels*args.n_spatial_fts+1)

In [None]:
if torch.cuda.is_available():
    spatial = spatial.cuda()
    nab = nab.cuda()
    rfn = rfn.cuda()

In [None]:
test_set = [(item - mu) / sig for item in test_set]
test_set_padded = pad(test_set, args.max_len, n_nodes)
test_set_padded = augment_rbf_coefs(test_set_padded, PHI, lhs)

In [None]:
checkpoint = torch.load(os.path.join(savepath, 'spatial-best.pt'))
spatial.load_state_dict(checkpoint)
spatial = spatial.eval()
checkpoint = torch.load(os.path.join(savepath, 'nab-best.pt'))
nab.load_state_dict(checkpoint)
nab = nab.eval()
checkpoint = torch.load(os.path.join(savepath, 'rfn-best.pt'))
rfn.load_state_dict(checkpoint)
rfn = rfn.eval()

In [None]:
batch_size = len(test_set)
with torch.no_grad():
    S = sample_independent(spatial, PHI, invPHI, ARGS, batch_size)
    target, pred = sample_dependent(nab, rfn, test_set_padded, S, PHI, lhs, 
                                    batch_size, n_nodes, 
                                    args.in_len, args.max_len, 
                                    args.n_levels, args.n_spatial_fts)

In [None]:
gt = [torch2numpy(seq) for seq in target]
pred = [torch2numpy(seq) for seq in pred]

In [None]:
MAEs = get_MAEs(gt, pred, args.mae_list, args.in_len)
for i in range(len(MAEs)):
    print('{}-step MAE: {}'.format(args.mae_list[i], MAEs[i]))