# Artificial Association Neural Networks
## This notebook covers only a basic image training process for understanding AAN.

<img src="img/Fig3.png" width="700" height="200"/>
<!-- 
![n1](img/Fig3.png)
 -->

This neural network goes through the following three steps:

1. Feature extraction from each domain
2. Association of the extracted features
3. Utilization of the associated information to perform various subtasks and a main task





In [1]:
from models.artificial_association_networks import ArtificialAssociationNeuralNetworks
from data_structures.neurodataloader import NeuroDataset, createNeuroDataloader
from data_structures.neuronode import NeuroNode
from data_structures.batch_neurotree import BatchNeuroTree



## If the current device is not in CUDA mode and is using CPU, please modify the config/options.py file in the config directory.



In [2]:


import numpy as np
import copy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

print('Device : ', device)

import random
import numpy as np

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True




Device :  cuda


In [3]:
from collections import defaultdict
from config.option import device
import pickle


## The key idea of this neural network model is to generalize the learning process for any given data without altering the structure of the neural network model. 
![n2](img/Fig4.png)
#### In other words, the goal is to implement a model that can embed any arbitrary X as a hidden state (h), and NeuroNode and NeuroTree are data structures that allow for varying the number of convolutions and the structure of layers in the neural network model based on the data.

<img src="img/Fig5a.png" width="130" height="130"/>
On the left is a multi-layer perceptron structure, and on the right is a NeuroTree structure.

<img src="img/Fig5b.png" width="350" height="150"/> 
On the left is a Recurrent Neural Network structure, and on the right is a NeuroTree structure.

<img src="img/Fig5c.png" width="300" height="150"/> 
On the left is a Recursive Neural Network (RvNN) structure, and on the right is a NeuroTree structure.

<img src="img/Fig5d.png" width="300" height="130"/> 
On the left is a Convolutional Neural Network (CNN) structure, and on the right is a NeuroTree structure.

<img src="img/Fig5e.png" width="300" height="130"/>
On the left is a Graph Neural Network (GNN) structure, and on the right is a NeuroTree structure.


### In addition to the above-mentioned structures, NeuroTree also allows for connections that pass information beyond the layers.

<img src="img/Fig5f.png" width="300" height="150"/> <img src="img/Fig5g.png" width="200" height="150"/> 

The details are in the paper.

<!-- ![n7](img/Fig5e.png)
![n8](img/Fig5f.png)
 -->

### 1: Defining DataLoader and Neurotree Builder Functions

In this chapter, we will define the DataLoader and Neurotree Builder functions.

In [4]:
from datas.image.load import MNIST_DATA


image_train, image_test = MNIST_DATA('./datas/image') # image CNN

todataset=lambda o: o.dataset
image_train, image_valid = torch.utils.data.random_split(image_train, [50000, 10000])





In [5]:
train_x = {
    'image' : image_train.dataset.data[image_train.indices].unsqueeze(1)
}

train_y = {
    'image' : image_train.dataset.targets[image_train.indices]
}

valid_x = {
    'image' : image_valid.dataset.data[image_valid.indices].unsqueeze(1)
}

valid_y = {
    'image' : image_valid.dataset.targets[image_valid.indices]
}



test_x = {
    'image' : image_test.data.unsqueeze(1)
}

test_y = {
    'image' : image_test.targets
}


maintask_map = {
    'image' : 'classification'
}


# neurotree building 
def image2neurotree(data, mt):
    leaf = NeuroNode(data.to(device, dtype=torch.float)/255, 'image', None, None, [])
    node = NeuroNode(None, None, None, None, [])
    node.insert(leaf)
    root = NeuroNode(None, None, None, None, [])
    root.insert(node) 
    return root



# data -> neurotree
xmt2neurotree = {
    'image' : image2neurotree
}



In [6]:
train_dataset = NeuroDataset(train_x, train_y, maintask_map, xmt2neurotree)
valid_dataset = NeuroDataset(valid_x, valid_y, maintask_map, xmt2neurotree)
test_dataset = NeuroDataset(test_x, test_y, maintask_map, xmt2neurotree)



In [7]:
train_dataloader = createNeuroDataloader(train_dataset, batch_size=100, prefetch_factor=None)
valid_dataloader = createNeuroDataloader(valid_dataset, batch_size=100, prefetch_factor=None)
test_dataloader = createNeuroDataloader(test_dataset, batch_size=100, prefetch_factor=None)


### 2: Defining Feature Extraction Models

In this chapter, we will define the Feature Extraction Models.

In [8]:

class LeNet_5(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1)
        self.parameter_init()

    def parameter_init(self):
        for param in self.parameters():
            if len(param.shape) > 1:
                nn.init.xavier_uniform_(param, gain=1.414)

    def forward(self, batch_tree: BatchNeuroTree):
        batch_x = batch_tree.getX()
        x = torch.stack(batch_x).to(device, dtype = torch.float)

        x = F.tanh(self.conv1(x))
        x = F.avg_pool2d(x, 2, 2)
        x = F.tanh(self.conv2(x))
        x = F.avg_pool2d(x, 2, 2)
        x = F.tanh(self.conv3(x))
        x = x.view(-1, 120)
        zeros = torch.zeros(x.shape[0], 8).to(device)
        out = torch.cat([x, zeros], dim=-1)
        return out

    def setting_pretrained_model(self, model):
        self.conv1 = model.conv1
        self.conv2 = model.conv2
        self.conv3 = model.conv3



In [9]:

image2vec = LeNet_5().to(device)

fe_networks = {
    # 'class': astcls2vec,
    # 'Num': astnum2vec,
    'image' : image2vec
}

### 3: Defining Subtasks and Maintask

In this chapter, we will define the subtasks and maintask.

In [10]:
class Classification(torch.nn.Module):
    def __init__(self, hidden_dim, class_count):
        super().__init__()
        self.task_name = 'classification'
        self.hidden_dim = hidden_dim
        self.class_count = class_count

        self.loss_function = torch.nn.NLLLoss()


        self.recognition_layer = nn.Sequential(
            nn.Linear(self.hidden_dim, class_count),
            nn.LogSoftmax(dim=-1)
            #                     nn.Softmax(dim = 1)
        )

    def forward(self, hiddens, tree):
        if type(hiddens) == list:
            hiddens = torch.stack(hiddens, dim = 0)
        outputs = self.recognition_layer(hiddens)
        return outputs

    def loss(self,batch_domains, batch_outputs, batch_targets):
        # corrects = 0
        preds = torch.stack(batch_outputs, dim = 0)
        targets = torch.stack(batch_targets, dim=0).to(device, dtype = torch.long)

        # print('loss',preds.shape, targets.shape)

        corrects = defaultdict(int)
        counts = defaultdict(int)


        loss = self.loss_function(preds, targets)
        _, predictions = torch.max(preds, 1)

        for domain, label, prediction in zip(batch_domains, targets, predictions):
            if label == prediction:
                corrects[domain] += 1
            counts[domain] += 1
        return loss, corrects, counts


In [11]:

# sub_classification = Classification(128, 70).to(device)
subtask_networks = {
    # 'parent_pred': sub_classification
}

main_classification = Classification(128, 10).to(device)
# self.main_autoencoder = AutoEncoder(128, self.dran).to(device)
maintask_networks = {
    'classification': main_classification,
}



### 4: Defining Artificial Association Networks (AAN)

In this chapter, we will define the Artificial Association Networks (AAN), which integrate the models defined earlier.

In [12]:



aan_model = ArtificialAssociationNeuralNetworks(
    128, 128,
    fe_networks, {}, 
    subtask_networks, maintask_networks,    
    version='gaau').to(device)




In [13]:
# import wandb
# import random

# lr = 0.02
# model_name = "GAAU-128-128"
# dataset = "MNIST"
# epochs = 10

# # start a new wandb run to track this script
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="artificial-association-networks-test",
    
#     # track hyperparameters and run metadata
#     config={
#     "learning_rate": lr,
#     "architecture": model_name,
#     "dataset": dataset,
#     "epochs": epochs,
#     }
# )

In [14]:
import pytorch_lightning as pl

In [15]:

class PylightningArtificialAssociationNeuralNetworks(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.training_result = 0
        self.training_count = 0
        self.training_loss = 0
        self.valid_result = 0
        self.valid_count = 0
        self.valid_loss = 0
        
        self.model = model


    def forward(self, data, mt):
        return self.aan(data, mt, node_level = True)

    def training_step(self, batch, batch_idx):
        x, y, mt, domain = batch
        maintask_outputs, h_root, batchNeuroTree = self.model(x, mt)
        y_hat = torch.stack(maintask_outputs, dim = 0)
        y_target = torch.stack(y, dim = 0)

        loss = F.cross_entropy(y_hat, y_target)
        
        predictions = torch.argmax(y_hat, 1)
        pred = torch.sum(predictions == y_target)
        self.training_result += pred.item()
        self.training_count += len(y)
        self.training_loss += loss.item()

        
        # logs metrics for each training_step,
        # and the average across the epoch, to the progress bar and logger
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_train_epoch_end(self):
        # training_step_outputs has all my batches

        if self.training_count != 0:
            print('[Training]', (self.training_result, self.training_count), self.training_result/self.training_count)
        if self.valid_count != 0:
            print('[Validation]', (self.valid_result, self.valid_count), self.valid_result/self.valid_count)
        # wandb.log(
        #     {
        #     "acc": self.training_result/self.training_count, 
        #     "val": self.valid_result/self.valid_count,
        #     "training_loss" : self.training_loss,
        #     "valid_loss" : self.valid_loss
        #     }
        # )
        self.training_result = 0
        self.training_count = 0
        self.training_loss = 0
        self.valid_result = 0
        self.valid_count = 0            
        self.valid_loss = 0
        
        return
            

    def validation_step(self, batch, batch_idx):
        x, y, mt, domain = batch
        maintask_outputs, h_root, batchNeuroTree = self.model(x, mt)
        y_hat = torch.stack(maintask_outputs, dim = 0)
        y_target = torch.stack(y, dim = 0)

        loss = F.cross_entropy(y_hat, y_target)
        predictions = torch.argmax(y_hat, 1)
        pred = torch.sum(predictions == y_target)
        self.valid_result += pred.item()
        self.valid_count += len(y)
        self.valid_loss += loss.item()

        return pred        
   
    
    
    def test_step(self, batch, batch_idx):
        x, y, mt, domain = batch
        maintask_outputs, h_root, batchNeuroTree = self.model(x, mt)
        y_hat = torch.stack(maintask_outputs, dim = 0)
        y_target = torch.stack(y, dim = 0)
        predictions = torch.argmax(y_hat, 1)
        pred = torch.sum(predictions == y_target)
        return pred
        

            

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

In [16]:
pyaan = PylightningArtificialAssociationNeuralNetworks(aan_model)

In [17]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


count_parameters(main_classification)
count_parameters(aan_model.multi_main_task_networks), count_parameters(aan_model.multi_sub_task_networks)


count_parameters(aan_model.multi_feature_extraction_networks), count_parameters(aan_model.multi_restoration_networks)

count_parameters(aan_model.ran.rnn.X2H)

count_parameters(aan_model.ran.rnn.H2H)
count_parameters(aan_model.ran.rnn)
count_parameters(aan_model.ran.gnn)
count_parameters(aan_model.ran)
count_parameters(aan_model.multi_feature_extraction_networks)
count_parameters(pyaan)


134415

In [18]:
trainer = pl.Trainer(max_epochs = 10)
trainer.fit(pyaan, train_dataloader, valid_dataloader)
# wandb.finish()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                                | Params
--------------------------------------------------------------
0 | model | ArtificialAssociationNeuralNetworks | 134 K 
--------------------------------------------------------------
134 K     Trainable params
0         Non-trainable params
134 K     Total params
0.538     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

[Training] (42354, 50000) 0.84708
[Validation] (9209, 10200) 0.902843137254902


Validation: 0it [00:00, ?it/s]

[Training] (46851, 50000) 0.93702
[Validation] (9438, 10000) 0.9438


Validation: 0it [00:00, ?it/s]

[Training] (47670, 50000) 0.9534
[Validation] (9556, 10000) 0.9556


Validation: 0it [00:00, ?it/s]

[Training] (48151, 50000) 0.96302
[Validation] (9626, 10000) 0.9626


Validation: 0it [00:00, ?it/s]

[Training] (48441, 50000) 0.96882
[Validation] (9665, 10000) 0.9665


Validation: 0it [00:00, ?it/s]

[Training] (48641, 50000) 0.97282
[Validation] (9700, 10000) 0.97


Validation: 0it [00:00, ?it/s]

[Training] (48764, 50000) 0.97528
[Validation] (9726, 10000) 0.9726


Validation: 0it [00:00, ?it/s]

[Training] (48878, 50000) 0.97756
[Validation] (9749, 10000) 0.9749


Validation: 0it [00:00, ?it/s]

[Training] (48971, 50000) 0.97942
[Validation] (9762, 10000) 0.9762


Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


[Training] (49038, 50000) 0.98076
[Validation] (9773, 10000) 0.9773
