# LIGN
Graph Induced Lifelong Learning for Spatial-Temporal Data

----

## Imports

In [8]:
import lign as lg
import lign.models as md
import lign.utils as utl

import torch as th
import torchvision as tv
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import datetime
tm_now = datetime.datetime.now
from torch.cuda.amp import GradScaler

----

## Preprocessing 
The data needs to formatted and set up for lign. GraphDatabase instaces are needed to load graphs into the model and track their node

### Create Dataset
***_ Only run one of the following cells within "Create Dataset"_**

***_ Also, either run "Create Dataset" or "Load Dataset". Not both _**

In [7]:
trans = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = utl.io.cifar_to_lign("data/datasets/CIFAR100", transforms = trans)
dataset.save("data/datasets/cifar100_train.lign")

validate = utl.io.cifar_to_lign("data/datasets/CIFAR100", train=False, transforms = trans)
validate.save("data/datasets/cifar100_test.lign")

### Load Dataset
***_ Only run one of the following cells within "Load Dataset"_**

***_ Also, either run "Create Dataset" or "Load Dataset". Not both _**

In [9]:
dataset = lg.graph.GraphDataset("data/datasets/cifar100_train.lign")
validate = lg.graph.GraphDataset("data/datasets/cifar100_test.lign")

### Cuda GPUs
Searches for nvidia GPUs in order to speed up training and inference

In [10]:
device = th.device("cuda" if th.cuda.is_available() else "cpu")

### Functions and NNs
Possible functions and NNs used throughout the code to apply changes to the graph's data

In [11]:
def sum_neighs_data(neighs):
    out = neighs[0]
    for neigh in neighs[1:]:
        out = out + neighs
    return out

### Hyperparameters
*_ LABELS needs to be adjusted to reflect the total number of labels in the model _
* LAMBDA: regulates how much the model relies on difference between the nodes vs the features that lead to their label when calculating pairwise loss
* DIST_VEC_SIZE: size of vector representing the mapping of the nodes by the model
* INIT_NUM_LAB: number of labels used to training the model initially in the supervised method to learn pairwise mapping
* LABELS: track all the labels that model comes across. The order will be randomized
* SUBGRAPH_SIZE: represent the number of nodes processed at once. The models don't have batches. This is the closest thing to it
* AMP_ENABLE: to enable mixed precission training
* LR: Learning rate


In [12]:
LAMBDA = 0.0001
DIST_VEC_SIZE = 3 # 3 was picked so the graph can be drawn in a 3d grid
INIT_NUM_LAB = 10
LABELS = np.arange(100)
SUBGRPAH_SIZE = 200
AMP_ENABLE = True
LR = 1e-3

np.random.shuffle(LABELS)

---
## Model
***_ Either run LIGN or R-LIGN. Not both _** 
### LIGN

[L]ifelong Learning [I]nduced by [G]raph [N]eural Networks Model (LIGN)

In [13]:
class LIGN_cnn(nn.Module):
    def __init__(self, out_feats):
        super(LIGN_cnn, self).__init__()
        self.gcn1 = md.layers.GCN(module_post = nn.Conv2d(3, 6, 5))
        self.gcn2 = md.layers.GCN(module_post = nn.Conv2d(6, 16, 5))
        self.gcn3 = md.layers.GCN(module_post = nn.Linear(16 * 5 * 5, 150))
        self.gcn4 = md.layers.GCN(module_post = nn.Linear(150, 84))
        self.gcn5 = md.layers.GCN(module_post = nn.Linear(84, out_feats))
        self.pool = md.layers.GCN(module_post = nn.MaxPool2d(2, 2))

    def forward(self, g, features):
        x = self.pool(F.relu(self.gcn1(g, features)))
        x = self.pool(F.relu(self.gcn2(g, x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.gcn3(g, x))
        x = F.relu(self.gcn4(g, x))
        
        return th.tanh(self.gcn5(g, x))

model = LIGN_cnn(DIST_VEC_SIZE).to(device)

AttributeError: module 'lign.models' has no attribute 'layers'

### R-LIGN
[R]ecurrent [L]ifelong Learning [I]nduced by [G]raph [N]eural Networks Model (R-LIGN)

In [None]:
#dataset.set_data("h", )
#dataset.set_data("c", )
####
# model = R_LIGN(DIST_VEC_SIZE)

----
## Training
### Parameters

In [None]:
#opt
accuracy = []
num_of_labels = len(LABELS)

opt = th.optim.Adam(model.parameters(), lr=LR)

scaler = GradScaler() if AMP_ENABLE else None

### Load Pre-Trained Model
***_ Replace file name_**

In [None]:
checkpoint = torch.load('data/model/LIGN_training_model_cool.pt')

model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

if AMP_ENABLE:
    scaler.load_state_dict(checkpoint['scaler'])

### Train Model

In [14]:
lg.train.superv(model, dataset, LABELS[:INIT_NUM_LAB])

for num_labels in range(INIT_NUM_LAB, num_of_labels + 1):

    retrain_track = num_labels%30
    if retrain_track == 15:
        lg.train.unsuperv(model, dataset, opt, "x", "labels", DIST_VEC_SIZE, LAMBDA, LABELS[:num_labels], (device, scaler), subgraph_size=SUBGRPAH_SIZE)
    elif retrain_track == 0:
        lg.train.superv(model, dataset, opt, "x", "labels", DIST_VEC_SIZE, LAMBDA, LABELS[:num_labels], (device, scaler), subgraph_size=SUBGRPAH_SIZE)
    
    #acc = lg.test.accuracy(model, validate, cluster, LABELS[:num_labels])
    accuracy.append()
    print("Label: {num_labels}/{num_of_labels} | Accuracy: {acc} | Unsurpervised Retraining: {retrain_track}:15 | " +
    "Surpervised Retraining: {retrain_track}:0")


AttributeError: module 'lign' has no attribute 'train'

### Save State

In [None]:
utl.io.json(accuracy, "data/performance/accuracy/LIGN_training_acc_"+str(tm_now())+".json")
## loss

check = {
    "model": model.state_dict(),
    "optimizer": opt.state_dict()
}

if AMP_ENABLE:
    check["scaler"] = scaler.state_dict()

th.save(check, "data/model/LIGN_training_model_"+str(tm_now())+".pt")

---
## View
### Performance

### Graph