In [None]:
import numpy as np
from Scripts.Configs.ConfigClass import Config
from Scripts.DataManager.GraphConstructor.GraphConstructor import TextGraphType
from lightning.pytorch.loggers import CSVLogger
import os
from Scripts.DataManager.GraphLoader.AmazonReviewGraphDataModule import AmazonReviewGraphDataModule


config = Config(r'C:\Users\fardin\Projects\ColorIntelligence')
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
device = 'cuda'
batch_size = 64

In [None]:
tag_dep_seq = TextGraphType.DEPENDENCY | TextGraphType.TAGS | TextGraphType.SEQUENTIAL
data_manager = AmazonReviewGraphDataModule(config, True, True, shuffle=True, num_data_load = 50000, device='cpu', batch_size=batch_size, graph_type=tag_dep_seq, load_preprocessed_data = True)

In [None]:
t_dataloader = data_manager.train_dataloader()
v_dataloader = data_manager.val_dataloader()
X1, y1 = next(iter(t_dataloader))
X2, y2 = next(iter(v_dataloader))

In [None]:
from torch import nn, Tensor
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, BatchNorm, SAGEConv
class HeteroGCNConv(nn.Module):
    def __init__(self, in_feature, out_feature, dropout = 0.0) -> None:
        super().__init__()
        self.conv1 = GATv2Conv(in_feature, out_feature/4, heads=4, edge_dim=1, add_self_loops=False)
        self.batch_norm = BatchNorm(32)
        self.dropout= nn.Dropout(dropout)
        
    def forward(self, x: Tensor, edge_index: Tensor, edge_weights: Tensor) -> Tensor:
        x = self.conv1(x, edge_index)
        x = self.batch_norm(x)
        x = F.leaky_relu(x)
        x = self.dropout(x)
        return x

In [None]:
from torch_geometric.nn import to_hetero
hetero_model = to_hetero(HeteroGCNConv(300, 256, 0.2), X2.metadata())
pre = hetero_model(X2.x_dict, X2.edge_index_dict, X2.edge_attr_dict)

In [None]:
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F
from torch import Tensor
import torch
from typing import Dict
import torch_geometric
from torch.nn import Linear
from torch_geometric.nn import GATv2Conv, GCNConv, GCN2Conv, DenseGCNConv, dense_diff_pool, BatchNorm, global_mean_pool, global_add_pool, global_max_pool, MemPooling, SAGEConv, to_hetero, HeteroBatchNorm
from torch_geometric.nn import Sequential as GSequential
from torch_geometric.utils import to_dense_adj
from torch import nn, Tensor
import torch.nn.functional as F
from torch_geometric.data import HeteroData

class HeteroGcnGatModel1(torch.nn.Module):
    def __init__(self,
                 input_feature: int, out_features: int,
                 metadata,
                 base_hidden_feature: int=256,
                 dropout=0.1):
        
        super(HeteroGcnGatModel1, self).__init__()
        self.input_features = input_feature
        self.num_out_features = out_features
        self.bsh: int = base_hidden_feature
        bsh2: int = int(self.bsh/2)
        bsh4: int = int(self.bsh/4)
        bsh8: int = int(self.bsh/8)
        
        self.encoder = GSequential('x_dict, edge_index_dict, edge_weights_dict', [
            (to_hetero(HeteroGCNConv(input_feature, self.bsh, dropout), metadata), 'x_dict, edge_index_dict, edge_weights_dict ->x1'),
            (to_hetero(HeteroGCNConv(self.bsh, self.bsh, dropout), metadata), 'x1, edge_index_dict, edge_weights_dict ->x1'),
            (to_hetero(HeteroGCNConv(self.bsh, bsh2, dropout), metadata), 'x1, edge_index_dict, edge_weights_dict -> x2'),
            (to_hetero(HeteroGCNConv(bsh2, bsh2, dropout), metadata), 'x2, edge_index_dict, edge_weights_dict -> x2'),
            (to_hetero(HeteroGCNConv(bsh2, bsh2, dropout), metadata), 'x2, edge_index_dict, edge_weights_dict -> x2'),
            (lambda x1, x2: (x1, x2), 'x1, x2 -> x1, x2')
            
            # (to_hetero(HeteroGCNConv(bsh2, bsh4, dropout), metadata), 'x2, edge_index_dict, edge_weights_dict -> x3'),
            # (to_hetero(HeteroGCNConv(bsh4, bsh4, dropout), metadata), 'x3, edge_index_dict, edge_weights_dict -> x3'),
            # (to_hetero(HeteroGCNConv(bsh4, bsh4, dropout), metadata), 'x3, edge_index_dict, edge_weights_dict -> x3'),
            # (to_hetero(HeteroGCNConv(bsh4, bsh8, dropout), metadata), 'x3, edge_index_dict, edge_weights_dict -> x4'),
            # (to_hetero(HeteroGCNConv(bsh8, bsh8, dropout), metadata), 'x4, edge_index_dict, edge_weights_dict -> x4'),
            # (to_hetero(HeteroGCNConv(bsh8, bsh8, dropout), metadata), 'x4, edge_index_dict, edge_weights_dict -> x4'),
            # (lambda x1, x2, x3, x4: (x1, x2, x3, x4), 'x1, x2, x3, x4 -> x1, x2, x3, x4')
        ])
        
        # print(f'bsh8: {bsh8}')
        # self.attention = GSequential('x3, x4, edge_index, edge_weights', [
        #     (GATv2Conv(bsh8, bsh8, 2, edge_dim=1, dropout=dropout), 'x4, edge_index, edge_weights ->x4'),
        #     (BatchNorm(bsh4), 'x4->x4'),
        #     (nn.ReLU(), 'x4->x4'),
            
        #     (GCN2Conv(bsh4, 0.5, 0.1, 2), 'x4, x3, edge_index, edge_weights->x3'),
        #     (BatchNorm(bsh4), 'x3->x3'),
        #     (nn.ReLU(), 'x3->x3'),
        #     (GCNConv(bsh4, bsh4), 'x3, edge_index, edge_weights -> x3'),
        #     (BatchNorm(bsh4), 'x3->x3'),
        #     (nn.ReLU(), 'x3->x3'),
            
        #     (GATv2Conv(bsh4, bsh4, 2, edge_dim=1, dropout=dropout), 'x3, edge_index, edge_weights ->x3'),
        #     (BatchNorm(bsh2), 'x3->x3'),
        #     (nn.ReLU(), 'x3->x3'),
        #     (lambda x3, x4: (x3, x4), 'x3, x4 -> x3, x4')
        # ])
        
        # self.decoder = GSequential('x1, x2, x3, edge_index, edge_weights', [
            
        #     (GCN2Conv(bsh2, 0.5, 0.1, 2), 'x3, x2, edge_index, edge_weights->x2'),
        #     (BatchNorm(bsh2), 'x2->x2'),
        #     (nn.ReLU(), 'x2->x2'),
        #     (nn.Dropout(dropout), 'x2->x2'),
        # self.decoder = GSequential('x1, x2, edge_index, edge_weights', [
        #     (GCNConv(bsh2, bsh2), 'x2, edge_index, edge_weights -> x2'),
        #     (BatchNorm(bsh2), 'x2->x2'),
        #     (nn.ReLU(), 'x2->x2'),
        #     (nn.Dropout(dropout), 'x2->x2'),
        #     (GCNConv(bsh2, self.bsh), 'x2, edge_index->x2'),
        #     (BatchNorm(self.bsh), 'x2->x2'),
        #     (nn.ReLU(), 'x2->x2'),
        #     (nn.Dropout(dropout), 'x2->x2'),
            
        #     (GCN2Conv(self.bsh, 0.5, 0.1, 2), 'x2, x1, edge_index, edge_weights->x1'),
        #     (BatchNorm(self.bsh), 'x1->x1'),
        #     (nn.ReLU(), 'x1->x1'),
        #     (nn.Dropout(dropout), 'x1->x1'),
        #     (GCNConv(self.bsh, self.bsh), 'x1, edge_index, edge_weights ->x1'),
        #     (BatchNorm(self.bsh), 'x1->x1'),
        #     (nn.ReLU(), 'x1->x1'),
        #     (nn.Dropout(dropout), 'x1->x1'),
        #     (GCNConv(self.bsh, self.bsh), 'x1, edge_index, edge_weights ->x1'),
        #     (BatchNorm(self.bsh), 'x1->x1'),
        #     (nn.ReLU(), 'x1->x1'),
        #     (nn.Dropout(dropout), 'x1->x1')
        # ])
        self.mem_pool = MemPooling(bsh2, bsh2, 4, 2)
        self.output_layer = Linear(bsh2, self.num_out_features)
        
        
    def forward(self, x: HeteroData) -> Tensor:
        x1_dict, x2_dict = self.encoder(x.x_dict, x.edge_index_dict, x.edge_attr_dict)
        
        # x1_dict, x2_dict, x3_dict, x4_dict = self.encoder(x.x_dict, x.edge_index_dict, x.edge_attr_dict)
        # x_att, x4 = self.attention(x3_dict["word"], x4_dict["word"], 
        #                            x.edge_index_dict[('word', 'seq', 'word')],
        #                            x.edge_attr_dict[('word', 'seq', 'word')])
        # x_dec = self.decoder(x1_dict["word"], x2_dict["word"], x_att, 
        #                      x.edge_index_dict[('word', 'seq', 'word')],
        #                      x.edge_attr_dict[('word', 'seq', 'word')])
        
        x_pooled, S = self.mem_pool(x2_dict["word"], x['word'].batch)
        x_pooled = x_pooled.view(x_pooled.shape[0], -1)
        return self.output_layer(x_pooled)

In [None]:
torch_model = HeteroGcnGatModel1(300, 1, X1.metadata(), 256, dropout=0.2)
torch_model = torch_model.to(device)
print(next(iter(torch_model.parameters())).device)
print(torch_geometric.nn.summary(torch_model, X1.to(device)))

In [None]:
from Scripts.Models.LightningModels.LightningModels import BinaryLightningModel
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import lightning as L
from lightning.pytorch.tuner import Tuner

In [None]:
callbacks = [
    ModelCheckpoint(save_top_k=2, mode='max', monitor='val_acc', save_last=True),
    EarlyStopping(patience=50, mode='max', monitor='val_acc')
]
lightning_model = BinaryLightningModel(torch_model,
                                 torch.optim.Adam(torch_model.parameters(), lr=0.01, weight_decay=0.00055),
                                       torch.nn.BCEWithLogitsLoss(),
                                       learning_rate=0.01,
                                       batch_size=batch_size,
                                       ).to(device)
trainer = L.Trainer(
            callbacks=callbacks,
            max_epochs=50,
            accelerator='gpu',
            logger=CSVLogger(save_dir='logs/', name='GcnGatSentiment3'),
            num_sanity_val_steps=0)

In [None]:
tuner = Tuner(trainer)
results = tuner.lr_find(lightning_model, datamodule=data_manager, min_lr=0.000001,max_lr=0.1)
fig = results.plot(suggest=True)

In [None]:
from Scripts.Models.ModelsManager.ClassifierModelManager import ClassifierModelManager

In [None]:
model_manager = ClassifierModelManager(torch_model, lightning_model)

In [None]:
suggested_lr = model_manager.tune(data_manager=data_manager)

In [None]:
model_manager.fit(datamodule=data_manager, max_epochs=20, ckpt_path='best')

In [None]:
model_manager.fit(datamodule=data_manager)