# LIGN
Graph Induced Lifelong Learning for Spatial-Temporal Data

----

## Imports

In [1]:
import lign as lg
import lign.models as md
import lign.utils as utl

import torch as th
import torchvision as tv
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler

import numpy as np
import datetime
tm_now = datetime.datetime.now

----

## Preprocessing 

### Create Dataset

In [None]:
dataset_name = "CIFAR" #<<<<<

trans = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = utl.load.cifar_to_lign("data/datasets/CIFAR100", transforms = trans)
dataset.save("data/datasets/cifar100_train.lign")

validate = utl.load.cifar_to_lign("data/datasets/CIFAR100", train=False, transforms = trans)
validate.save("data/datasets/cifar100_test.lign")

In [None]:
dataset_name = "MNIST" #<<<<<

trans = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.1307,), (0.3081,))
])

dataset = utl.load.mnist_to_lign("data/datasets/MNIST", transforms = trans)
dataset.save("data/datasets/mnist_train.lign")

validate = utl.load.mnist_to_lign("data/datasets/MNIST", train=False, transforms = trans)
validate.save("data/datasets/mnist_test.lign")

### Load Dataset

In [2]:
dataset_name = "CIFAR" #<<<<<

dataset = lg.graph.GraphDataset("data/datasets/cifar100_train.lign")
validate = lg.graph.GraphDataset("data/datasets/cifar100_test.lign")

In [None]:
dataset_name = "MNIST" #<<<<<

dataset = lg.graph.GraphDataset("data/datasets/mnist_train.lign")
validate = lg.graph.GraphDataset("data/datasets/mnist_test.lign")

### Cuda GPUs

In [3]:
if th.cuda.is_available():
    device = th.device("cuda")
    th.cuda.empty_cache()
else:
    device = th.device("cpu")

### Functions and NNs

In [4]:
def sum_neighs_data(neighs): ## adds up neighbors' data before executing post_mod (pre_mod happens before)
    out = neighs[0]
    for neigh in neighs[1:]:
        out = out + neigh
    return out

class ADDON(nn.Module): ## tempory layer for training
    def __init__(self, in_fea, out_fea):
        super(ADDON, self).__init__()
        self.gcn1 = md.layers.GCN(nn.Linear(in_fea, out_fea))
    
    def forward(self, g, features):
        x = self.gcn1(g, features)
        return x


### Hyperparameters
* LAMBDA: regulates how much the model relies on difference between the nodes vs the features that lead to their label when calculating pairwise loss
* DIST_VEC_SIZE: size of vector representing the mapping of the nodes by the model
* INIT_NUM_LAB: number of labels used to training the model initially in the supervised method to learn pairwise mapping
* LABELS: list of all the labels that model comes across. Labels can be appended at any time. The order of labels is initially randomized
* SUBGRAPH_SIZE: represent the number of nodes processed at once. The models don't have batches. This is the closest thing to it
* AMP_ENABLE: toggle to enable mixed precission training
* EPOCHS: Loops executed during training
* LR: Learning rate
* RETRAIN_PER: period between retraining based on number of labels seen. format: (offset, period)

In [9]:
LAMBDA = 0.001
DIST_VEC_SIZE = 2
INIT_NUM_LAB = 4
LABELS = np.arange(10)
SUBGRPAH_SIZE = 1000
AMP_ENABLE = True
EPOCHS = 1000
LR = 1e-3
RETRAIN_PER = {
    "superv": (4, 10),
    "semi": (0, 15)
}

np.random.shuffle(LABELS)

---
## Models
### LIGN

[L]ifelong Learning [I]nduced by [G]raph [N]eural Networks Model (LIGN)

In [10]:
class LIGN_CIFAR(nn.Module):
    def __init__(self, out_feats):
        super(LIGN_CIFAR, self).__init__()
        self.gcn1 = md.layers.GCN(nn.Conv2d(3, 6, 5))
        self.gcn2 = md.layers.GCN(nn.Conv2d(6, 16, 5))
        self.gcn3 = md.layers.GCN(nn.Linear(16 * 5 * 5, 150))
        self.gcn4 = md.layers.GCN(nn.Linear(150, 84))
        self.gcn5 = md.layers.GCN(nn.Linear(84, out_feats))
        self.pool = md.layers.GCN(nn.MaxPool2d(2, 2))

    def forward(self, g, features):
        x = self.pool(g, F.relu(self.gcn1(g, features)))
        x = self.pool(g, F.relu(self.gcn2(g, x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.gcn3(g, x))
        x = F.relu(self.gcn4(g, x))
        
        return th.tanh(self.gcn5(g, x))

model = LIGN_CIFAR(DIST_VEC_SIZE).to(device)

### R-LIGN
[R]ecurrent [L]ifelong Learning [I]nduced by [G]raph [N]eural Networks Model (R-LIGN)

In [None]:
#dataset.set_data("h", )
#dataset.set_data("c", )
####
# model = R_LIGN(DIST_VEC_SIZE)

----
## Training
### Parameters

In [11]:
#opt
accuracy = []
log = []
num_of_labels = len(LABELS)
opt = th.optim.Adam(model.parameters(), lr=LR)
scaler = GradScaler() if AMP_ENABLE else None

retrain_superv = lambda x: (x + RETRAIN_PER["superv"][0])%RETRAIN_PER["superv"][1] == 0
retrain_semi = lambda x: (x + RETRAIN_PER["semi"][0])%RETRAIN_PER["semi"][1] == 0

### Load State

In [None]:
checkpoint = th.load('data/models/LIGN_training_cool_time.pt')

model.load_state_dict(checkpoint['model'])
opt.load_state_dict(checkpoint['optimizer'])

if AMP_ENABLE:
    scaler.load_state_dict(checkpoint['scaler'])

### Train Model

In [12]:

lg.train.superv(model, opt, dataset, "x", "labels", DIST_VEC_SIZE, LABELS[:INIT_NUM_LAB], LAMBDA, (device, scaler), addon = ADDON, subgraph_size=SUBGRPAH_SIZE, epochs=EPOCHS)

for num_labels in range(INIT_NUM_LAB, num_of_labels + 1):

    """if retrain_semi(num_labels):
        lg.train.semi_superv(model, opt, dataset, "x", "labels", DIST_VEC_SIZE, LABELS[:num_labels], LAMBDA, (device, scaler), addon = ADDON, subgraph_size=SUBGRPAH_SIZE, epochs=EPOCHS, cluster=(utl.clustering.NN(), 5))"""

    if retrain_superv(num_labels):
        lg.train.superv(model, opt, dataset, "x", "labels", DIST_VEC_SIZE, LABELS[:num_labels], LAMBDA, (device, scaler), epochs=EPOCHS, addon = ADDON, subgraph_size=SUBGRPAH_SIZE)
    
    acc = lg.test.accuracy(model, validate, dataset, "x", "labels", LABELS[:num_labels], cluster=(utl.clustering.NN(), 5), sv_img = '2d', device=device)

    accuracy.append(acc)
    log.append("Label: {}/{}\t|\tAccuracy: {}\t|\tSemisurpervised Retraining: {}\t|\tSurpervised Retraining: {}".format(num_labels, num_of_labels, round(acc, 2), retrain_semi(num_labels), retrain_superv(num_labels)))
    print(log[-1])


Label: 20/40	|	Accuracy: 5.8	|	Semisurpervised Retraining: False	|	Surpervised Retraining: False
Label: 21/40	|	Accuracy: 5.52	|	Semisurpervised Retraining: False	|	Surpervised Retraining: False
Label: 22/40	|	Accuracy: 5.09	|	Semisurpervised Retraining: False	|	Surpervised Retraining: False
Label: 23/40	|	Accuracy: 5.0	|	Semisurpervised Retraining: False	|	Surpervised Retraining: False
Label: 24/40	|	Accuracy: 4.75	|	Semisurpervised Retraining: False	|	Surpervised Retraining: False
Label: 25/40	|	Accuracy: 4.56	|	Semisurpervised Retraining: False	|	Surpervised Retraining: False
Label: 26/40	|	Accuracy: 4.27	|	Semisurpervised Retraining: False	|	Surpervised Retraining: True
Label: 27/40	|	Accuracy: 4.11	|	Semisurpervised Retraining: False	|	Surpervised Retraining: False
Label: 28/40	|	Accuracy: 3.89	|	Semisurpervised Retraining: False	|	Surpervised Retraining: False
Label: 29/40	|	Accuracy: 3.86	|	Semisurpervised Retraining: False	|	Surpervised Retraining: False
Label: 30/40	|	Accuracy

### Save State

In [None]:

time = str(tm_now()).replace(":", "-").replace(".", "").replace(" ", "_")
filename = "LIGN_" + dataset_name + "_training_"+time

## Save metrics
metrics = {
    "accuracy": accuracy,
    "log": log
}
utl.io.json(metrics, "data/metrics/"+filename+".json")

## Save hyperparameters
para = {
    "LAMBDA": LAMBDA,
    "DIST_VEC_SIZE": DIST_VEC_SIZE,
    "INIT_NUM_LAB": INIT_NUM_LAB,
    "LABELS": LABELS.tolist(),
    "SUBGRPAH_SIZE": SUBGRPAH_SIZE,
    "AMP_ENABLE": AMP_ENABLE,
    "EPOCHS": EPOCHS,
    "LR": LR,
    "RETRAIN_PER": RETRAIN_PER
}

utl.io.json(para, "data/parameters/"+filename+".json")

## Save model
check = {
    "model": model.state_dict(),
    "optimizer": opt.state_dict()
}
if AMP_ENABLE:
    check["scaler"] = scaler.state_dict()

th.save(check, "data/models/"+filename+".pt")
    

---
## View
### Performance

In [None]:
print(LAMBDA)