# LIGN
Graph Induced Lifelong Learning for Spatial-Temporal Data

----

## Imports

In [1]:
import lign as lg
import lign.models as md
import lign.utils as utl

import torch as th
import torchvision as tv
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import datetime
tm_now = datetime.datetime.now
from torch.cuda.amp import GradScaler

----

## Preprocessing 

### Create Dataset

In [None]:
trans = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = utl.io.cifar_to_lign("data/datasets/CIFAR100", transforms = trans)
dataset.save("data/datasets/cifar100_train.lign")

validate = utl.io.cifar_to_lign("data/datasets/CIFAR100", train=False, transforms = trans)
validate.save("data/datasets/cifar100_test.lign")

### Load Dataset

In [2]:
dataset = lg.graph.GraphDataset("data/datasets/cifar100_train.lign")
validate = lg.graph.GraphDataset("data/datasets/cifar100_test.lign")

### Cuda GPUs

In [3]:
device = th.device("cuda" if th.cuda.is_available() else "cpu")

### Functions and NNs

In [4]:
def sum_neighs_data(neighs):
    out = neighs[0]
    for neigh in neighs[1:]:
        out = out + neighs
    return out

### Hyperparameters
* LAMBDA: regulates how much the model relies on difference between the nodes vs the features that lead to their label when calculating pairwise loss
* DIST_VEC_SIZE: size of vector representing the mapping of the nodes by the model
* INIT_NUM_LAB: number of labels used to training the model initially in the supervised method to learn pairwise mapping
* LABELS: list of all the labels that model comes across. Labels can be appended at any time. The order of labels is initially randomized
* SUBGRAPH_SIZE: represent the number of nodes processed at once. The models don't have batches. This is the closest thing to it
* AMP_ENABLE: to enable mixed precission training
* LR: Learning rate


In [5]:
LAMBDA = 0.0001
DIST_VEC_SIZE = 3 # 3 was picked so the graph can be drawn in a 3d grid
INIT_NUM_LAB = 10
LABELS = np.arange(10)
SUBGRPAH_SIZE = 200
AMP_ENABLE = True
LR = 1e-3

np.random.shuffle(LABELS)

---
## Models
### LIGN

[L]ifelong Learning [I]nduced by [G]raph [N]eural Networks Model (LIGN)

In [6]:
class LIGN_CIFAR(nn.Module):
    def __init__(self, out_feats):
        super(LIGN_CIFAR, self).__init__()
        self.gcn1 = md.layers.GCN(nn.Conv2d(3, 6, 5))
        self.gcn2 = md.layers.GCN(nn.Conv2d(6, 16, 5))
        self.gcn3 = md.layers.GCN(nn.Linear(16 * 5 * 5, 150))
        self.gcn4 = md.layers.GCN(nn.Linear(150, 84))
        self.gcn5 = md.layers.GCN(nn.Linear(84, out_feats))
        self.pool = md.layers.GCN(nn.MaxPool2d(2, 2))

    def forward(self, g, features):
        x = self.pool(F.relu(self.gcn1(g, features)))
        x = self.pool(F.relu(self.gcn2(g, x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.gcn3(g, x))
        x = F.relu(self.gcn4(g, x))
        
        return th.tanh(self.gcn5(g, x))

model = LIGN_CIFAR(DIST_VEC_SIZE).to(device)

### R-LIGN
[R]ecurrent [L]ifelong Learning [I]nduced by [G]raph [N]eural Networks Model (R-LIGN)

In [None]:
#dataset.set_data("h", )
#dataset.set_data("c", )
####
# model = R_LIGN(DIST_VEC_SIZE)

----
## Training
### Parameters

In [7]:
#opt
accuracy = []
log = []
num_of_labels = len(LABELS)
opt = th.optim.Adam(model.parameters(), lr=LR)
scaler = GradScaler() if AMP_ENABLE else None

### Load Pre-Trained Model

In [None]:
checkpoint = torch.load('data/models/LIGN_training_cool_time.pt')

model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

if AMP_ENABLE:
    scaler.load_state_dict(checkpoint['scaler'])

### Train Model

In [8]:
lg.train.superv(model, opt, dataset, "x", "labels", DIST_VEC_SIZE, LABELS[:INIT_NUM_LAB], LAMBDA, (device, scaler), subgraph_size=SUBGRPAH_SIZE)

for num_labels in range(INIT_NUM_LAB, num_of_labels + 1):

    retrain_track = num_labels%30
    if retrain_track == 15:
        lg.train.semi_superv(model, opt, dataset, "x", "labels", DIST_VEC_SIZE, LABELS[:num_labels], LAMBDA, (device, scaler), subgraph_size=SUBGRPAH_SIZE, cluster=(utl.clustering.NN(), 3))
    elif retrain_track == 0:
        lg.train.superv(model, opt, dataset, "x", "labels", DIST_VEC_SIZE, LABELS[:num_labels], LAMBDA, (device, scaler), subgraph_size=SUBGRPAH_SIZE)
    
    acc = lg.test.accuracy(model, validate, dataset, "x", "labels", LABELS[:num_labels], cluster=(utl.clustering.NN(), 3), device=device)
    accuracy.append(acc)
    log.append("Label: {}/{} | Accuracy: {} | Unsurpervised Retraining: {} | Surpervised Retraining: {}".format(num_labels, num_of_labels, acc, retrain_track == 15, retrain_track == 0))
    print(log[-1])


tensor([[[[[ 0.2784,  0.4353,  0.6235,  ...,  0.6941,  0.7176,  0.7490],
           [ 0.1373,  0.2549,  0.3882,  ...,  0.5922,  0.5686,  0.7490],
           [ 0.1686,  0.2235,  0.3020,  ...,  0.5843,  0.4588,  0.6706],
           ...,
           [ 0.5608,  0.5373,  0.4510,  ...,  0.2392,  0.1765,  0.2863],
           [ 0.2000,  0.4118,  0.5216,  ...,  0.2078,  0.2549,  0.3725],
           [-0.0196,  0.0745,  0.2235,  ...,  0.2235,  0.3333,  0.5451]],

          [[-0.0588,  0.0824,  0.1843,  ...,  0.3882,  0.5137,  0.6314],
           [-0.1686, -0.0667,  0.0039,  ...,  0.2706,  0.2941,  0.5765],
           [-0.1059, -0.0431, -0.0118,  ...,  0.2392,  0.1137,  0.4353],
           ...,
           [ 0.2392,  0.2784,  0.2314,  ...,  0.0667, -0.0039,  0.0824],
           [-0.1451,  0.1294,  0.2784,  ...,  0.0039,  0.0667,  0.1843],
           [-0.3725, -0.2314, -0.0431,  ..., -0.0196,  0.1373,  0.3725]],

          [[-0.1451, -0.0902, -0.1216,  ...,  0.0353,  0.1294,  0.3961],
           [-0.

AttributeError: module 'torch.cuda.amp' has no attribute 'autocast'

### Save State

In [None]:
time = tm_now()

## Save metrics
metrics = {
    "accuracy": accuracy,
    "log": log
}
utl.io.json(metrics, "data/metrics/LIGN_training_"+str(time)+".json")

## Save model checkpoint
check = {
    "model": model.state_dict(),
    "optimizer": opt.state_dict()
}
if AMP_ENABLE:
    check["scaler"] = scaler.state_dict()

th.save(check, "data/models/LIGN_training_"+str(time)+".pt")

---
## View
### Performance

### Graph