#### Fardin Rastakhiz @2023

In [1]:
from Scripts.Configs.ConfigClass import Config
from Scripts.DataManager.GraphConstructor.GraphConstructor import TextGraphType
import os
from Scripts.DataManager.GraphLoader.AGGraphDataModule import AGGraphDataModule
import torch
from torch.utils.flop_counter import FlopCounterMode

config = Config(r'C:\Users\fardin\Projects\ColorIntelligence')
# config = Config(r'E:\Darsi\Payan Name Arshad\Second Work\ColorIntelligence2\ColorIntelligence')
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128




In [2]:
from Scripts.Models.GraphEmbedding.HeteroDeepGraphEmbedding4 import HeteroDeepGraphEmbedding4
from Scripts.Models.GraphEmbedding.HeteroDeepGraphEmbedding5 import HeteroDeepGraphEmbedding5
from Scripts.Models.GraphEmbedding.HeteroDeepGraphEmbedding6 import HeteroDeepGraphEmbedding6

In [3]:
from Scripts.Models.LightningModels.LightningModels import HeteroMultiClassLightningModel
from Scripts.Models.LossFunctions.HeteroLossFunctions import MulticlassHeteroLoss1, MulticlassHeteroLoss2, MulticlassHeteroLoss3
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import lightning as L
from lightning.pytorch.tuner import Tuner
from Scripts.Models.ModelsManager.ClassifierModelManager import ClassifierModelManager

In [4]:

import torch.nn.functional as F
from torch import Tensor
import torch
from torch.nn import Linear
from torch_geometric.nn import BatchNorm, MemPooling, to_hetero, PairNorm
from torch_geometric.data import HeteroData
from Scripts.Models.BaseModels.HeteroGat import HeteroGat
from Scripts.Models.BaseModels.HeteroLinear import HeteroLinear

class HeteroDeepGraphEmbedding6(torch.nn.Module):
    
    def __init__(self,
                 input_feature: int, out_features: int,
                 metadata,
                 hidden_feature: int=256,
                 device = 'cpu',
                 dropout=0.1,
                 edge_type_count=9,
                 edge_type_weights=-1,
                 active_keys = ['dep', 'tag', 'word', 'sentence', 'general', 'sentiment'],
                 num_pooling_classes=1
                 ):

        super(HeteroDeepGraphEmbedding6, self).__init__()
        self.input_features = input_feature
        self.num_out_features = out_features
        self.hidden_feature: int = hidden_feature
        self.edge_type_count = edge_type_count
        self.edge_type_weights = torch.nn.Parameter(torch.tensor([1]* self.edge_type_count if edge_type_weights==-1 else  edge_type_weights).to(torch.float32), requires_grad=False)

        self.part_weight_norm = torch.nn.LayerNorm((self.edge_type_count,))
        self.norm = PairNorm()
        self.drop = torch.nn.Dropout(0.2)
        self.active_keys = active_keys
        
        self.hetero_linear1 = to_hetero(HeteroLinear(self.input_features,self.hidden_feature, use_dropout=False, use_batch_norm=True), metadata)
        
        self.hetero_gat_1 = to_hetero(HeteroGat(self.hidden_feature, self.hidden_feature, dropout, num_heads=2), metadata)
        self.hetero_gat_2 = to_hetero(HeteroGat(self.hidden_feature, self.hidden_feature, dropout, num_heads=2), metadata)
        
        self.hetero_linear_2 = to_hetero(HeteroLinear(self.hidden_feature, self.input_features, dropout, use_batch_norm=True), metadata)
        
        self.num_pooling_classes = torch.nn.Parameter(torch.tensor(num_pooling_classes).to(torch.int32), requires_grad=False)
        self.mem_pool = MemPooling(self.hidden_feature, self.hidden_feature, 2, self.num_pooling_classes)
        
        self.linear_1 = Linear(self.hidden_feature* self.num_pooling_classes, self.hidden_feature)
        self.linear_2 = Linear(self.hidden_feature, self.hidden_feature)
        self.batch_norm_1 = BatchNorm(self.hidden_feature)
        
        self.output_layer = Linear(self.hidden_feature, self.num_out_features)
        
        self.dep_embedding = torch.nn.Embedding(45, self.input_features)
        self.tag_embedding = torch.nn.Embedding(50, self.input_features)
        self.dep_unembedding = torch.nn.Linear(self.hidden_feature, 45)
        self.tag_unembedding = torch.nn.Linear(self.hidden_feature, 50)
        
        self.pw1 = torch.nn.Parameter(torch.tensor(self.edge_type_weights, dtype=torch.float32), requires_grad=False)
        
        self.x_batches = None
        self.x_batches_cpu = None
        self.x_dict_cpu_1 = None
        self.x_dict_cpu_2 = None
        

    def forward(self, x: HeteroData) -> Tensor:
        self.x_batches = {k:x[k].batch for k in self.active_keys}
        # self.x_batches_cpu = {k:self.x_batches[k].to('cpu') for k in self.active_keys}
        x_dict, edge_attr_dict, edge_index_dict = self.preprocess_data(x)
        edge_attr_dict = self.update_weights(edge_attr_dict, self.pw1)
        x_dict = self.hetero_linear1(x_dict)
        x_dict = self.hetero_gat_1(x_dict, edge_index_dict, edge_attr_dict)
        # self.x_dict_cpu_1 = {k: x_dict[k].to('cpu') for k in x_dict}
        self.normalize(x_dict, self.x_batches)
        x_dict = self.hetero_gat_2(x_dict, edge_index_dict, edge_attr_dict)
        # self.x_dict_cpu_2 = {k: x_dict[k].to('cpu') for k in x_dict}
        x_pooled, S = self.mem_pool(x_dict['word'], self.x_batches['word'])
                
        x_pooled = x_pooled.view(x_pooled.shape[0], -1)
        x_pooled = F.relu(self.linear_1(x_pooled))
        x_pooled = F.relu(self.batch_norm_1(self.linear_2(x_pooled)))
        out = self.output_layer(x_pooled)
        
        x_dict_out = self.hetero_linear_2(x_dict)
        x_dict_out['dep'] = self.dep_unembedding(x_dict['dep'])
        x_dict_out['tag'] = self.tag_unembedding(x_dict['tag'])
        
        return out, x_dict_out

    def preprocess_data(self, x):
        x_dict = {key: x.x_dict[key] for key in x.x_dict}
        if 'dep' in x_dict:
            x_dict['dep'] = self.dep_embedding(x_dict['dep'])
        if 'tag' in x_dict:
            x_dict['tag'] = self.tag_embedding(x_dict['tag'])

        edge_attr_dict = x.edge_attr_dict
        edge_index_dict = x.edge_index_dict

        return x_dict, edge_attr_dict, edge_index_dict

    def normalize(self, x_dict, x_batches):
        for k in self.active_keys:
            vecs = x_dict[k]
            if k not in x_batches:
                print('k is not in x_batches')
                continue
            batches = x_batches[k]
            if batches is None:
                print('batches is none')
                continue
            if len(batches) == 0:
                print('batches is empty')
                continue
            
            x_dict[k] = self.norm(vecs, batches)
        return x_dict

    def update_weights(self, edge_attr_dict, part_weights):
        for i, key in enumerate(edge_attr_dict):
            edge_attr = edge_attr_dict[key]
            if edge_attr is None or edge_attr == ('word', 'seq', 'word'):
                continue
            edge_attr_dict[key]= edge_attr * part_weights[i]
        return edge_attr_dict

    def get_scale_same(self, scale:float, attributes: Tensor):
        if attributes is None or len(attributes) == 0:
            return
        attributes = scale * torch.ones_like(attributes)
        return attributes

In [5]:
import pandas as pd
test_df = pd.read_csv(r"data\AG\test.csv", header=None)
g_test_df = test_df.groupby(0)

In [6]:
mega_texts = {}
for g in g_test_df:
    mega_text = ''
    for i in range(g[1].shape[0]):
        mega_text += g[1].iloc[i][1] + " " + g[1].iloc[i][2] + ". \n"
    print(g[0], g[1].iloc[0][2])
    mega_texts[g[0]] = mega_text

1 Canadian Press - VANCOUVER (CP) - The sister of a man who died after a violent confrontation with police has demanded the city's chief constable resign for defending the officer involved.
2 Michael Phelps won the gold medal in the 400 individual medley and set a world record in a time of 4 minutes 8.26 seconds.
3 Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.
4 SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket.


In [7]:
from Scripts.DataManager.GraphConstructor.SentimentGraphConstructor import SentimentGraphConstructor


sgc = SentimentGraphConstructor(test_df[1][:10], 'sentiment', config, load_preprocessed_data=False, naming_prepend='graph', start_data_load=0, end_data_load=4, use_sentence_nodes=True , use_general_node=True)

In [8]:
docs = {}
for k in [1, 2]:
    docs[k] = sgc.nlp(mega_texts[k])

In [15]:
len(docs[2])

85280

In [48]:
mega_graphs = {}
for k in [1, 2, 3, 4]:
    mega_graphs[k] = sgc.to_graph(mega_texts[k])

In [44]:
meta_data = mega_graphs[3].metadata()

In [25]:
edge_type_weights = [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.]
graph_embedding = HeteroDeepGraphEmbedding6(300, 4, meta_data, 32, dropout=0.2, edge_type_count=11, edge_type_weights=edge_type_weights)
graph_embedding = graph_embedding.to(device)
lightning_model = HeteroMultiClassLightningModel.load_from_checkpoint(r'logs\hetero_model_18_AG\version_26\checkpoints\epoch=74-step=35100.ckpt', model=graph_embedding, num_classes=4)
lightning_model.eval()

  self.pw1 = torch.nn.Parameter(torch.tensor(self.edge_type_weights, dtype=torch.float32), requires_grad=False)


HeteroMultiClassLightningModel(
  (model): HeteroDeepGraphEmbedding6(
    (part_weight_norm): LayerNorm((11,), eps=1e-05, elementwise_affine=True)
    (norm): PairNorm()
    (drop): Dropout(p=0.2, inplace=False)
    (hetero_linear1): GraphModule(
      (linear): ModuleDict(
        (dep): Linear(in_features=300, out_features=32, bias=True)
        (tag): Linear(in_features=300, out_features=32, bias=True)
        (word): Linear(in_features=300, out_features=32, bias=True)
        (sentence): Linear(in_features=300, out_features=32, bias=True)
        (general): Linear(in_features=300, out_features=32, bias=True)
        (sentiment): Linear(in_features=300, out_features=32, bias=True)
      )
      (batch_norm): ModuleDict(
        (dep): BatchNorm(32)
        (tag): BatchNorm(32)
        (word): BatchNorm(32)
        (sentence): BatchNorm(32)
        (general): BatchNorm(32)
        (sentiment): BatchNorm(32)
      )
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (hetero_gat_

In [33]:
import torch_geometric

In [49]:
mega_graphs

{1: HeteroData(
   dep={
     length=45,
     x=[45],
   },
   tag={
     length=50,
     x=[50],
   },
   word={ x=[85856, 300] },
   sentence={ x=[1957, 300] },
   general={ x=[1, 300] },
   sentiment={ x=[2, 300] },
   (dep, dep_word, word)={
     edge_index=[2, 83899],
     edge_attr=[83899],
   },
   (word, word_dep, dep)={
     edge_index=[2, 83899],
     edge_attr=[83899],
   },
   (tag, tag_word, word)={
     edge_index=[2, 85856],
     edge_attr=[85856],
   },
   (word, word_tag, tag)={
     edge_index=[2, 85856],
     edge_attr=[85856],
   },
   (word, seq, word)={
     edge_index=[2, 171710],
     edge_attr=[171710],
   },
   (general, general_sentence, sentence)={
     edge_index=[2, 1957],
     edge_attr=[1957],
   },
   (sentence, sentence_general, general)={
     edge_index=[2, 1957],
     edge_attr=[1957],
   },
   (word, word_sentence, sentence)={
     edge_index=[2, 85856],
     edge_attr=[85856],
   },
   (sentence, sentence_word, word)={
     edge_index=[2, 85856],


In [57]:
mega_X = torch_geometric.data.Batch.from_data_list([mega_graphs[i].to(device) for i in [1, 2, 3, 4]])

In [58]:
with torch.no_grad():
    y_pred = lightning_model(mega_X)
    y_pred = y_pred[0]
    print(torch.argmax(y_pred, dim=1))

tensor([3, 1, 2, 3], device='cuda:0')


In [42]:
with torch.no_grad():
    y_pred = lightning_model(mega_X)
    y_pred = y_pred[0]
    print(torch.argmax(y_pred, dim=1))


tensor([1, 2], device='cuda:0')


In [29]:
with torch.no_grad():
    y_pred = lightning_model(mega_X)
    y_pred = y_pred[0]
    print(torch.argmax(y_pred, dim=1))


tensor([3, 1], device='cuda:0')
