In [1]:
import learn2learn as l2l
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle
import sys
import argparse
import os
from run_MAML_04 import test2 as test_maml


sys.path.insert(1, "..")

from ts_dataset import TSDataset
from base_models import LSTMModel, FCN
from metrics import torch_mae as mae
import copy
from pytorchtools import EarlyStopping, to_torch
from eval_base_models import test, train, freeze_model
from torch.utils.data import Dataset, DataLoader
from ts_dataset import DomainTSDataset, SimpleDataset
from sklearn.manifold import TSNE


In [None]:
meta_info = {"POLLUTION": [5, 50, 14],
             "HR": [32, 50, 13],
             "BATTERY": [20, 50, 3] }

output_directory = "output/"
horizon = 10
output_dim = 1

dataset_name = "HR"
save_model_file = "model6.pt"
load_model_file = "model6.pt"
lower_trial = 0
upper_trial = 3
learning_rate = 0.01
meta_learning_rate = 0.005
adaptation_steps = 10
batch_size = 20
model_name = "LSTM"
is_test = 1
patience_stopping = 20
epochs = 1000
noise_level = 0.0
noise_type = "additive"

params = {'batch_size': batch_size,
      'shuffle': True,
      'num_workers': 0}

assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
assert dataset_name in ("POLLUTION", "HR", "BATTERY")

window_size, task_size, input_dim = meta_info[dataset_name]
grid = [0., noise_level]

train_data = pickle.load(  open( "../../Data/TRAIN-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-NOML.pickle", "rb" ) )
train_data_ML = pickle.load( open( "../../Data/TRAIN-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-ML.pickle", "rb" ) )
validation_data = pickle.load( open( "../../Data/VAL-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-NOML.pickle", "rb" ) )
validation_data_ML = pickle.load( open( "../../Data/VAL-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-ML.pickle", "rb" ) )
test_data = pickle.load( open( "../../Data/TEST-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-NOML.pickle", "rb" ) )
test_data_ML = pickle.load( open( "../../Data/TEST-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-ML.pickle", "rb" ) )

In [None]:

trial = 0
output_directory = "../../Models/"+dataset_name+"_"+model_name+"_MAML/"+str(trial)+"/"

save_model_file_ = output_directory + "encoder_"+save_model_file
save_model_file_2 = output_directory + save_model_file
load_model_file_ = output_directory + load_model_file

model = LSTMModel( batch_size=batch_size, seq_len = window_size, input_dim = input_dim, n_layers = 2, hidden_dim = 120, output_dim =1)
model2 = nn.Linear(120, 1)

model.cuda()
model2.cuda()

maml = l2l.algorithms.MAML(model2, lr=learning_rate, first_order=False)
model.load_state_dict(torch.load(save_model_file_))
maml.load_state_dict(torch.load(save_model_file_2))


total_tasks_test = len(test_data_ML)
error_list =  []
parameters_list = []
domain_list = []
error_mean = []
activations_list = []

learner = maml.clone()  # Creates a clone of model
learner.cuda()
count = 0.0


input_dim = test_data_ML.x.shape[-1]
window_size = test_data_ML.x.shape[-2]
output_dim = test_data_ML.y.shape[-1]

if is_test:
    step = total_tasks_test//100

else:
    step = 1

step = 1 if step == 0 else step
max_tasks = (total_tasks_test-horizon-1)


temp_params = []
for params in maml.parameters():

    temp_params.append(params.cpu().detach().numpy()[0])
params = np.concatenate([temp_params[0], np.array([temp_params[1]])])
parameters_list.append(list(params))
domain_list.append(test_data_ML.file_idx[-1])

for task in range(0,max_tasks , step):

    temp_file_idx = test_data_ML.file_idx[task:task+horizon+1]
    if(len(np.unique(temp_file_idx))>1):
        continue

   
    model2 = LSTMModel( batch_size=None, seq_len = None, input_dim = input_dim, n_layers = 2, hidden_dim = 120, output_dim =1)

    learner = maml.clone() 

    x_spt, y_spt = test_data_ML[task]
    x_qry = test_data_ML.x[(task+1):(task+1+horizon)].reshape(-1, window_size, input_dim)
    y_qry = test_data_ML.y[(task+1):(task+1+horizon)].reshape(-1, output_dim)



    x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
    x_qry = to_torch(x_qry)
    y_qry = to_torch(y_qry)


    for step in range(adaptation_steps):


        pred = learner(model.encoder(x_spt))
        error = mae(pred, y_spt)
        learner.adapt(error)


    error_mean.append(np.mean(error.cpu().detach().numpy()))
    temp_params = []
    for params in maml.parameters():

        temp_params.append(params.cpu().detach().numpy()[0])
    params = np.concatenate([temp_params[0], np.array([temp_params[1]])])
    parameters_list.append(list(params))
    domain_list.append(test_data_ML.file_idx[task])
    activations_list.append(model.encoder(x_spt).cpu().detach().numpy()[np.newaxis, :])