In [None]:
import os
import random
import uuid
from collections import defaultdict
from timeit import default_timer as timer

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from torch.distributions.normal import Normal
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler
# from torch.utils.tensorboard import SummaryWriter

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import os, pickle
import copy
import json

In [None]:
tracker = {"train":{}, "test":{}}

In [None]:
class TsDS(Dataset):
    def __init__(self, XL,yL,flatten=False,lno=None,long=True):
        self.samples=[]
        self.labels=[]
        self.flatten=flatten
        self.lno=lno
        self.long=long
        self.scaler = StandardScaler()
        for X,Y in zip(XL,yL):
            self.samples += [torch.tensor(X).float()]
            self.labels += [torch.tensor(Y)]
            
    def __len__(self):
        return sum([s.shape[0] for s in self.samples])

    def __getitem__(self, idx):
        if self.flatten: sample=self.samples[idx].flatten(start_dim=1)
        else: sample=self.samples[idx]
        if self.lno==None: label=self.labels[idx]
        elif self.long: label=self.labels[idx][:,self.lno].long()
        else: label=self.labels[idx][:,self.lno].float()
        return (sample,label)

    def fit(self,kind='seq'):
        if kind=='seq':
            self.lastelems=[torch.cat([s[:,-1,:] for s in self.samples],dim=0)]
            self.scaler.fit(torch.cat([le for le in self.lastelems],dim=0))            
        elif kind=='flat': self.scaler.fit(torch.cat([s for s in self.samples],dim=0))
    def scale(self,kind='flat',scaler=None):
        def cs(s):
            return (s.shape[0]*s.shape[1],s.shape[2])
        if scaler==None: scaler=self.scaler
        if kind=='seq':
            self.samples=[torch.tensor(scaler.transform(s.reshape(cs(s))).reshape(s.shape)).float() for s in self.samples]
            pass
        elif kind=='flat':
            self.samples=[torch.tensor(scaler.transform(s)).float() for s in self.samples]
    def unscale(self,kind='flat',scaler=None):
        def cs(s):
            return (s.shape[0]*s.shape[1],s.shape[2])
        if scaler==None: scaler=self.scaler
        if kind=='seq':
            self.samples=[torch.tensor(scaler.inverse_transform(s.reshape(cs(s))).reshape(s.shape)).float() for s in self.samples]
            pass
        elif kind=='flat':
            self.samples=[torch.tensor(scaler.inverse_transform(s)).float() for s in self.samples]

In [None]:
class Accumulator:
    def __init__(self):
        self.clear()

    def clear(self):
        self.metrics = defaultdict(lambda: [])

    def add(self, key, value):
        self.metrics[key] += value

    def add_dict(self, dict):
        for key, value in dict.items():
            self.add(key, value)

    def mean(self, key):
        return np.mean(self.metrics[key])

    def __getitem__(self, item):
        return self.metrics[item]

    def __setitem__(self, key, value):
        self.metrics[key] = value

    def get_dict(self):
        return copy.deepcopy(dict(self.metrics))

    def items(self):
        return self.metrics.items()

    def __str__(self):
        return str(dict(self.metrics))

In [None]:
import torch
from torchmeta.toy import Sinusoid
from torchmeta.transforms import ClassSplitter
from torchmeta.utils.data import BatchMetaDataLoader


class ToTensor1D(object):
    """Convert a `numpy.ndarray` to tensor. Unlike `ToTensor` from torchvision,
    this converts numpy arrays regardless of the number of dimensions.

    Converts automatically the array to `float32`.
    """

    def __call__(self, array):
        return torch.tensor(array.astype("float32"))

    def __repr__(self):
        return self.__class__.__name__ + "()"


def get_sine_loader(batch_size, num_steps, shots=10, test_shots=15):
    dataset_transform = ClassSplitter(
        shuffle=True, num_train_per_class=shots, num_test_per_class=test_shots
    )
    transform = ToTensor1D()
    dataset = Sinusoid(
        shots + test_shots,
        num_tasks=batch_size * num_steps,
        transform=transform,
        target_transform=transform,
        dataset_transform=dataset_transform,
    )
    loader = BatchMetaDataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True,
    )
    return loader

In [None]:
def get_numbers(name):
    splitted = name.split('_')
    g, d = (splitted[2]), int(splitted[3])
    return g, d

folder_path = os.path.join("marketdata")
l = os.listdir(folder_path)

data_type = "ds"
meta_train = {"train": [], "test": []}
meta_test = {"train": [], "test": []}

for file in l:
    if data_type in file:
        type_ = "train" if "train" in file else "test"
        g, d = get_numbers(file)
        if d < 20: # for meta-training
            meta_train[type_].append(file)
        else: # for meta-testing
            meta_test[type_].append(file)


meta_train["train"] = sorted(meta_train["train"])
meta_train["test"] = sorted(meta_train["test"])

data = list(zip(meta_train["train"], meta_train["test"]))
data = sorted(data, key=lambda x: get_numbers(x[0])[1])
idx = 0

def load_task(task):
    """
    task is a tuple of strings of the form (train_cs_g_d_2.pkl, test_cs_g_d_2.pkl)
    returns X_train, y_train, X_test, y_test
    """
    train_file, test_file = task
    train_data = pickle.load(open(os.path.join(folder_path, train_file), "rb"))
    test_data = pickle.load(open(os.path.join(folder_path, test_file), "rb"))
    return train_data.samples, train_data.labels, test_data.samples, test_data.labels

def sample_task():
    global idx
    if idx >= len(data):
        idx = 0
    task = data[idx]
    idx += 1
    
    return load_task(task)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim_query, dim_key, dim_value, dim_output, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_query, dim_output, bias=False)
        self.fc_k = nn.Linear(dim_key, dim_output, bias=False)
        self.fc_v = nn.Linear(dim_value, dim_output, bias=False)
        self.fc_o = nn.Linear(dim_output, dim_output)

    def forward(self, query, key, value, mask=None):
        query = self.fc_q(query)
        key = self.fc_k(key)
        value = self.fc_v(value)

        query_ = torch.cat(query.chunk(self.num_heads, -1), 0)
        key_ = torch.cat(key.chunk(self.num_heads, -1), 0)
        value_ = torch.cat(value.chunk(self.num_heads, -1), 0)

        A_logits = (query_ @ key_.transpose(-2, -1)) / math.sqrt(query.shape[-1])
        if mask is not None:
            mask = torch.stack([mask.squeeze(-1)] * query.shape[-2], -2)
            mask = torch.cat([mask] * self.num_heads, 0)
            A_logits.masked_fill(mask, -float("inf"))
            A = torch.softmax(A_logits, -1)
        else:
            A = torch.softmax(A_logits, -1)

        outs = torch.cat((A @ value_).chunk(self.num_heads, 0), -1)
        outs = query + outs
        outs = outs + F.relu(self.fc_o(outs))
        return outs


class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds):
        super().__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mha = MultiHeadAttention(dim, dim, dim, dim, num_heads)

    def forward(self, X):
        batch_size = X.size(0)
        query = self.S.repeat(batch_size, 1, 1)
        return self.mha(query, X, X).squeeze()

## NC Model

In [None]:
def fc_stack(num_layers, input_dim, hidden_dim, output_dim):
    if num_layers == 0:
        return nn.Identity()
    elif num_layers == 1:
        return nn.Linear(input_dim, output_dim)
    else:
        modules = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(num_layers - 2):
            modules.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
        modules.append(nn.Linear(hidden_dim, output_dim))
        return nn.Sequential(*modules)


class CrossAttEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        dim = hid_dim

        self.mlp_v = fc_stack(enc_depth, 380, dim, dim)
        self.mlp_qk = fc_stack(enc_depth, 370, dim, dim)
        self.attn = MultiHeadAttention(dim, dim, dim, dim, num_heads)

    def forward(self, inputs):
        q = self.mlp_qk(inputs["te_xp"])
        k = self.mlp_qk(inputs["tr_xp"])
        v = self.mlp_v(inputs["tr_xyp"])
        out = self.attn(q, k, v)
        return out


class MeanPool(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # assert len(x.shape) == 3
        return x.mean(0)


class NeuralComplexity1D(nn.Module):
    def __init__(self):
        super().__init__()
        self.bs = batch_size
        self.encoder = CrossAttEncoder()

        if pool == "pma":
            self.pool = PMA(dim=hid_dim, num_heads=num_heads, num_seeds=1)
        elif pool == "mean":
            self.pool = MeanPool()

        self.decoder = fc_stack(dec_depth, hid_dim, hid_dim, 1)

    def forward(self, inputs):
        # print("input shape:", inputs["te_xp"].shape)
        x = self.encoder(inputs)
        # print("encoded shape:", x.shape)
        # x = self.pool(x)
        # print("pool shape:", x.shape)
        x = self.decoder(x)
        return x

## Regression Model

In [None]:
def get_learner(batch_size, layers, hidden_size, activation=None, regularizer=None, task='regression', init_dim=23, num_outputs=10, seq_len=20):
    if activation == "relu":
        activation = nn.ReLU
    elif activation == "sigmoid":
        activation = nn.Sigmoid
    elif activation == "tanh":
        activation = nn.Tanh
    elif activation is None:
        activation = nn.Identity
    else:
        raise ValueError(f"activation={activation} not implemented!")
        
    if task == 'regression':
        return RegressionNeuralNetwork(
            batch_size,
            num_layers=layers,
            hidden_size=hidden_size,
            activation=activation,
            regularizer=regularizer,
            init_dim=init_dim,
            num_outputs=num_outputs,
        )
    elif task == 'timeseries':
        return TimeSeries(
            batch_size,
            num_layers=layers,
            hidden_size=hidden_size,
            n_features=init_dim,
            num_outputs=num_outputs,
            seq_length=seq_len
        )
    elif task == 'classification':
        raise NotImplementedError
        return ParallelNeuralNetwork(
            batch_size,
            num_layers=layers,
            hidden_size=hidden_size,
            activation=activation,
            regularizer=regularizer,
            output_activation=nn.Softmax(dim=1),
        )



class ParallelLinear(nn.Module):
    def __init__(self, bs, input_size, output_size):
        super().__init__()
        fcs = [nn.Linear(input_size, output_size) for _ in range(bs)]
        self.weight = Parameter(torch.stack([m.weight for m in fcs]))
        self.bias = Parameter(torch.stack([m.bias for m in fcs]).unsqueeze(1))

    def forward(self, x):
        return torch.einsum("bnd,bmd->bnm", x, self.weight) + self.bias


class RegressionNeuralNetwork(nn.Module):
    def __init__(self, batch_size, num_layers, init_dim, hidden_size, activation, num_outputs, regularizer=None):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(init_dim, hidden_size))
        for _ in range(num_layers - 1):
            self.layers.append(activation())
            self.layers.append(
                nn.Linear(hidden_size, hidden_size)
            )
            if regularizer == "dropout":
                self.layers.append(nn.Dropout())

        self.layers.append(activation())
        self.layers.append(nn.Linear(hidden_size, num_outputs))
        self.activation = activation
        self.regularizer = regularizer

    def forward(self, x):
        # print(x.shape)
        for i, layer in enumerate(self.layers):
            # print(f"In layer {i+1}, Shape={x.shape}", end=" ")
            x = layer(x)
            # print(f"Ouput Shape = {x.shape}")
        return x


class ParallelNeuralNetwork(nn.Module):
    """ Equivalent to running  batch_size neural networks in parallel. No weight sharing. """

    def __init__(self, bs, num_layers, hidden_size, activation, regularizer):
        super().__init__()
        self.bs = bs
        modules = [ParallelLinear(bs, 1, hidden_size)]
        for _ in range(num_layers - 1):
            modules.append(activation())
            modules.append(ParallelLinear(bs, hidden_size, hidden_size))
            if regularizer == "dropout":
                modules.append(nn.Dropout())
            if regularizer == "g_dropout":
                modules.append(GaussianDropout(alpha=1.0))
            if regularizer == "v_dropout":
                modules.append(VariationalDropout(alpha=1.0, dim=hidden_size))
            if regularizer == "alpha_dropout":
                modules.append(nn.AlphaDropout(p=0.5))
            if regularizer == "batchnorm":
                # Parallel batchnorm is equivalent to layernorm in this case
                modules.append(nn.LayerNorm(hidden_size, elementwise_affine=False))
        modules.append(activation())
        modules.append(ParallelLinear(bs, hidden_size, 1))
        self.net = nn.Sequential(*modules)

    def forward(self, x):
        if x.shape[0] != self.bs:
            assert x.shape[0] == 1
            x = x.repeat(self.bs, 1, 1)
        return self.net(x)

    @staticmethod
    def l1(weight):
        return weight.view(weight.shape[0], -1).abs().sum(-1)

    @staticmethod
    def l2(weight):
        return weight.view(weight.shape[0], -1).pow(2).sum(-1)

    @staticmethod
    def norm(weight, p=2, q=2):
        return weight.norm(p=p, dim=2).norm(q, dim=1)

    @staticmethod
    def op_norm(weight, p=float("Inf")):
        _, S, _ = weight.svd()
        return S.norm(p, dim=-1)

    @staticmethod
    def orthogonal_loss(weight):
        bs, n, _ = weight.shape
        sym = torch.bmm(weight, weight.transpose(2, 1))
        eyes = [torch.eye(n, device="cuda") for _ in range(bs)]
        sym -= torch.stack(eyes)
        return sym.abs().sum()

    def get_measure(self, name):
        # https://github.com/bneyshabur/generalization-bounds/blob/master/measures.py
        linears = [p for p in self.modules() if isinstance(p, ParallelLinear)]
        ws = [p.weight for p in linears]
        bs = [p.bias for p in linears]
        ps = ws + bs

        inf = float("Inf")

        if name == "L1":
            return torch.stack([self.l1(p) for p in ps]).sum(0)
        elif name == "L2":
            return torch.stack([self.l2(p) for p in ps]).sum(0)
        elif name == "L_{1,inf}":
            return torch.stack([self.norm(w, p=1, q=inf) for w in ws]).prod(0)
        elif name == "Frobenius":
            return torch.stack([self.norm(w, p=2, q=2) for w in ws]).prod(0)
        elif name == "L_{3,1.5}":
            return torch.stack([self.norm(w, p=3, q=1.5) for w in ws]).prod(0)
        elif name == "Orthogonal":
            # https://arxiv.org/abs/1609.07093
            return torch.stack([self.orthogonal_loss(w) for w in ws]).sum()
        elif name == "Spectral":
            return torch.stack([self.op_norm(w, p=inf) for w in ws]).prod(0)
        elif name == "L_1.5_op":
            return torch.stack([self.op_norm(w, p=1.5) for w in ws]).prod(0)
        elif name == "Trace":
            return torch.stack([self.op_norm(w, p=1) for w in ws]).prod(0)
        else:
            raise ValueError(f"Measure {name} is not implemented.")

    def get_measures(self):
        measure_names = [
            "L1",
            "L2",
            "L_{1,inf}",
            "Frobenius",
            "L_{3,1.5}",
            # "Spectral",
            # "L_1.5_op",
            # "Trace",
        ]
        return {name: self.get_measure(name) for name in measure_names}


class GaussianDropout(nn.Module):
    def __init__(self, alpha=1.0):
        super(GaussianDropout, self).__init__()
        self.alpha = alpha

    def forward(self, x):
        if self.train():
            epsilon = torch.randn(x.size()) * self.alpha + 1
            return x * epsilon.cuda()
        else:
            return x


class VariationalDropout(nn.Module):
    def __init__(self, alpha=1.0, dim=None):
        super(VariationalDropout, self).__init__()

        self.dim = dim
        self.max_alpha = alpha
        log_alpha = (torch.ones(dim) * alpha).log()
        self.log_alpha = nn.Parameter(log_alpha)

    def kl(self):
        c1 = 1.16145124
        c2 = -1.50204118
        c3 = 0.58629921

        alpha = self.log_alpha.exp()

        negative_kl = (
            0.5 * self.log_alpha + c1 * alpha + c2 * alpha ** 2 + c3 * alpha ** 3
        )

        kl = -negative_kl

        return kl.mean()

    def forward(self, x):
        """
        Sample noise   e ~ N(1, alpha)
        Multiply noise h = h_ * e
        """
        if self.train():
            # N(0,1)
            epsilon = torch.randn(x.size()).cuda()

            # Clip alpha
            self.log_alpha.data = torch.clamp(self.log_alpha.data, max=self.max_alpha)
            alpha = self.log_alpha.exp()

            # N(1, alpha)
            epsilon = epsilon * alpha

            return x * epsilon
        else:
            return x

class TimeSeries(torch.nn.Module):
    def __init__(self, batch_size, num_layers=1, hidden_size=20,
                n_features=18,
                num_outputs=10,
                seq_length=20):
        super(TimeSeries, self).__init__()
        self.n_features = n_features
        self.seq_len = seq_length
        self.n_hidden = hidden_size
        self.n_layers = num_layers
        self.batch_size = batch_size
    
        self.lstm = torch.nn.LSTM(input_size = n_features, 
                                 hidden_size = self.n_hidden,
                                 num_layers = self.n_layers, 
                                 batch_first = True)
        
        self.linear = torch.nn.Linear(self.n_hidden*self.seq_len, num_outputs)
        # self.softmax = torch.nn.Softmax()
        
    
    def init_hidden(self):
        hidden_state = torch.zeros(self.n_layers,self.batch_size,self.n_hidden).to(device)
        cell_state = torch.zeros(self.n_layers,self.batch_size,self.n_hidden).to(device)
        self.hidden = (hidden_state, cell_state)
    
    
    def forward(self, x):        
        batch_size, seq_len, _ = x.size()
        lstm_out, _ = self.lstm(x,self.hidden)
        # print("last_out:", lstm_out.shape)
        x = lstm_out.contiguous().view(batch_size,-1)
        # print("x")
        x = self.linear(x)
        # reg = x[:, :6]
        # cls = self.softmax(x[:, 6:])
        return x

## Args

In [None]:
gpu = '7'
batch_size = 32
task_batch_size = 32
lr = 0.0005
time_budget = 10000000000.0
task = 'sine'
nc_regularize = True
epochs = 1000
train_steps = 500
log_steps = 500
test_steps = 250
learn_freq = 10
inner_lr = 0.01
inner_steps = 16
nc_weight = 1.0
learner_layers = 2
learner_hidden = 40
learner_act = 'relu'
input = 'cross_att'
enc = 'fc'
pool = 'mean'
dec = 'fc'
enc_depth = 3
dec_depth = 2
hid_dim = 512
num_heads = 8
model_path = f"results/model.ckpt"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Memory Bank

In [None]:
class MemoryBank:
    """
    Memory bank class. Stores snapshots of task learners.
    get_batch() returns a random minibatch of (snapshot, gap) for NC to train on.
    """

    def add(self, te_xp, tr_xp, tr_xyp, gap):
        if not hasattr(self, "te_xp"):
            self.te_xp = te_xp
            self.tr_xp = tr_xp
            self.tr_xyp = tr_xyp
            self.gap = gap
        else:
            self.te_xp = torch.cat([self.te_xp, te_xp], dim=0)
            self.tr_xp = torch.cat([self.tr_xp, tr_xp], dim=0)
            self.tr_xyp = torch.cat([self.tr_xyp, tr_xyp], dim=0)
            self.gap = torch.cat([self.gap, gap], dim=0)

            MEMORY_LIMIT = 1_000_000
            if self.te_xp.shape[0] > MEMORY_LIMIT:
                self.te_xp = self.te_xp[-MEMORY_LIMIT:]
                self.tr_xp = self.tr_xp[-MEMORY_LIMIT:]
                self.tr_xyp = self.tr_xyp[-MEMORY_LIMIT:]
                self.gap = self.gap[-MEMORY_LIMIT:]

    def get_batch(self, batch_size):
        N = self.te_xp.shape[0]
        assert N == self.tr_xp.shape[0]
        assert N == self.tr_xyp.shape[0]
        assert N == self.gap.shape[0]

        idxs = random.sample(range(N), k=batch_size)
        batch = {
            "te_xp": self.te_xp[idxs].to(device),
            "tr_xp": self.tr_xp[idxs].to(device),
            "tr_xyp": self.tr_xyp[idxs].to(device),
        }
        return (batch, self.gap[idxs].to(device))

## Run Regression and Classification

In [None]:
def run_regression(batch, train=True):
    x_train, y_train = batch["train"][0].to(device), batch["train"][1].to(device)
    x_test, y_test = batch["test"][0].to(device), batch["test"][1].to(device)

    h = get_learner(
        batch_size=x_train.shape[0],
        layers= learner_layers,
        hidden_size= learner_hidden,
        activation= learner_act,
        init_dim=23,
        num_outputs=10,
        task='regression'
    ).to(device)
    h_opt = torch.optim.SGD(h.parameters(), lr= inner_lr)
    h_crit = nn.MSELoss(reduction="none")

    for _ in range( inner_steps):
        preds_train = h(x_train)
        preds_test = h(x_test)

        te_xp = torch.cat([x_test, preds_test], dim=-1)
        tr_xp = torch.cat([x_train, preds_train], dim=-1)
        tr_xyp = torch.cat([x_train, y_train, preds_train], dim=-1)
        meta_batch = {"te_xp": te_xp, "tr_xp": tr_xp, "tr_xyp": tr_xyp}

        h_loss = h_crit(preds_train.squeeze(), y_train.squeeze()).mean(-1).sum()
        if  nc_regularize and global_step >  train_steps * 2:
            model_preds = model(meta_batch)
            # We sum NC outputs across tasks because h_loss is also summed.
            nc_regularization = model_preds.sum()
            h_loss += nc_regularization *  nc_weight

        h_opt.zero_grad()
        h_loss.backward()
        h_opt.step()

        l_test = mse_criterion(preds_test.squeeze(), y_test.squeeze())
        l_train = mse_criterion(preds_train.squeeze(), y_train.squeeze())
        gap = l_test.mean(-1) - l_train.mean(-1)

        if train:
            memory_bank.add(
                te_xp=te_xp.cpu().detach(),
                tr_xp=tr_xp.cpu().detach(),
                tr_xyp=tr_xyp.cpu().detach(),
                gap=gap.cpu().detach(),
            )
    return h, meta_batch

In [None]:
def run_classification(batch, train=True):
    x_train, y_train = batch["train"][0].to(device), batch["train"][1].to(device)
    x_test, y_test = batch["test"][0].to(device), batch["test"][1].to(device)

    h = get_learner(
        batch_size=x_train.shape[0],
        layers= learner_layers,
        hidden_size= learner_hidden,
        activation= learner_act,
        task='classification',
    ).to(device)
    h_opt = torch.optim.SGD(h.parameters(), lr= inner_lr)
    h_crit = nn.CrossEntropyLoss()

    for _ in range( inner_steps):
        preds_train = h(x_train)
        preds_test = h(x_test)

        te_xp = torch.cat([x_test, preds_test], dim=-1)
        tr_xp = torch.cat([x_train, preds_train], dim=-1)
        tr_xyp = torch.cat([x_train, y_train, preds_train], dim=-1)
        meta_batch = {"te_xp": te_xp, "tr_xp": tr_xp, "tr_xyp": tr_xyp}

        h_loss = h_crit(preds_train.squeeze(), y_train.squeeze()).mean(-1).sum()
        if  nc_regularize and global_step >  train_steps * 2:
            model_preds = model(meta_batch)
            # We sum NC outputs across tasks because h_loss is also summed.
            nc_regularization = model_preds.sum()
            h_loss += nc_regularization *  nc_weight

        h_opt.zero_grad()
        h_loss.backward()
        h_opt.step()

        l_test = mse_criterion(preds_test.squeeze(), y_test.squeeze())
        l_train = mse_criterion(preds_train.squeeze(), y_train.squeeze())
        gap = l_test.mean(-1) - l_train.mean(-1)

        if train:
            memory_bank.add(
                te_xp=te_xp.cpu().detach(),
                tr_xp=tr_xp.cpu().detach(),
                tr_xyp=tr_xyp.cpu().detach(),
                gap=gap.cpu().detach(),
            )
    return h, meta_batch

In [None]:
def run_regression_timeseries(batch, train=True):
    x_train, y_train = batch["train"][0].to(device), batch["train"][1].to(device)
    x_test, y_test = batch["test"][0].to(device), batch["test"][1].to(device)

    h = get_learner(
        batch_size=x_train.shape[0],
        layers=1,
        hidden_size=20,
        init_dim=18,
        num_outputs=10,
        task='timeseries'
    ).to(device)
    h.init_hidden()
    h_opt = torch.optim.SGD(h.parameters(), lr= inner_lr)
    h_crit = nn.MSELoss(reduction="none")

    for _ in range( inner_steps):
        preds_train = h(x_train)
        preds_test = h(x_test)

        te_xp = torch.cat([x_test.contiguous().view(batch_size, -1), preds_test], dim=-1)
        tr_xp = torch.cat([x_train.contiguous().view(batch_size, -1), preds_train], dim=-1)
        tr_xyp = torch.cat([x_train.contiguous().view(batch_size, -1), y_train, preds_train], dim=-1)
        meta_batch = {"te_xp": te_xp, "tr_xp": tr_xp, "tr_xyp": tr_xyp}
        # print("te_xp:", te_xp.shape)

        h_loss = h_crit(preds_train.squeeze(), y_train.squeeze()).mean(-1).sum()
        if  nc_regularize and global_step >  train_steps * 2:
            model_preds = model(meta_batch)
            # We sum NC outputs across tasks because h_loss is also summed.
            nc_regularization = model_preds.sum()
            h_loss += nc_regularization *  nc_weight

        h_opt.zero_grad()
        h_loss.backward()
        h_opt.step()

        l_test = mse_criterion(preds_test.squeeze(), y_test.squeeze())
        l_train = mse_criterion(preds_train.squeeze(), y_train.squeeze())
        gap = l_test.mean(-1) - l_train.mean(-1)

        if train:
            memory_bank.add(
                te_xp=te_xp.cpu().detach(),
                tr_xp=tr_xp.cpu().detach(),
                tr_xyp=tr_xyp.cpu().detach(),
                gap=gap.cpu().detach(),
            )
    return h, meta_batch

## Defining the model

In [None]:
model = NeuralComplexity1D().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr= lr)
mse_criterion = nn.MSELoss(reduction="none")
mae_criterion = nn.L1Loss()
global_timestamp = timer()
global_step = 0
accum = Accumulator()

## Logging

In [None]:
def log_metrics(type_="train", metrics={}):
    dict = tracker[type_]
    for k, v in metrics.items():
        if k not in dict:
            dict[k] = []
        dict[k].append(v)

## Testing

In [None]:
def test(epoch, test_tasks):
    test_accum = Accumulator()
    for batch in test_tasks:
        h, meta_batch = run_regression_timeseries(batch, train=False)

        x_train, y_train = batch["train"][0].to(device), batch["train"][1].to(device)
        x_test, y_test = batch["test"][0].to(device), batch["test"][1].to(device)
        with torch.no_grad():
            preds_train = h(x_train)
            preds_test = h(x_test)

            l_train = mse_criterion(preds_train.squeeze(), y_train.squeeze())
            l_test = mse_criterion(preds_test.squeeze(), y_test.squeeze())
            gap = l_test.mean(-1) - l_train.mean(-1)

            model_preds = model(meta_batch)
            loss = mse_criterion(model_preds.squeeze(), gap.squeeze()).mean()
            mae = mae_criterion(model_preds.squeeze(), gap.squeeze()).mean()

        test_accum.add_dict(
            {
                "l_test": [l_test.mean(-1).detach().cpu()],
                "l_train": [l_train.mean(-1).detach().cpu()],
                "mae": [mae.item()],
                "loss": [loss.item()],
                "gap": [gap.squeeze().detach().cpu()],
                "pred": [model_preds.squeeze().detach().cpu()],
            }
        )

    all_gaps = torch.cat(test_accum["gap"])
    all_preds = torch.cat(test_accum["pred"])
    R = np.corrcoef(all_gaps, all_preds)[0, 1]
    mean_l_test = torch.cat(test_accum["l_test"]).mean()
    mean_l_train = torch.cat(test_accum["l_train"]).mean()


    logger.info(f"Test epoch {epoch}")
    logger.info(
        f"mae {test_accum.mean('mae'):.2e} loss {test_accum.mean('loss'):.2e} R {R:.3f} "
        f"l_test {mean_l_test:.2e} l_train {mean_l_train:.2e} "
    )

    metrics = {
        "mae": test_accum.mean("mae"),
        "loss": test_accum.mean("loss"),
        "R": R,
        "l_test": mean_l_test.item(),
        "l_train": mean_l_train.item(),
    }
    log_metrics("test", metrics)

## Training

In [None]:
def train(train_loader):
    # This is the inner loop (basically this is the train_epoch function)
    global global_step
    for batch in train_loader:
        global_step += 1
        if global_step %  learn_freq == 0: # run the predictor after every 10 batches
            run_regression_timeseries(batch)

        meta_batch, gap = memory_bank.get_batch( batch_size)
        model_preds = model(meta_batch)
        loss = mse_criterion(model_preds.squeeze(), gap.squeeze()).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        mae = mae_criterion(model_preds.squeeze(), gap.squeeze())
        accum.add_dict(
            {
                "mae": [mae.item()],
                "loss": [loss.item()],
                "gap": [gap.squeeze().detach().cpu()],
                "pred": [model_preds.squeeze().detach().cpu()],
            }
        )

        if global_step % log_steps == 0:
            # torch.save(model.state_dict(), model_path)

            all_gaps = torch.cat(accum["gap"])
            all_preds = torch.cat(accum["pred"])
            R = np.corrcoef(all_gaps, all_preds)[0, 1]
            logger.info(f"Train Step {global_step}")
            logger.info(
                f"mae {accum.mean('mae'):.2e} loss {accum.mean('loss'):.2e} R {R:.3f} "
            )

            metrics = {
                "mae": accum.mean("mae"),
                "loss": accum.mean("loss"),
                "R": R,
            }
            log_metrics("train", metrics)
            print(metrics)

        if timer() - global_timestamp >  time_budget:
            logger.info(f"Stopping at step {global_step}")
            quit()

## Data Population

In [None]:
memory_bank = MemoryBank()
populate_timestamp = timer()
torch.autograd.set_detect_anomaly(True)
task_count = 2
populate_loader = []
for tasks in range(task_count):
    X_train, y_train, X_test, y_test = sample_task()
    
    for batch in zip(X_train, y_train, X_test, y_test):
        X_tr, y_tr = batch[0].float(), batch[1].float()
        X_te, y_te = batch[2].float(), batch[3].float()
        if X_tr.shape[0] == X_te.shape[0]:
            d = {"train": [X_tr, y_tr],
                    "test": [X_te, y_te]}
            populate_loader.append(d)
for i, batch in enumerate(populate_loader):
    run_regression_timeseries(batch)
    # print(i+1)

logger.info(f"Populate time: {timer() - populate_timestamp}")

## Main Loop

In [None]:
tracker = {"train": {}, "test":{}}

for epoch in range(epochs):
    logger.info(f"Epoch {epoch}")
    logger.info(f"Bank size: {memory_bank.te_xp.shape[0]}")

    test_timestamp = timer()
    out = test(epoch, populate_loader)
    test_elapsed = timer() - test_timestamp

    train_timestamp = timer()
    out = train(populate_loader)
    train_elapsed = timer() - train_timestamp
    logger.info(f"Time: train {train_elapsed:.1f} test {test_elapsed:.1f}")

    with open("logs.json", "w") as f:
        json.dump(tracker, f)