In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import pickle
import higher
import copy
import sys

sys.path.insert(1,"..")
from ts_dataset import TSDataset
from base_models import LSTMModel, FCN
from metrics import torch_mae as mae
import time
import numpy as np

In [2]:
dataset_name = "HR"
dataset_name = "POLLUTION"
model_name = "LSTM"

task_size = 50
batch_size = 64
output_dim = 1

batch_size = 20
horizon = 10
meta_learning_rate = 10e-6
learning_rate = 10e-5
n_inner_iter = 1
##test

if dataset_name == "HR":
    window_size = 32
    input_dim = 13
elif dataset_name == "POLLUTION":
    window_size = 5
    input_dim = 14

def to_torch(numpy_tensor):
    
    return torch.tensor(numpy_tensor).float().cuda()


train_data = pickle.load(  open( "../../Data/TRAIN-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-NOML.pickle", "rb" ) )
train_data_ML = pickle.load( open( "../../Data/TRAIN-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-ML.pickle", "rb" ) )
validation_data = pickle.load( open( "../../Data/VAL-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-NOML.pickle", "rb" ) )
validation_data_ML = pickle.load( open( "../../Data/VAL-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-ML.pickle", "rb" ) )
test_data = pickle.load( open( "../../Data/TEST-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-NOML.pickle", "rb" ) )
test_data_ML = pickle.load( open( "../../Data/TEST-"+dataset_name+"-W"+str(window_size)+"-T"+str(task_size)+"-ML.pickle", "rb" ) )

In [3]:
if model_name == "LSTM":
    model = LSTMModel( batch_size=batch_size, seq_len = window_size, input_dim = input_dim, n_layers = 2, hidden_dim = 120, output_dim =1)

elif model_name == "FCN":
    kernels = [8,5,3] if window_size != 5 else [4,2,1]
    model = FCN(time_steps = window_size,  channels=[input_dim, 128, 128, 128] , kernels=kernels)
    
model.cuda()
meta_opt = optim.Adam(model.parameters(), lr=5e-5)

In [4]:

torch.backends.cudnn.enabled = False

In [5]:




def test(model, test_data_ML, horizon, n_inner_iter):

    model.train()
    total_tasks_test, task_size, window_size, input_dim = test_data_ML.x.shape
    
    qry_losses = []


    for task in range(0, (total_tasks_test-horizon-1), total_tasks_test//100):
        
        start_time = time.time()
        # Sample a batch of support and query images and labels.

        x_spt, y_spt = test_data_ML[task]
        x_qry = test_data_ML.x[(task+1):(task+1+horizon)].reshape(-1, window_size, input_dim)
        y_qry = test_data_ML.y[(task+1):(task+1+horizon)].reshape(-1, output_dim)

        x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
        x_qry = to_torch(x_qry)
        y_qry = to_torch(y_qry)
        
        querysz = x_qry.size(1)

        # TODO: Maybe pull this out into a separate module so it
        # doesn't have to be duplicated between `train` and `test`?

        # Initialize the inner optimizer to adapt the parameters to
        # the support set.
        
        inner_opt = torch.optim.SGD(model.parameters(), lr=1e-4)

        with higher.innerloop_ctx(model, inner_opt, track_higher_grads=False) as (fnet, diffopt):
            # Optimize the likelihood of the support set by taking
            # gradient steps w.r.t. the model's parameters.
            # This adapts the model's meta-parameters to the task.
            for _ in range(n_inner_iter):
                spt_logits = fnet(x_spt)
                spt_loss = mae(spt_logits, y_spt)
                diffopt.step(spt_loss)

            # The query loss and acc induced by these parameters.
            qry_logits = fnet(x_qry).detach()
            qry_loss = mae(qry_logits, y_qry)
            
            qry_losses.append(qry_loss.detach())


    qry_losses = torch.stack(qry_losses).mean().item()

    
    print(qry_losses)
    return qry_losses

In [6]:
model.train()
batch_size = 20
n_iterations = 10

for epoch in range(n_iterations):

    n_train_iter = train_data_ML.x.shape[0] // batch_size

    for batch_idx in range(n_train_iter-1):
        start_time = time.time()
        # Sample a batch of support and query images and labels.

        x_spt, y_spt = train_data_ML[batch_idx:batch_size+batch_idx]
        x_qry, y_qry = train_data_ML[batch_idx+1 : batch_idx+batch_size+1]
        
        x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
        x_qry = to_torch(x_qry)
        y_qry = to_torch(y_qry)
        
        task_num, setsz, c_, w = x_spt.size()
        querysz = x_qry.size(1)

        # TODO: Maybe pull this out into a separate module so it
        # doesn't have to be duplicated between `train` and `test`?

        # Initialize the inner optimizer to adapt the parameters to
        # the support set.
        n_inner_iter = 5
        inner_opt = torch.optim.SGD(model.parameters(), lr=1e-4)

        qry_losses = []
        meta_opt.zero_grad()
        for i in range(task_num):
            with higher.innerloop_ctx(
                model, inner_opt, copy_initial_weights=False
            ) as (fnet, diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                # higher is able to automatically keep copies of
                # your network's parameters as they are being updated.
                for _ in range(n_inner_iter):
                    spt_logits = fnet(x_spt[i])
                    spt_loss = mae(spt_logits, y_spt[i])
                    diffopt.step(spt_loss)

                # The final set of adapted parameters will induce some
                # final loss and accuracy on the query dataset.
                # These will be used to update the model's meta-parameters.
                qry_logits = fnet(x_qry[i])
                qry_loss = mae(qry_logits, y_qry[i])
                qry_losses.append(qry_loss.detach())

                # Update the model's meta-parameters to optimize the query
                # losses across all of the tasks sampled in this batch.
                # This unrolls through the gradient steps.
                qry_loss.backward()

        meta_opt.step()
        qry_losses = sum(qry_losses) / task_num
        i = epoch + float(batch_idx) / n_train_iter
        iter_time = time.time() - start_time
        if batch_idx % 1 == 0:
            print(
                f'[Epoch {i:.2f}] Train Loss: {qry_losses:.4f}  | Time: {iter_time:.2f}'
            )
            test(model, validation_data_ML, horizon, n_inner_iter)
            test(model, test_data_ML, horizon, n_inner_iter)

[Epoch 0.00] Train Loss: 0.1001  | Time: 6.48
0.06130401790142059
0.08580785989761353
[Epoch 0.01] Train Loss: 0.1016  | Time: 6.78
0.060051001608371735
0.08433961868286133
[Epoch 0.02] Train Loss: 0.1070  | Time: 6.69
0.058799535036087036
0.0828726589679718
[Epoch 0.03] Train Loss: 0.1031  | Time: 6.91
0.05754739046096802
0.08140676468610764
[Epoch 0.04] Train Loss: 0.1028  | Time: 6.93
0.05629550665616989
0.0799437165260315
[Epoch 0.05] Train Loss: 0.1038  | Time: 6.98
0.055052828043699265
0.07848794013261795
[Epoch 0.06] Train Loss: 0.1010  | Time: 6.78
0.05381219461560249
0.07703561335802078
[Epoch 0.07] Train Loss: 0.0900  | Time: 7.35
0.052572254091501236
0.07558497786521912
[Epoch 0.08] Train Loss: 0.0840  | Time: 7.17
0.051338378340005875
0.07414201647043228
[Epoch 0.09] Train Loss: 0.0870  | Time: 6.90
0.05012376978993416
0.07271445542573929
[Epoch 0.10] Train Loss: 0.0842  | Time: 6.93
0.048933349549770355
0.07130029797554016
[Epoch 0.11] Train Loss: 0.0839  | Time: 6.69
0.04

0.05296102911233902
[Epoch 0.93] Train Loss: 0.0828  | Time: 6.52
0.05230892077088356
0.0531684011220932
[Epoch 0.94] Train Loss: 0.0813  | Time: 6.58
0.052630942314863205
0.05334977060556412
[Epoch 0.95] Train Loss: 0.0800  | Time: 6.60
0.05289890617132187
0.05350331589579582
[Epoch 0.96] Train Loss: 0.0785  | Time: 6.51
0.05312482640147209
0.05363628268241882
[Epoch 0.97] Train Loss: 0.0801  | Time: 6.55
0.053327251225709915
0.05375835299491882
[Epoch 0.98] Train Loss: 0.0831  | Time: 6.50
0.05350477620959282
0.05387134850025177
[Epoch 1.00] Train Loss: 0.0648  | Time: 6.55
0.053565818816423416
0.05392572283744812
[Epoch 1.01] Train Loss: 0.0659  | Time: 6.49
0.05353972688317299
0.05393483489751816
[Epoch 1.02] Train Loss: 0.0699  | Time: 6.99
0.053450386971235275
0.05391380563378334
[Epoch 1.03] Train Loss: 0.0693  | Time: 7.01
0.05330350622534752
0.053865887224674225
[Epoch 1.04] Train Loss: 0.0684  | Time: 6.42
0.05310969427227974
0.05379823222756386
[Epoch 1.05] Train Loss: 0.068

0.04920510575175285
0.053282901644706726
[Epoch 1.87] Train Loss: 0.0556  | Time: 6.80
0.04952835291624069
0.05343689024448395
[Epoch 1.88] Train Loss: 0.0585  | Time: 6.74
0.0498293861746788
0.053576137870550156
[Epoch 1.89] Train Loss: 0.0580  | Time: 6.55
0.05011171102523804
0.05370103940367699
[Epoch 1.90] Train Loss: 0.0629  | Time: 6.63
0.05037420243024826
0.05380837991833687
[Epoch 1.91] Train Loss: 0.0753  | Time: 6.87
0.050635963678359985
0.05390387400984764
[Epoch 1.92] Train Loss: 0.0791  | Time: 6.65
0.05089222267270088
0.053988974541425705
[Epoch 1.93] Train Loss: 0.0787  | Time: 6.80
0.05112902820110321
0.05405980721116066
[Epoch 1.94] Train Loss: 0.0764  | Time: 6.72
0.05135991796851158
0.054121941328048706
[Epoch 1.95] Train Loss: 0.0748  | Time: 6.73
0.05158663168549538
0.054177138954401016
[Epoch 1.96] Train Loss: 0.0733  | Time: 6.83
0.0518205426633358
0.05422871932387352
[Epoch 1.97] Train Loss: 0.0748  | Time: 6.90
0.052077312022447586
0.05428079515695572
[Epoch 1.

0.050380807369947433
0.05054768547415733
[Epoch 2.81] Train Loss: 0.0493  | Time: 6.63
0.050910256803035736
0.05075585097074509
[Epoch 2.82] Train Loss: 0.0488  | Time: 6.83
0.05143694207072258
0.050966329872608185
[Epoch 2.83] Train Loss: 0.0490  | Time: 6.66
0.051943015307188034
0.05117475986480713
[Epoch 2.84] Train Loss: 0.0489  | Time: 6.54
0.0524163581430912
0.05137532949447632
[Epoch 2.85] Train Loss: 0.0528  | Time: 6.64
0.052856143563985825
0.05155859887599945
[Epoch 2.86] Train Loss: 0.0533  | Time: 6.55
0.05325336009263992
0.05172676965594292
[Epoch 2.87] Train Loss: 0.0540  | Time: 6.57
0.053591422736644745
0.05187535658478737
[Epoch 2.88] Train Loss: 0.0568  | Time: 6.56
0.05389276519417763
0.05200871452689171
[Epoch 2.89] Train Loss: 0.0561  | Time: 6.52
0.05416370555758476
0.05212758481502533
[Epoch 2.90] Train Loss: 0.0607  | Time: 6.57
0.054403964430093765
0.05222916975617409
[Epoch 2.91] Train Loss: 0.0720  | Time: 6.64
0.054644882678985596
0.05232078209519386
[Epoch 

KeyboardInterrupt: 

In [None]:
import numpy as np

In [6]:
a = np.random.randint(0, 10, 3)

In [7]:
a +1

array([9, 4, 8])

In [8]:
a

array([8, 3, 7])