In [None]:
import os
import shutil
import numpy as np
import scipy.linalg as linalg
import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp "drive/MyDrive/Work/RSTA/final_repo/utils.py" "./"
!cp "drive/MyDrive/Work/RSTA/final_repo/data_utils.py" "./"
!cp "drive/MyDrive/Work/RSTA/final_repo/model_utils.py" "./"
!cp "drive/MyDrive/Work/RSTA/final_repo/models.py" "./"

In [None]:
from utils import Config, torch2numpy, numpy2torch, get_MAEs
from data_utils import load_dataset, pad, unpad, augment_rbf_coefs
from models import RBF, SpatialTransform, NAB, RFN
from model_utils import sample_independent, sample_dependent

In [None]:
work_dir = 'drive/MyDrive/Work/RSTA/final_repo/'
dataset = 'noaa_pt' # e.g., 'convdiff', 'noaa_ec', 'noaa_pt', 'sst'

In [None]:
args = Config(os.path.join(work_dir, 'configs', 'config_{}.txt'.format(dataset)))

In [None]:
args.data_dir = os.path.join(work_dir, args.data_dir)

In [None]:
savepath = os.path.join(work_dir, 'saved_models/', args.dataset)
if not os.path.exists(savepath):
    os.makedirs(os.path.join(savepath))

In [None]:
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

In [None]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
train_set, val_set, _, p = load_dataset(args.data_dir, args.dataset)
n_nodes = len(p)

In [None]:
all_train = []
for item in train_set:
    all_train.extend([frame for frame in item])
all_train = np.array(all_train)

In [None]:
mu, sig = all_train.mean(), all_train.std()

In [None]:
train_set = [(item - mu) / sig for item in train_set]
val_set = [(item - mu) / sig for item in val_set]

In [None]:
train_set_padded = pad(train_set, args.max_len, n_nodes)
val_set_padded = pad(val_set, args.max_len, n_nodes)

In [None]:
random.shuffle(train_set_padded)

In [None]:
# Get the RBF matrix
PHI = np.zeros([n_nodes, n_nodes])
ARGS = np.zeros([n_nodes, n_nodes, 4])
for i in range(n_nodes):
    for j in range(n_nodes):
        PHI[i, j] = linalg.norm(p[i] - p[j], axis=-1)
        ARGS[i, j, :2] = p[i]
        ARGS[i, j, 2:] = p[j]

In [None]:
PHI = numpy2torch(PHI)
ARGS = numpy2torch(ARGS)

In [None]:
rbf = RBF(eps=args.eps, cls=args.rbf)
if torch.cuda.is_available():
    rbf = rbf.cuda()

In [None]:
PHI = rbf(PHI)
PHI = PHI.detach()
invPHI = numpy2torch(linalg.inv(torch2numpy(PHI))) 
lhs = torch.matmul(PHI.t(), PHI) + args.lstsq_reg * torch.eye(PHI.shape[1])

In [None]:
train_set_padded = augment_rbf_coefs(train_set_padded, PHI, lhs)
val_set_padded = augment_rbf_coefs(val_set_padded, PHI, lhs)

In [None]:
spatial = SpatialTransform(out_dim=args.n_spatial_fts)
nab = NAB(in_dim=args.n_spatial_fts)
rfn = RFN(in_dim=args.n_levels*args.n_spatial_fts+1)

In [None]:
if torch.cuda.is_available():
    spatial = spatial.cuda()
    nab = nab.cuda()
    rfn = rfn.cuda()

In [None]:
spatial = spatial.train()
nab = nab.train()
rfn = rfn.train()

In [None]:
params = list(spatial.parameters()) + list(nab.parameters()) + list(rfn.parameters())
optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)

In [None]:
num_params = sum([np.prod(par.size()) for par in params])
print('Number of parameters: {}'.format(num_params))

In [None]:
criterion = nn.MSELoss()

In [None]:
train_steps = len(train_set_padded) // args.batch_size
val_steps = len(val_set_padded) // args.batch_size

In [None]:
#%%script false --no-raise-error
sk = args.inv_sig_coef
for epoch in tqdm(range(args.num_epochs)):
    train_sample_prob = sk / (sk + np.exp(epoch/sk))
    train_error = 0.
    for step in range(train_steps):
        S = sample_independent(spatial, PHI, invPHI, ARGS, args.batch_size)
        selected_idxs = torch.randperm(len(train_set_padded))[:args.batch_size]
        seqs = [train_set_padded[idx] for idx in selected_idxs]
        target, pred = sample_dependent(nab, rfn, seqs, S, PHI, lhs, 
                                        args.batch_size, n_nodes, 
                                        args.in_len, args.max_len, 
                                        args.n_levels, args.n_spatial_fts,
                                        train_sample_prob)
        target = torch.cat([seq.reshape(-1) for seq in target])
        pred = torch.cat([seq.reshape(-1) for seq in pred])
        train_loss = criterion(target, pred)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        train_error += train_loss.item()

    val_error = 0.
    with torch.no_grad():
        for vstep in range(val_steps):
            S = sample_independent(spatial, PHI, invPHI, ARGS, args.batch_size)
            seqs = val_set_padded[vstep*args.batch_size : (vstep+1)*args.batch_size]
            target, pred = sample_dependent(nab, rfn, seqs, S, PHI, lhs, 
                                            args.batch_size, n_nodes, 
                                            args.in_len, args.max_len,
                                            args.n_levels, args.n_spatial_fts)
            target = torch.cat([seq.reshape(-1) for seq in target])
            pred = torch.cat([seq.reshape(-1) for seq in pred])
            val_loss = criterion(target, pred)
            val_error += val_loss.item()  

    print("Epoch: {}, Training Error: {}, Validation Error: {}".format(epoch,
                                                                       train_error/train_steps,
                                                                       val_error/val_steps))

    torch.save(spatial.state_dict(), os.path.join(savepath, 'spatial-{0:04d}.pt'.format(epoch)))
    torch.save(nab.state_dict(), os.path.join(savepath, 'nab-{0:04d}.pt'.format(epoch)))
    torch.save(rfn.state_dict(), os.path.join(savepath, 'rfn-{0:04d}.pt'.format(epoch)))

In [None]:
#%%script false --no-raise-error
batch_size = len(val_set)
MAEs = []
for epoch in tqdm(range(args.num_epochs)):
    checkpoint = torch.load(os.path.join(savepath, 'spatial-{0:04d}.pt'.format(epoch)))
    spatial.load_state_dict(checkpoint)
    spatial = spatial.eval()
    checkpoint = torch.load(os.path.join(savepath, 'nab-{0:04d}.pt'.format(epoch)))
    nab.load_state_dict(checkpoint)
    nab = nab.eval()
    checkpoint = torch.load(os.path.join(savepath, 'rfn-{0:04d}.pt'.format(epoch)))
    rfn.load_state_dict(checkpoint)
    rfn = rfn.eval()
    with torch.no_grad():
        S = sample_independent(spatial, PHI, invPHI, ARGS, batch_size)
        target, pred = sample_dependent(nab, rfn, val_set_padded, S, PHI, lhs, 
                                        batch_size, n_nodes, 
                                        args.in_len, args.max_len,
                                        args.n_levels, args.n_spatial_fts)
        gt = [torch2numpy(seq) for seq in target]
        pred = [torch2numpy(seq) for seq in pred]
        MAEs.append(get_MAEs(gt, pred, args.mae_list, args.in_len))

In [None]:
MAEs = np.array(MAEs)
plt.plot(MAEs)
plt.gca().legend(['mae_{}'.format(item) for item in args.mae_list])

In [None]:
print('Best validation losses and corresponding checkpoints')
print(str(np.min(MAEs, axis=0)) + ',' + str(np.argmin(MAEs, axis=0)))

In [None]:
MAE = MAEs.sum(axis=-1)
best_checkpoint_id = np.argmin(MAE)
print('Best checkpoint: {}'.format(best_checkpoint_id))

In [None]:
shutil.copy2(os.path.join(savepath, 'spatial-{0:04d}.pt'.format(best_checkpoint_id)), 
             os.path.join(savepath, 'spatial-best.pt'))
shutil.copy2(os.path.join(savepath, 'nab-{0:04d}.pt'.format(best_checkpoint_id)),
             os.path.join(savepath, 'nab-best.pt'))
shutil.copy2(os.path.join(savepath, 'rfn-{0:04d}.pt'.format(best_checkpoint_id)),
             os.path.join(savepath, 'rfn-best.pt'))

In [None]:
print('Validation losses for best checkpoint')
for i in range(len(args.mae_list)):
    print('{}-step MAE: {}'.format(args.mae_list[i], MAEs[best_checkpoint_id, i]))