In [167]:
import os
import random
import uuid
from collections import defaultdict
from timeit import default_timer as timer

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from torch.distributions.normal import Normal
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler
# from torch.utils.tensorboard import SummaryWriter

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import os, pickle
import copy
import json

In [168]:
tracker = {"train":{}, "test":{}}

In [169]:
class TsDS(Dataset):
    def __init__(self, XL,yL,flatten=False,lno=None,long=True):
        self.samples=[]
        self.labels=[]
        self.flatten=flatten
        self.lno=lno
        self.long=long
        self.scaler = StandardScaler()
        for X,Y in zip(XL,yL):
            self.samples += [torch.tensor(X).float()]
            self.labels += [torch.tensor(Y)]
            
    def __len__(self):
        return sum([s.shape[0] for s in self.samples])

    def __getitem__(self, idx):
        if self.flatten: sample=self.samples[idx].flatten(start_dim=1)
        else: sample=self.samples[idx]
        if self.lno==None: label=self.labels[idx]
        elif self.long: label=self.labels[idx][:,self.lno].long()
        else: label=self.labels[idx][:,self.lno].float()
        return (sample,label)

    def fit(self,kind='seq'):
        if kind=='seq':
            self.lastelems=[torch.cat([s[:,-1,:] for s in self.samples],dim=0)]
            self.scaler.fit(torch.cat([le for le in self.lastelems],dim=0))            
        elif kind=='flat': self.scaler.fit(torch.cat([s for s in self.samples],dim=0))
    def scale(self,kind='flat',scaler=None):
        self.fit(kind)
        def cs(s):
            return (s.shape[0]*s.shape[1],s.shape[2])
        if scaler==None: scaler=self.scaler
        if kind=='seq':
            self.samples=[torch.tensor(scaler.transform(s.reshape(cs(s))).reshape(s.shape)).float() for s in self.samples]
            pass
        elif kind=='flat':
            self.samples=[torch.tensor(scaler.transform(s)).float() for s in self.samples]
    def unscale(self,kind='flat',scaler=None):
        def cs(s):
            return (s.shape[0]*s.shape[1],s.shape[2])
        if scaler==None: scaler=self.scaler
        if kind=='seq':
            self.samples=[torch.tensor(scaler.inverse_transform(s.reshape(cs(s))).reshape(s.shape)).float() for s in self.samples]
            pass
        elif kind=='flat':
            self.samples=[torch.tensor(scaler.inverse_transform(s)).float() for s in self.samples]

In [170]:
class Accumulator:
    def __init__(self):
        self.clear()

    def clear(self):
        self.metrics = defaultdict(lambda: [])

    def add(self, key, value):
        self.metrics[key] += value

    def add_dict(self, dict):
        for key, value in dict.items():
            self.add(key, value)

    def mean(self, key):
        return np.mean(self.metrics[key])

    def __getitem__(self, item):
        return self.metrics[item]

    def __setitem__(self, key, value):
        self.metrics[key] = value

    def get_dict(self):
        return copy.deepcopy(dict(self.metrics))

    def items(self):
        return self.metrics.items()

    def __str__(self):
        return str(dict(self.metrics))

In [171]:
def get_numbers(name):
    splitted = name.split('_')
    g, d = (splitted[2]), int(splitted[3])
    return g, d

folder_path = os.path.join("marketdata")
l = os.listdir(folder_path)

data_type = "ds"
meta_train = {"train": [], "test": []}
meta_test = {"train": [], "test": []}
kind = "seq" if data_type == "ds" else "flat"

for file in l:
    if data_type in file:
        type_ = "train" if "train" in file else "test"
        g, d = get_numbers(file)
        if d < 20: # for meta-training
            meta_train[type_].append(file)
        else: # for meta-testing
            meta_test[type_].append(file)


meta_train["train"] = sorted(meta_train["train"])
meta_train["test"] = sorted(meta_train["test"])

data = list(zip(meta_train["train"], meta_train["test"]))
data = sorted(data, key=lambda x: get_numbers(x[0])[1])
idx = 0

def load_task(task):
    """
    task is a tuple of strings of the form (train_cs_g_d_2.pkl, test_cs_g_d_2.pkl)
    returns X_train, y_train, X_test, y_test
    """
    train_file, test_file = task
    # print(train_file)
    train_data = pickle.load(open(os.path.join(folder_path, train_file), "rb"))
    test_data = pickle.load(open(os.path.join(folder_path, test_file), "rb"))
    train_data.scale(kind)
    test_data.scale(kind)
    # print(train_data, test_data)
    return train_data.samples, train_data.labels, test_data.samples, test_data.labels

def sample_task():
    global idx
    if idx >= len(data):
        idx = 0
    task = data[idx]
    idx += 1
    
    return load_task(task)

print(len(data))

100


In [172]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim_query, dim_key, dim_value, dim_output, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_query, dim_output, bias=False)
        self.fc_k = nn.Linear(dim_key, dim_output, bias=False)
        self.fc_v = nn.Linear(dim_value, dim_output, bias=False)
        self.fc_o = nn.Linear(dim_output, dim_output)

    def forward(self, query, key, value, mask=None):
        query = self.fc_q(query)
        key = self.fc_k(key)
        value = self.fc_v(value)

        query_ = torch.cat(query.chunk(self.num_heads, -1), 0)
        key_ = torch.cat(key.chunk(self.num_heads, -1), 0)
        value_ = torch.cat(value.chunk(self.num_heads, -1), 0)

        A_logits = (query_ @ key_.transpose(-2, -1)) / math.sqrt(query.shape[-1])
        if mask is not None:
            mask = torch.stack([mask.squeeze(-1)] * query.shape[-2], -2)
            mask = torch.cat([mask] * self.num_heads, 0)
            A_logits.masked_fill(mask, -float("inf"))
            A = torch.softmax(A_logits, -1)
        else:
            A = torch.softmax(A_logits, -1)

        outs = torch.cat((A @ value_).chunk(self.num_heads, 0), -1)
        outs = query + outs
        outs = outs + F.relu(self.fc_o(outs))
        return outs


class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds):
        super().__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mha = MultiHeadAttention(dim, dim, dim, dim, num_heads)

    def forward(self, X):
        batch_size = X.size(0)
        query = self.S.repeat(batch_size, 1, 1)
        return self.mha(query, X, X).squeeze()

## NC Model

In [173]:
def convert_y_ohe(y_tr):
    output = y_tr[:, :6]
    y_tr_cls = y_tr[:, 6:]
    y_tr_new = torch.zeros((y_tr.shape[0], 16)).to(device)
    for i in range(y_tr.shape[0]):
        for j in range(4):
            y_tr_new[i, 4*j+int(y_tr_cls[i, j])] = 1
    return torch.cat((output, y_tr_new), dim=1)

In [175]:
def fc_stack(num_layers, input_dim, hidden_dim, output_dim):
    if num_layers == 0:
        return nn.Identity()
    elif num_layers == 1:
        return nn.Linear(input_dim, output_dim)
    else:
        modules = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(num_layers - 2):
            modules.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
        modules.append(nn.Linear(hidden_dim, output_dim))
        return nn.Sequential(*modules)


class CrossAttEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        dim = hid_dim
        self.bilinear = nn.Bilinear(360, 55, 392)
        self.mlp_v = fc_stack(enc_depth, 392, dim, dim)
        self.mlp_qk = fc_stack(enc_depth, 382, dim, dim)
        self.attn = MultiHeadAttention(dim, dim, dim, dim, num_heads)

    def forward(self, inputs):
        x_tr, y_tr, train_pred = inputs["tr_xyp"][:, :360], inputs["tr_xyp"][:, 360:370], inputs["tr_xyp"][:, 370:]
        q = self.mlp_qk(inputs["te_xp"])
        k = self.mlp_qk(inputs["tr_xp"])
        
        y_tr = convert_y_ohe(y_tr)
        tr_loss = inputs["tr_loss"]
        # print(y_tr.shape, tr_loss.shape, )
        bilinear_input = torch.cat((y_tr, torch.ones((y_tr.shape[0], 1)).to(device), tr_loss, train_pred), 1)
        # print(x_tr.shape, bilinear_input.shape)
        bilinear_output = self.bilinear(x_tr, bilinear_input)
        v = self.mlp_v(bilinear_output)
        
        out = self.attn(q, k, v)
        return out


class MeanPool(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # assert len(x.shape) == 3
        return x.mean(0)


class NeuralComplexity1D(nn.Module):
    def __init__(self):
        super().__init__()
        self.bs = batch_size
        self.encoder = CrossAttEncoder()

        if pool == "pma":
            self.pool = PMA(dim=hid_dim, num_heads=num_heads, num_seeds=1)
        elif pool == "mean":
            self.pool = MeanPool()

        self.decoder = fc_stack(dec_depth, hid_dim, hid_dim, 1)

    def forward(self, inputs):
        # print("input shape:", inputs["te_xp"].shape)
        x = self.encoder(inputs)
        # print("encoded shape:", x.shape)
        # x = self.pool(x)
        # print("pool shape:", x.shape)
        x = self.decoder(x)
        return x

## Regression Model

In [176]:
def get_learner(batch_size, layers, hidden_size, activation=None, regularizer=None, task='regression', init_dim=23, num_outputs=10, seq_len=20):
    if activation == "relu":
        activation = nn.ReLU
    elif activation == "sigmoid":
        activation = nn.Sigmoid
    elif activation == "tanh":
        activation = nn.Tanh
    elif activation is None:
        activation = nn.Identity
    else:
        raise ValueError(f"activation={activation} not implemented!")
        
    if task == 'regression':
        return RegressionNeuralNetwork(
            batch_size,
            num_layers=layers,
            hidden_size=hidden_size,
            activation=activation,
            regularizer=regularizer,
            init_dim=init_dim,
            num_outputs=num_outputs,
        )
    elif task == 'timeseries':
        return TimeSeries(
            batch_size,
            num_layers=layers,
            hidden_size=hidden_size,
            n_features=init_dim,
            num_outputs=num_outputs,
            seq_length=seq_len
        )
    elif task == 'classification':
        raise NotImplementedError
        return ParallelNeuralNetwork(
            batch_size,
            num_layers=layers,
            hidden_size=hidden_size,
            activation=activation,
            regularizer=regularizer,
            output_activation=nn.Softmax(dim=1),
        )


class RegressionNeuralNetwork(nn.Module):
    def __init__(self, batch_size, num_layers, init_dim, hidden_size, activation, num_outputs, regularizer=None):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(init_dim, hidden_size))
        for _ in range(num_layers - 1):
            self.layers.append(activation())
            self.layers.append(
                nn.Linear(hidden_size, hidden_size)
            )
            if regularizer == "dropout":
                self.layers.append(nn.Dropout())

        self.layers.append(activation())
        self.layers.append(nn.Linear(hidden_size, num_outputs))
        self.activation = activation
        self.regularizer = regularizer

    def forward(self, x):
        # print(x.shape)
        for i, layer in enumerate(self.layers):
            # print(f"In layer {i+1}, Shape={x.shape}", end=" ")
            x = layer(x)
            # print(f"Ouput Shape = {x.shape}")
        return x



class TimeSeries(torch.nn.Module):
    def __init__(self, batch_size, num_layers=1, hidden_size=20,
                n_features=18,
                num_outputs=22,
                seq_length=20):
        super(TimeSeries, self).__init__()
        self.n_features = n_features
        self.seq_len = seq_length
        self.n_hidden = hidden_size
        self.n_layers = num_layers
        self.batch_size = batch_size
    
        self.lstm = torch.nn.LSTM(input_size = n_features, 
                                 hidden_size = self.n_hidden,
                                 num_layers = self.n_layers, 
                                 batch_first = True)
        
        self.linear_reg = torch.nn.Linear(self.n_hidden*self.seq_len, 6)
        self.linear_cls1 = torch.nn.Linear(self.n_hidden*self.seq_len, 4)
        self.linear_cls2 = torch.nn.Linear(self.n_hidden*self.seq_len, 4)
        self.linear_cls3 = torch.nn.Linear(self.n_hidden*self.seq_len, 4)
        self.linear_cls4 = torch.nn.Linear(self.n_hidden*self.seq_len, 4)
        # self.softmax = torch.nn.Softmax()
        
    
    def init_hidden(self):
        hidden_state = torch.zeros(self.n_layers,self.batch_size,self.n_hidden).to(device)
        cell_state = torch.zeros(self.n_layers,self.batch_size,self.n_hidden).to(device)
        self.hidden = (hidden_state, cell_state)
    
    
    def forward(self, x):        
        batch_size, seq_len, _ = x.size()
        lstm_out, _ = self.lstm(x,self.hidden)
        x = lstm_out.contiguous().view(batch_size,-1)
        reg = self.linear_reg(x)
        cls1 = self.linear_cls1(x)
        cls2 = self.linear_cls2(x)
        cls3 = self.linear_cls3(x)
        cls4 = self.linear_cls4(x)
        return torch.cat([reg, cls1, cls2, cls3, cls4], dim=-1)
        

## Args

In [178]:
gpu = '0'
batch_size = 32
task_batch_size = 32
lr = 0.0005
time_budget = 10000000000.0
task = 'sine'
nc_regularize = True
epochs = 10
train_steps = 20
log_steps = 1
test_steps = 50
learn_freq = 5
inner_lr = 0.0001
inner_steps = 2
nc_weight = 1.0
learner_layers = 2
learner_hidden = 40
learner_act = 'relu'
input = 'cross_att'
enc = 'fc'
pool = 'mean'
dec = 'fc'
enc_depth = 3
dec_depth = 2
hid_dim = 512
num_heads = 8
model_path = f"results/model.ckpt"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Memory Bank

In [179]:
class MemoryBank:
    """
    Memory bank class. Stores snapshots of task learners.
    get_batch() returns a random minibatch of (snapshot, gap) for NC to train on.
    """

    def add(self, te_xp, tr_xp, tr_xyp, gap, l_train):
        if not hasattr(self, "te_xp"):
            self.te_xp = te_xp
            self.tr_xp = tr_xp
            self.tr_xyp = tr_xyp
            self.gap = gap
            self.l_train = l_train
        else:
            self.te_xp = torch.cat([self.te_xp, te_xp], dim=0)
            self.tr_xp = torch.cat([self.tr_xp, tr_xp], dim=0)
            self.tr_xyp = torch.cat([self.tr_xyp, tr_xyp], dim=0)
            self.gap = torch.cat([self.gap, gap], dim=0)
            self.l_train = torch.cat([self.l_train, l_train], dim=0)

            MEMORY_LIMIT = 1_000_000
            if self.te_xp.shape[0] > MEMORY_LIMIT:
                self.te_xp = self.te_xp[-MEMORY_LIMIT:]
                self.tr_xp = self.tr_xp[-MEMORY_LIMIT:]
                self.tr_xyp = self.tr_xyp[-MEMORY_LIMIT:]
                self.gap = self.gap[-MEMORY_LIMIT:]
                self.l_train = self.l_train[-MEMORY_LIMIT:]

    def get_batch(self, batch_size):
        N = self.te_xp.shape[0]
        assert N == self.tr_xp.shape[0]
        assert N == self.tr_xyp.shape[0]
        assert N == self.gap.shape[0]

        idxs = random.sample(range(N), k=batch_size)
        batch = {
            "te_xp": self.te_xp[idxs].to(device),
            "tr_xp": self.tr_xp[idxs].to(device),
            "tr_xyp": self.tr_xyp[idxs].to(device),
            "tr_loss": self.l_train[idxs].to(device),
        }
        return (batch, self.gap[idxs].to(device))

## Run Regression and Classification

In [180]:
def run_regression_timeseries(batch, h, train=True):
    x_train, y_train = batch["train"][0].to(device), batch["train"][1].to(device)
    x_test, y_test = batch["test"][0].to(device), batch["test"][1].to(device)

    h = get_learner(
        batch_size=32,
        layers=1,
        hidden_size=20,
        init_dim=18,
        num_outputs=22,
        task='timeseries'
        ).to(device)
    h.init_hidden()
    h_opt = torch.optim.SGD(h.parameters(), lr= inner_lr)
    h_crit_reg = nn.MSELoss(reduction="none")
    h_crit_cls = nn.CrossEntropyLoss(reduction="none")


    for _ in range( inner_steps):
        preds_train = h(x_train)
        preds_test = h(x_test)
        # print(preds_train.shape)

        te_xp = torch.cat([x_test.contiguous().view(batch_size, -1), preds_test], dim=-1)
        tr_xp = torch.cat([x_train.contiguous().view(batch_size, -1), preds_train], dim=-1)
        tr_xyp = torch.cat([x_train.contiguous().view(batch_size, -1), y_train, preds_train], dim=-1)
        
        # print("te_xp:", te_xp.shape)
        # print(preds_train[:, 6:10], y_train[:, 6])
        reg_loss_tr = h_crit_reg(preds_train[:, :6].squeeze(), y_train[:, :6].squeeze())
        cls1_loss_tr = h_crit_cls(preds_train[:, 6:10].squeeze(), y_train[:, 6].squeeze().long())
        cls2_loss_tr = h_crit_cls(preds_train[:, 10:14].squeeze(), y_train[:, 7].squeeze().long())
        cls3_loss_tr = h_crit_cls(preds_train[:, 14:18].squeeze(), y_train[:, 8].squeeze().long())
        cls4_loss_tr = h_crit_cls(preds_train[:, 18:22].squeeze(), y_train[:, 9].squeeze().long())
        l_train = torch.cat((reg_loss_tr, cls1_loss_tr.unsqueeze(1),cls2_loss_tr.unsqueeze(1), cls3_loss_tr.unsqueeze(1), cls4_loss_tr.unsqueeze(1)), dim=-1)
        h_loss = (reg_loss_tr.mean(-1).sum() + cls1_loss_tr.mean(-1).sum() + cls2_loss_tr.mean(-1).sum() + cls3_loss_tr.mean(-1).sum() + cls4_loss_tr.mean(-1).sum())/160
        meta_batch = {"te_xp": te_xp, "tr_xp": tr_xp, "tr_xyp": tr_xyp, "tr_loss": l_train}
        if  nc_regularize and global_step >  train_steps * 2:
            model_preds = model(meta_batch)
            # We sum NC outputs across tasks because h_loss is also summed.
            nc_regularization = model_preds.sum()
            h_loss += nc_regularization *  nc_weight

        h_opt.zero_grad()
        h_loss.backward()
        h_opt.step()

        reg_loss_te = h_crit_reg(preds_test[:, :6].squeeze(), y_test[:, :6].squeeze())
        cls1_loss_te = h_crit_cls(preds_test[:, 6:10].squeeze(), y_test[:, 6].squeeze().long())
        cls2_loss_te = h_crit_cls(preds_test[:, 10:14].squeeze(), y_test[:, 7].squeeze().long())
        cls3_loss_te = h_crit_cls(preds_test[:, 14:18].squeeze(), y_test[:, 8].squeeze().long())
        cls4_loss_te = h_crit_cls(preds_test[:, 18:22].squeeze(), y_test[:, 9].squeeze().long())
        # print(cls1_loss_te.shape)
        l_test = torch.cat((reg_loss_te, cls1_loss_te.unsqueeze(1), cls2_loss_te.unsqueeze(1), cls3_loss_te.unsqueeze(1), cls4_loss_te.unsqueeze(1)), dim=-1)
        l_train = torch.cat((reg_loss_tr, cls1_loss_tr.unsqueeze(1),cls2_loss_tr.unsqueeze(1), cls3_loss_tr.unsqueeze(1), cls4_loss_tr.unsqueeze(1)), dim=-1)
#mse_criterion(preds_train.squeeze(), y_train.squeeze())
        # print(l_train, l_test)
        gap = l_test.mean(-1) - l_train.mean(-1)

        if train:
            memory_bank.add(
                te_xp=te_xp.cpu().detach(),
                tr_xp=tr_xp.cpu().detach(),
                tr_xyp=tr_xyp.cpu().detach(),
                gap=gap.cpu().detach(),
                l_train=l_train.cpu().detach()
            )
    return h, meta_batch

## Defining the model

In [181]:
model = NeuralComplexity1D().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr= lr)
mse_criterion = nn.MSELoss(reduction="none")
ce_criterion = nn.CrossEntropyLoss(reduction="none")
huber_criterion = nn.HuberLoss(reduction='none')
mae_criterion = nn.L1Loss()
global_timestamp = timer()
global_step = 0
accum = Accumulator()

## Logging

In [182]:
def log_metrics(type_="train", metrics={}):
    dict = tracker[type_]
    for k, v in metrics.items():
        if k not in dict:
            dict[k] = []
        dict[k].append(v)

## Testing

In [183]:
def test(epoch, test_tasks):
    test_accum = Accumulator()
    for task in test_tasks:
        h = get_learner(
        batch_size=32,
        layers=1,
        hidden_size=20,
        init_dim=18,
        num_outputs=22,
        task='timeseries'
        ).to(device)
        h.init_hidden()
        for batch in task:
            h, meta_batch = run_regression_timeseries(batch, h, train=False)

            x_train, y_train = batch["train"][0].to(device), batch["train"][1].to(device)
            x_test, y_test = batch["test"][0].to(device), batch["test"][1].to(device)
            with torch.no_grad():
                preds_train = h(x_train)
                preds_test = h(x_test)

                reg_loss_tr = mse_criterion(preds_train[:, :6].squeeze(), y_train[:, :6].squeeze())
                cls1_loss_tr = ce_criterion(preds_train[:, 6:10].squeeze(), y_train[:, 6].squeeze().long())
                cls2_loss_tr = ce_criterion(preds_train[:, 10:14].squeeze(), y_train[:, 7].squeeze().long())
                cls3_loss_tr = ce_criterion(preds_train[:, 14:18].squeeze(), y_train[:, 8].squeeze().long())
                cls4_loss_tr = ce_criterion(preds_train[:, 18:22].squeeze(), y_train[:, 9].squeeze().long())
                reg_loss_te = mse_criterion(preds_test[:, :6].squeeze(), y_test[:, :6].squeeze())
                cls1_loss_te = ce_criterion(preds_test[:, 6:10].squeeze(), y_test[:, 6].squeeze().long())
                cls2_loss_te = ce_criterion(preds_test[:, 10:14].squeeze(), y_test[:, 7].squeeze().long())
                cls3_loss_te = ce_criterion(preds_test[:, 14:18].squeeze(), y_test[:, 8].squeeze().long())
                cls4_loss_te = ce_criterion(preds_test[:, 18:22].squeeze(), y_test[:, 9].squeeze().long())
                # print(cls1_loss_te.shape)
                l_test = torch.cat((reg_loss_te, cls1_loss_te.unsqueeze(1), cls2_loss_te.unsqueeze(1), cls3_loss_te.unsqueeze(1), cls4_loss_te.unsqueeze(1)), dim=-1)
                l_train = torch.cat((reg_loss_tr, cls1_loss_tr.unsqueeze(1),cls2_loss_tr.unsqueeze(1), cls3_loss_tr.unsqueeze(1), cls4_loss_tr.unsqueeze(1)), dim=-1)
                gap = l_test.mean(-1) - l_train.mean(-1)

                model_preds = model(meta_batch)
                loss = huber_criterion(model_preds.squeeze(), gap.squeeze()).mean()
                mae = mae_criterion(model_preds.squeeze(), gap.squeeze()).mean()

            test_accum.add_dict(
                {
                    "l_test": [l_test.mean(-1).detach().cpu()],
                    "l_train": [l_train.mean(-1).detach().cpu()],
                    "mae": [mae.item()],
                    "loss": [loss.item()],
                    "gap": [gap.squeeze().detach().cpu()],
                    "pred": [model_preds.squeeze().detach().cpu()],
                }
            )

    all_gaps = torch.cat(test_accum["gap"])
    all_preds = torch.cat(test_accum["pred"])
    R = np.corrcoef(all_gaps, all_preds)[0, 1]
    mean_l_test = torch.cat(test_accum["l_test"]).mean()
    mean_l_train = torch.cat(test_accum["l_train"]).mean()


    logger.info(f"Test epoch {epoch}")
    logger.info(
        f"mae {test_accum.mean('mae'):.2e} loss {test_accum.mean('loss'):.2e} R {R:.3f} "
        f"l_test {mean_l_test:.2e} l_train {mean_l_train:.2e} "
    )

    metrics = {
        "mae": test_accum.mean("mae"),
        "loss": test_accum.mean("loss"),
        "R": R,
        "l_test": mean_l_test.item(),
        "l_train": mean_l_train.item(),
    }
    log_metrics("test", metrics)

## Training

In [184]:
def train(train_loader):
    # This is the inner loop (basically this is the train_epoch function)
    global global_step
    for task in train_loader:
        h = get_learner(
        batch_size=32,
        layers=1,
        hidden_size=20,
        init_dim=18,
        num_outputs=22,
        task='timeseries'
        ).to(device)
        h.init_hidden()
        for batch in task:
            global_step += 1
            if global_step %  learn_freq == 0: # run the predictor after every 10 batches
                run_regression_timeseries(batch, h)

            meta_batch, gap = memory_bank.get_batch( batch_size)
            model_preds = model(meta_batch)
            loss = huber_criterion(model_preds.squeeze(), gap.squeeze()).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            mae = mae_criterion(model_preds.squeeze(), gap.squeeze())
            accum.add_dict(
                {
                    "mae": [mae.item()],
                    "loss": [loss.item()],
                    "gap": [gap.squeeze().detach().cpu()],
                    "pred": [model_preds.squeeze().detach().cpu()],
                }
            )

            if global_step % log_steps == 0:
                # torch.save(model.state_dict(), model_path)

                all_gaps = torch.cat(accum["gap"])
                all_preds = torch.cat(accum["pred"])
                R = np.corrcoef(all_gaps, all_preds)[0, 1]
                logger.info(f"Train Step {global_step}")
                logger.info(
                    f"mae {accum.mean('mae'):.2e} loss {accum.mean('loss'):.2e} R {R:.3f} "
                )

                metrics = {
                    "mae": accum.mean("mae"),
                    "loss": accum.mean("loss"),
                    "R": R,
                }
                log_metrics("train", metrics)
                # print(metrics)

            if timer() - global_timestamp >  time_budget:
                logger.info(f"Stopping at step {global_step}")
                quit()

## Data Population

In [185]:
memory_bank = MemoryBank()
populate_timestamp = timer()

task_count = 20
populate_loader = []
task_loader = []


for tasks in range(task_count):
    populate_loader = []
    X_train, y_train, X_test, y_test = sample_task()
    
    for batch in zip(X_train, y_train, X_test, y_test):
        X_tr, y_tr = batch[0].float(), batch[1].float()
        X_te, y_te = batch[2].float(), batch[3].float()
        if X_tr.shape[0] == X_te.shape[0]:
            d = {"train": [X_tr, y_tr],
                    "test": [X_te, y_te]}
            populate_loader.append(d)
    task_loader.append(populate_loader)
    
for i, task in enumerate(task_loader):
    h = get_learner(
        batch_size=32,
        layers=1,
        hidden_size=20,
        init_dim=18,
        num_outputs=22,
        task='timeseries'
    ).to(device)
    h.init_hidden()
    for j, batch in enumerate(task):
        run_regression_timeseries(batch, h)

# logger.info(f"Populate time: {timer() - populate_timestamp}")

## Main Loop

In [None]:
tracker = {"train": {}, "test":{}}

for epoch in range(epochs):
    logger.info(f"Epoch {epoch}")
    logger.info(f"Bank size: {memory_bank.te_xp.shape[0]}")

    test_timestamp = timer()
    out = test(epoch, task_loader)
    test_elapsed = timer() - test_timestamp

    train_timestamp = timer()
    out = train(task_loader)
    train_elapsed = timer() - train_timestamp
    logger.info(f"Time: train {train_elapsed:.1f} test {test_elapsed:.1f}")

    with open("logs.json", "w") as f:
        json.dump(tracker, f)