# LIGN
Graph Induced Lifelong Learning for Spatial-Temporal Data

----

## Imports

In [1]:
import lign as lg
import lign.models as md
import lign.utils as utl

import torch as th
import torchvision as tv
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler

import numpy as np
import datetime
tm_now = datetime.datetime.now

----

## Preprocessing 

### Create Dataset

In [None]:
dataset_name = "CIFAR" #<<<<<

trans = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = utl.load.cifar_to_lign("data/datasets/CIFAR100", transforms = trans)
dataset.save("data/datasets/cifar100_train.lign")

validate = utl.load.cifar_to_lign("data/datasets/CIFAR100", train=False, transforms = trans)
validate.save("data/datasets/cifar100_test.lign")

In [None]:
dataset_name = "MNIST" #<<<<<

trans = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.1307,), (0.3081,))
])

dataset = utl.load.mnist_to_lign("data/datasets/MNIST", transforms = trans)
dataset.save("data/datasets/mnist_train.lign")

validate = utl.load.mnist_to_lign("data/datasets/MNIST", train=False, transforms = trans)
validate.save("data/datasets/mnist_test.lign")

### Load Dataset

In [None]:
dataset_name = "CIFAR" #<<<<<

dataset = lg.graph.GraphDataset("data/datasets/cifar100_train.lign")
validate = lg.graph.GraphDataset("data/datasets/cifar100_test.lign")

In [2]:
dataset_name = "MNIST" #<<<<<

dataset = lg.graph.GraphDataset("data/datasets/mnist_train.lign")
validate = lg.graph.GraphDataset("data/datasets/mnist_test.lign")

In [None]:
model.

### Cuda GPUs

In [3]:
if th.cuda.is_available():
    device = th.device("cuda")
    th.cuda.empty_cache()
else:
    device = th.device("cpu")

### Functions and NNs

In [4]:
def sum_neighs_data(neighs): ## adds up neighbors' data before executing post_mod (pre_mod happens before)
    out = neighs[0]
    for neigh in neighs[1:]:
        out = out + neigh
    return out

class ADDON(nn.Module): ## tempory layer for training
    def __init__(self, in_fea, out_fea, device = 'cuda'):
        super(ADDON, self).__init__()
        self.addon = lg.layers.ADDON(in_fea, out_fea, device=device)
    
    def forward(self, g, features):
        x = F.log_softmax(self.addon(g, features), dim=1)
        return x


### Hyperparameters
* LAMBDA: regulates how much the model relies on difference between the nodes vs the features that lead to their label when calculating pairwise loss
* DIST_VEC_SIZE: size of vector representing the mapping of the nodes by the model
* INIT_NUM_LAB: number of labels used to training the model initially in the supervised method to learn pairwise mapping
* LABELS: list of all the labels that model comes across. Labels can be appended at any time. The order of labels is initially randomized
* SUBGRAPH_SIZE: represent the number of nodes processed at once. The models don't have batches. This is the closest thing to it
* AMP_ENABLE: toggle to enable mixed precission training
* EPOCHS: Loops executed during training
* LR: Learning rate
* RETRAIN_PER: period between retraining based on number of labels seen. format: (offset, period)

In [5]:
LAMBDA = 0.5
DIST_VEC_SIZE = 2 #128
INIT_NUM_LAB = 4
LABELS = np.arange(10)
SUBGRPAH_SIZE = 500
AMP_ENABLE = True
EPOCHS = 150
LR = 1e-3
RETRAIN_PER = {
    "superv": (5, 2), # (offset, frequency)
    #"semi": (0, 15)
}

#np.random.shuffle(LABELS)

---
## Models
### LIGN

[L]ifelong Learning [I]nduced by [G]raph [N]eural Networks Model (LIGN)

In [6]:
class LIGN_MNIST(nn.Module):
    def __init__(self, out_feats):
        super(LIGN_MNIST, self).__init__()
        self.gcn1 = lg.layers.GCN(nn.Conv2d(1, 32, 3, 1))
        self.gcn2 = lg.layers.GCN(nn.Conv2d(32, 64, 3, 1))
        self.gcn3 = lg.layers.GCN(nn.Linear(9216, out_feats))
        self.drop1 = nn.Dropout(0.25)
        self.drop2 = nn.Dropout(0.5)

    def forward(self, g, features):
        x = F.relu(self.gcn1(g, features))
        x = F.relu(self.gcn2(g, x))
        x = F.max_pool2d(x, 2)
        x = th.flatten(self.drop1(x), 1)
        x = F.relu(self.gcn3(g, x))

        return self.drop2(x)

model = LIGN_MNIST(DIST_VEC_SIZE).to(device)

----
## Training
### Parameters

In [7]:
#opt
accuracy = []
log = []
num_of_labels = len(LABELS)
opt = th.optim.Adam(model.parameters(), lr=LR)
scaler = GradScaler() if AMP_ENABLE else None
addon = ADDON(DIST_VEC_SIZE, INIT_NUM_LAB, device).to(device)

retrain_superv = lambda x: (x + RETRAIN_PER["superv"][0])%RETRAIN_PER["superv"][1] == 0
#retrain_semi = lambda x: (x + RETRAIN_PER["semi"][0])%RETRAIN_PER["semi"][1] == 0

### Load State

In [None]:
checkpoint = th.load('data/models/LIGN_training_cool_time.pt')

model.load_state_dict(checkpoint['model'])
opt.load_state_dict(checkpoint['optimizer'])

if AMP_ENABLE:
    scaler.load_state_dict(checkpoint['scaler'])

### Train Model

In [8]:
acc = lg.test.accuracy(model, validate, dataset, "x", "labels", LABELS[:INIT_NUM_LAB], cluster=(utl.clustering.NN(), 5), sv_img = '2d', device=device)

print(acc)

Prediction: 
tensor([2, 3, 0,  ..., 1, 2, 1], device='cuda:0')
True values: 
tensor([2, 1, 0,  ..., 1, 2, 3], device='cuda:0')
Vector output: 
tensor([[0.0000, 0.2880],
        [0.0000, 0.0669],
        [0.0800, 0.2165],
        ...,
        [0.0000, 0.1180],
        [0.0448, 0.3269],
        [0.0287, 0.2106]], device='cuda:0')
33.26918450805869


In [None]:

lg.train.superv(model, opt, dataset, "x", "labels", LABELS[:INIT_NUM_LAB], addon, LAMBDA, (device, scaler), epochs=EPOCHS, subgraph_size=SUBGRPAH_SIZE)

for num_labels in range(INIT_NUM_LAB, num_of_labels + 1):
    """if retrain_semi(num_labels):
        lg.train.semi_superv(model, opt, dataset, "x", "labels", DIST_VEC_SIZE, LABELS[:num_labels], LAMBDA, (device, scaler), addon = ADDON, subgraph_size=SUBGRPAH_SIZE, epochs=EPOCHS, cluster=(utl.clustering.NN(), 5))"""

    if retrain_superv(num_labels):
        addon.addon.update_size(num_labels)
        lg.train.superv(model, opt, dataset, "x", "labels", LABELS[:num_labels], addon, LAMBDA, (device, scaler), epochs=EPOCHS, subgraph_size=SUBGRPAH_SIZE)
    
    acc = lg.test.accuracy(model, validate, dataset, "x", "labels", LABELS[:num_labels], cluster=(utl.clustering.NN(), 5), sv_img = '2d', device=device)

    accuracy.append(acc)
    log.append("Label: {}/{} -- Accuracy: {}% -- Semisurpervised Retraining: {} -- Surpervised Retraining: {}".format(num_labels, num_of_labels, round(acc, 2), 5, retrain_superv(num_labels)))
    print(log[-1])


### Save State

In [None]:

time = str(tm_now()).replace(":", "-").replace(".", "").replace(" ", "_")
filename = "LIGN_" + dataset_name + "_training_"+time

## Save metrics
metrics = {
    "accuracy": accuracy,
    "log": log
}
utl.io.json(metrics, "data/metrics/"+filename+".json")

## Save hyperparameters
para = {
    "LAMBDA": LAMBDA,
    "DIST_VEC_SIZE": DIST_VEC_SIZE,
    "INIT_NUM_LAB": INIT_NUM_LAB,
    "LABELS": LABELS.tolist(),
    "SUBGRPAH_SIZE": SUBGRPAH_SIZE,
    "AMP_ENABLE": AMP_ENABLE,
    "EPOCHS": EPOCHS,
    "LR": LR,
    "RETRAIN_PER": RETRAIN_PER
}

utl.io.json(para, "data/parameters/"+filename+".json")

## Save model
check = {
    "model": model.state_dict(),
    "optimizer": opt.state_dict()
}
if AMP_ENABLE:
    check["scaler"] = scaler.state_dict()

th.save(check, "data/models/"+filename+".pt")
    

---
## View
### Performance

In [None]:
print(LAMBDA)