In [22]:
import numpy as np
# from Scripts.DataManager.DabasePreparations.AmazonReviewSentiGraph import AmazonReviewSentiGraph
# from Scripts.Models.ModelsManager.SimpleGraphClassifierModelManager import SimpleGraphClassifierModelManager
from Scripts.Configs.ConfigClass import Config
config = Config(r'C:\Users\fardin\Projects\ColorIntelligence')
from Scripts.DataManager.GraphConstructor.GraphConstructor import TextGraphType
from lightning.pytorch.loggers import CSVLogger
import os
# os.environ['TORCH_USE_CUDA_DSA']
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
device = 'cuda'
batch_size = 16

In [23]:
import pickle
from typing import List, Dict, Tuple
import networkx as nx
import pandas as pd
from torch_geometric.utils import to_networkx
from Scripts.DataManager.GraphConstructor.GraphConstructor import GraphConstructor
from torch_geometric.data import Data , HeteroData
from Scripts.Configs.ConfigClass import Config
import spacy
import torch
import numpy as np
import os

class TagDepTokenGraphConstructor(GraphConstructor):

    class _Variables(GraphConstructor._Variables):
        def __init__(self):
            super(TagDepTokenGraphConstructor._Variables, self).__init__()
            self.nlp_pipeline: str = ''
    def __init__(self, texts: List[str], save_path: str, config: Config,
                 lazy_construction=True, load_preprocessed_data=False, naming_prepend='' , use_compression=True, use_sentence_nodes=False, use_general_node=True):

        super(TagDepTokenGraphConstructor, self)\
            .__init__(texts, self._Variables(), save_path, config, lazy_construction, load_preprocessed_data,
                      naming_prepend , use_compression)
        self.settings = {"dep_token_weight" : 1, "token_token_weight" : 2, "tag_token_weight" : 1, "general_token_weight" : 1, "general_sentence_weight" : 1, "sentence_token_weight" : 1}
        self.use_sentence_nodes = use_sentence_nodes
        self.use_general_node = use_general_node
        self.var.nlp_pipeline = self.config.spacy.pipeline
        self.var.graph_num = len(self.raw_data)
        self.nlp = spacy.load(self.var.nlp_pipeline)
        self.dependencies = self.nlp.get_pipe("parser").labels
        self.tags = self.nlp.get_pipe("tagger").labels
            
    def setup(self, load_preprocessed_data = True):
        self.load_preprocessed_data = True
        if load_preprocessed_data:
            self.load_var()
            self.num_data_load = self.var.graph_num if self.num_data_load > self.var.graph_num else self.num_data_load
            if not self.lazy_construction:
                for i in range(self.num_data_load):
                    if i%100 == 0:
                        print(f' {i} graph loaded')
                    if (self._graphs[i] is None) and (i in self.var.graphs_name):
                        self.load_data_compressed(i)
                    else:
                        self._graphs[i] = self.to_graph(self.raw_data[i])
                        self.var.graphs_name[i] = f'{self.naming_prepend}_{i}'
                        self.save_data_compressed(i)
                self.var.save_to_file(os.path.join(self.save_path, f'{self.naming_prepend}_var.txt'))
        else:
            if not self.lazy_construction:
                save_start = 0
                self.num_data_load = len(self.raw_data) if self.num_data_load > len(self.raw_data) else self.num_data_load
                for i in range(self.num_data_load):
                    if i not in self._graphs:
                        if i % 100 == 0:
                            self.save_data_range(save_start, i)
                            save_start = i
                            print(f'i: {i}')
                        self._graphs[i] = self.to_graph(self.raw_data[i])
                        self.var.graphs_name[i] = f'{self.naming_prepend}_{i}'
                self.save_data_range(save_start, self.num_data_load)
            self.var.save_to_file(os.path.join(self.save_path, f'{self.naming_prepend}_var.txt'))


    def to_graph(self, text: str):
        doc = self.nlp(text)
        if len(doc) < 2:
            return
        if self.use_sentence_nodes:
            self.__create_graph_with_sentences(doc)
        else:
            self.__create_graph(doc,use_general_node=self.use_general_node)
    
    def __find_dep_index(self , dependency : str):
        for dep_idx in range(len(self.dependencies)):
            if self.dependencies[dep_idx] == dependency:
                return dep_idx
        return -1 # means not found
    def __build_initial_dependency_vectors(self , dep_length : int):
        return torch.zeros((dep_length, self.nlp.vocab.vectors_length), dtype=torch.float32)
    def __find_tag_index(self , tag : str):
        for tag_idx in range(len(self.tags)):
            if self.tags[tag_idx] == tag:
                return tag_idx
        return -1 # means not found
    def __build_initial_tag_vectors(self , tags_length : int):
        return torch.zeros((tags_length, self.nlp.vocab.vectors_length), dtype=torch.float32)
    def __build_initial_general_vector(self):
        return torch.zeros((1 , self.nlp.vocab.vectors_length), dtype=torch.float32)    
    def __create_graph_with_sentences(self , doc , for_compression=False):
        graph = self.create_graph(doc,for_compression,False)
        # TODO : add sentence nodes here
        pass
    def __create_graph(self , doc , for_compression=False, use_general_node=True):
        # nodes size is dependencies + tokens
        data = HeteroData()
        dep_length = len(self.dependencies)
        tag_length = len(self.tags)
        if for_compression:
            data['dep'].x = torch.full((dep_length,),-1, dtype=torch.float32)
            data['word'].x = torch.full((len(doc),),-1, dtype=torch.float32)
            data['tag'].x = torch.full((tag_length,), -1, dtype=torch.float32)
            if use_general_node:
                data['general'].x = torch.full((1,),-1, dtype=torch.float32)
        else:
            data['dep'].x = self.__build_initial_dependency_vectors(dep_length)
            data['word'].x = torch.zeros((len(doc) , self.nlp.vocab.vectors_length), dtype=torch.float32)
            data['tag'].x = self.__build_initial_tag_vectors(tag_length)
            if use_general_node:
                data['general'].x = self.__build_initial_general_vector()
        word_dep_edge_index = []
        dep_word_edge_index = []
        word_tag_edge_index = []
        tag_word_edge_index = []
        word_word_edge_index = []
        word_general_edge_index = []
        general_word_edge_index = []
        word_dep_edge_attr = []
        dep_word_edge_attr = []
        word_tag_edge_attr = []
        tag_word_edge_attr = []
        word_word_edge_attr = []
        word_general_edge_attr = []
        general_word_edge_attr = []
        for token in doc:
            token_id = self.nlp.vocab.strings[token.lemma_]
            if token_id in self.nlp.vocab.vectors:
                if for_compression:
                    data['word'].x[token.i] = torch.tensor(token_id , dtype=torch.float32)
                else:
                    data['word'].x[token.i] = torch.tensor(self.nlp.vocab.vectors[token_id])
            # adding dependency edges
            if token.dep_ != 'ROOT':
                dep_idx = self.__find_dep_index(token.dep_)
                if dep_idx != -1:
                    word_dep_edge_index.append([token.head.i , dep_idx])
                    word_dep_edge_attr.append(self.settings["dep_token_weight"])
                    dep_word_edge_index.append([dep_idx , token.i])
                    dep_word_edge_attr.append(self.settings["dep_token_weight"])
            # adding tag edges
            tag_idx = self.__find_tag_index(token.tag_)
            if tag_idx != -1:
                word_tag_edge_index.append([token.i , tag_idx])
                word_tag_edge_attr.append(self.settings["tag_token_weight"])
                tag_word_edge_index.append([tag_idx , token.i])
                tag_word_edge_attr.append(self.settings["tag_token_weight"])
            # adding sequence edges
            if token.i != len(doc) - 1:
                # using zero vectors for edge features
                word_word_edge_index.append([token.i , token.i + 1])
                word_word_edge_attr.append(self.settings["token_token_weight"])
                word_word_edge_index.append([token.i + 1 , token.i])
                word_word_edge_attr.append(self.settings["token_token_weight"])
            # adding general node edges
            if use_general_node:
                word_general_edge_index.append([token.i , 0])
                word_general_edge_attr.append(self.settings["general_token_weight"])
                general_word_edge_index.append([0 , token.i])
                general_word_edge_attr.append(self.settings["general_token_weight"])
        data['dep' , 'dep_word' , 'word'].edge_index = torch.transpose(torch.tensor(dep_word_edge_index, dtype=torch.long) , 0 , 1)
        data['word' , 'word_dep' , 'dep'].edge_index = torch.transpose(torch.tensor(word_dep_edge_index, dtype=torch.long) , 0 , 1)
        data['tag', 'tag_word', 'word'].edge_index = torch.transpose(torch.tensor(tag_word_edge_index, dtype=torch.long) , 0 , 1)
        data['word', 'word_tag', 'tag'].edge_index = torch.transpose(torch.tensor(word_tag_edge_index, dtype=torch.long) , 0 , 1)
        data['word' , 'seq' , 'word'].edge_index = torch.transpose(torch.tensor(word_word_edge_index, dtype=torch.long) , 0 , 1)
        data['dep' , 'dep_word' , 'word'].edge_attr = dep_word_edge_attr
        data['word' , 'word_dep' , 'dep'].edge_attr = word_dep_edge_attr
        data['tag', 'tag_word', 'word'].edge_attr = tag_word_edge_attr
        data['word', 'word_tag', 'tag'].edge_attr = word_tag_edge_attr
        data['word' , 'seq' , 'word'].edge_attr = word_word_edge_attr
        if use_general_node:
            data['general' , 'general_word' , 'word'].edge_index = torch.transpose(torch.tensor(general_word_edge_index, dtype=torch.long) , 0 , 1)
            data['word' , 'word_general' , 'general'].edge_index = torch.transpose(torch.tensor(word_general_edge_index, dtype=torch.long) , 0 , 1)
            data['general' , 'general_word' , 'word'].edge_attr = general_word_edge_attr
            data['word' , 'word_general' , 'general'].edge_attr = word_general_edge_attr
        return data
    def draw_graph(self , idx : int):
        # TODO : do this part if needed
        pass
    def to_graph_indexed(self, text: str):
        doc = self.nlp(text)
        if len(doc) < 2:
            return
        if self.use_sentence_nodes:
            return self.__create_graph_with_sentences(doc , for_compression=True)
        else:
            return self.__create_graph(doc,for_compression=True,use_general_node=self.use_general_node)
    def convert_indexed_nodes_to_vector_nodes(self, graph):
        print(graph)
        if self.use_sentence_nodes:
            # TODO : complete this part with sentences later
            pass
        else:
            words = torch.zeros((len(graph['word'].x) , self.nlp.vocab.vectors_length), dtype=torch.float32)
            for i in range(len(graph['word'].x)):
                if graph['word'].x[i] in self.nlp.vocab.vectors:
                    words[i] = torch.tensor(self.nlp.vocab.vectors[graph['word'].x[i]])
                else:
                    words[i] = torch.zeros((self.nlp.vocab.vectors_length) , dtype=torch.float32)
            graph['word'].x = words
            graph['dep'].x = self.__build_initial_dependency_vectors(len(self.dependencies))
            graph['tag'].x = self.__build_initial_tag_vectors(len(self.tags))
            if self.use_general_node:
                graph['general'].x = self.__build_initial_general_vector()
        return graph
 

In [24]:
from Scripts.DataManager.GraphConstructor.GraphConstructor import TextGraphType
test_type = TextGraphType.CO_OCCURRENCE | TextGraphType.DEPENDENCY | TextGraphType.SEQUENTIAL
print((TextGraphType.CO_OCCURRENCE in test_type), (TextGraphType.DEPENDENCY in test_type), (TextGraphType.SEQUENTIAL in test_type), (TextGraphType.TAGS in test_type))
test_type = test_type - (TextGraphType.TAGS | TextGraphType.CO_OCCURRENCE)
print((TextGraphType.CO_OCCURRENCE in test_type), (TextGraphType.DEPENDENCY in test_type), (TextGraphType.SEQUENTIAL in test_type), (TextGraphType.TAGS in test_type))

True True True False
False True True False


In [28]:
from copy import copy
import numpy as np
from os import path
from typing import Dict

import pandas as pd
from torch_geometric.loader import DataLoader

from Scripts.Configs.ConfigClass import Config

from Scripts.DataManager.GraphConstructor.CoOccurrenceGraphConstructor import CoOccurrenceGraphConstructor
from Scripts.DataManager.GraphConstructor.TagDepTokenGraphConstructor import TagDepTokenGraphConstructor
from Scripts.DataManager.GraphConstructor.DependencyGraphConstructor import DependencyGraphConstructor
from Scripts.DataManager.GraphConstructor.TagsGraphConstructor import TagsGraphConstructor

from Scripts.DataManager.GraphConstructor.GraphConstructor import GraphConstructor, TextGraphType
from Scripts.DataManager.GraphLoader.GraphDataModule import GraphDataModule
from torch.utils.data.dataset import random_split
import torch
from Scripts.DataManager.Datasets.GraphConstructorDataset import GraphConstructorDataset

class AmazonReviewGraphDataModule(GraphDataModule):

    def __init__(self, config: Config, has_val: bool, has_test: bool, test_size=0.2, val_size=0.2, num_workers=2,
                 drop_last=True, train_data_path='', test_data_path='', graphs_path='', batch_size = 32,
                 device='cpu', shuffle = False, num_data_load=-1,
                 graph_type: TextGraphType = TextGraphType.CO_OCCURRENCE | TextGraphType.DEPENDENCY | TextGraphType.TAGS, *args, **kwargs):
        # kwargs['num_workers'] = num_workers
        # kwargs['batch_size'] = batch_size
        # kwargs['shuffle'] = shuffle
        super(AmazonReviewGraphDataModule, self)\
            .__init__(config, device, has_val, has_test, test_size, val_size, *args, **kwargs)

        self.device = device
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.drop_last = drop_last
        self.graph_type = graph_type
        self.train_data_path = 'data/Amazon-Review/train_sm.csv' if train_data_path == '' else train_data_path
        self.test_data_path = 'data/Amazon-Review/test_sm.csv' if test_data_path == '' else test_data_path
        self.train_df: pd.DataFrame = pd.DataFrame()
        self.test_df: pd.DataFrame = pd.DataFrame()
        self.labels = None
        self.dataset = None
        self.shuffle = shuffle
        self.num_node_features = 0
        self.num_classes = 0
        self.df: pd.DataFrame = pd.DataFrame()
        self.__train_dataset, self.__val_dataset, self.__test_dataset = None, None, None
        self.train_df = pd.read_csv(path.join(self.config.root, self.train_data_path))
        self.test_df = pd.read_csv(path.join(self.config.root, self.test_data_path))
        self.train_df.columns = ['Polarity', 'Title', 'Review']
        self.test_df.columns = ['Polarity', 'Title', 'Review']
        self.train_df = self.train_df[['Polarity', 'Review']]
        self.test_df = self.test_df[['Polarity', 'Review']]
        self.df = pd.concat([self.train_df, self.test_df])
        self.num_data_load = num_data_load if num_data_load>0 else self.df.shape[0]
        self.num_data_load = num_data_load if self.num_data_load < self.df.shape[0] else self.df.shape[0] 
        self.df = self.df.iloc[:self.num_data_load]
        self.df.index = np.arange(0, self.num_data_load)
        
        
        
        self.graph_constructors = self.__set_graph_constructors(self.graph_type)
        graph_constructor = self.graph_constructors[TextGraphType.CO_OCCURRENCE]
        graph_constructor.setup()
        print(f'self.num_data_load: {self.num_data_load}')
        labels = self.df['Polarity'][:self.num_data_load]
        labels = labels.apply(lambda p: 0 if p == 1 else 1).to_numpy()
        labels = torch.from_numpy(labels)
        self.labels = labels.to(torch.float32).view(-1, 1).to(self.device)
        graph_constructor = self.graph_constructors[TextGraphType.CO_OCCURRENCE]
        
        print(f'self.labels.shape: {self.labels.shape}')
        self.dataset = GraphConstructorDataset(graph_constructor, self.labels)
        sample_graph = graph_constructor.get_first()
        self.num_node_features = sample_graph.num_features
        self.num_classes = len(torch.unique(self.labels))
        self.__train_dataset, self.__val_dataset, self.__test_dataset =\
            random_split(self.dataset, [1-self.val_size-self.test_size, self.val_size, self.test_size])
        self.__train_dataloader =  DataLoader(self.__train_dataset, batch_size=self.batch_size, drop_last=self.drop_last, shuffle=self.shuffle, num_workers=0, persistent_workers=False)
        self.__test_dataloader =  DataLoader(self.__test_dataset, batch_size=self.batch_size, num_workers=0, persistent_workers=False)
        self.__val_dataloader =  DataLoader(self.__val_dataset, batch_size=self.batch_size, num_workers=0, persistent_workers=False)
        
    def prepare_data(self):
        pass
        
    def setup(self, stage: str):
        pass

    def teardown(self, stage: str) -> None:
        pass

    def train_dataloader(self):
        return self.__train_dataloader

    def test_dataloader(self):
        return self.__test_dataloader

    def val_dataloader(self):
        return self.__val_dataloader 

    def __set_graph_constructors(self, graph_type: TextGraphType):
        graph_type = copy(graph_type)
        graph_constructors: Dict[TextGraphType, GraphConstructor] = {}
        
        if TextGraphType.CO_OCCURRENCE in graph_type:
            graph_constructors[TextGraphType.CO_OCCURRENCE] = self.__get_co_occurrence_graph()
            graph_type = graph_type - TextGraphType.CO_OCCURRENCE
        
        tag_dep_seq = TextGraphType.DEPENDENCY | TextGraphType.TAGS | TextGraphType.SEQUENTIAL
        if tag_dep_seq in graph_type:
            graph_constructors[tag_dep_seq] = self.__get_dep_and_tag_graph()
            graph_type = graph_type - tag_dep_seq
        
        if TextGraphType.DEPENDENCY in graph_type:
            graph_constructors[TextGraphType.DEPENDENCY] = self.__get_dependency_graph()
            graph_type = graph_type - (TextGraphType.DEPENDENCY | TextGraphType.SEQUENTIAL)
        
        if TextGraphType.TAGS in graph_type:
            graph_constructors[TextGraphType.TAGS] = self.__get_tag_graph()
            graph_type = graph_type - (TextGraphType.TAGS | TextGraphType.SEQUENTIAL)
        
        if TextGraphType.SEQUENTIAL in graph_type:
            graph_constructors[TextGraphType.SEQUENTIAL] = self.__get_Sequential_graph()
            graph_type = graph_type - TextGraphType.SEQUENTIAL
            
        return graph_constructors
    
    

    def __get_co_occurrence_graph(self):
        print(f'self.num_data_load: {self.num_data_load}')
        return CoOccurrenceGraphConstructor(self.df['Review'][:self.num_data_load], 'data/GraphData/AmazonReview', self.config, lazy_construction=False, load_preprocessed_data=True, naming_prepend='graph', num_data_load=self.num_data_load, device=self.device)
    
    def __get_dependency_graph(self):
        print(f'self.num_data_load: {self.num_data_load}')
        return DependencyGraphConstructor(self.df['Review'][:self.num_data_load], 'data/GraphData/AmazonReview', self.config, lazy_construction=False, load_preprocessed_data=True, naming_prepend='graph', num_data_load=self.num_data_load, device=self.device)
    
    def __get_tag_graph(self):
        print(f'self.num_data_load: {self.num_data_load}')
        return TagsGraphConstructor(self.df['Review'][:self.num_data_load], 'data/GraphData/AmazonReview', self.config, lazy_construction=False, load_preprocessed_data=True, naming_prepend='graph', num_data_load=self.num_data_load, device=self.device)
    
    def __get_Sequential_graph(self):
        print(f'self.num_data_load: {self.num_data_load}')
        return DependencyGraphConstructor(self.df['Review'][:self.num_data_load], 'data/GraphData/AmazonReview', self.config, lazy_construction=False, load_preprocessed_data=True, naming_prepend='graph', num_data_load=self.num_data_load, device=self.device)
    
    def __get_dep_and_tag_graph(self):
        print(f'self.num_data_load: {self.num_data_load}')
        return TagDepTokenGraphConstructor(self.df['Review'][:self.num_data_load], 'data/GraphData/AmazonReview2', self.config, lazy_construction=False, load_preprocessed_data=False, naming_prepend='graph', num_data_load=self.num_data_load, device=self.device)
    
    
    def zero_rule_baseline(self):
        return f'zero_rule baseline: {(len(self.labels[self.labels>0.5])* 100.0 / len(self.labels))  : .2f}%'

In [30]:
tag_dep_seq = TextGraphType.DEPENDENCY | TextGraphType.TAGS | TextGraphType.SEQUENTIAL
data_manager = AmazonReviewGraphDataModule(config, True, True, shuffle=True, num_data_load = 100, device='cpu', batch_size=batch_size, graph_type=tag_dep_seq)

self.num_data_load: 100


TypeError: TagDepTokenGraphConstructor.__init__() got an unexpected keyword argument 'num_data_load'

In [None]:
t_dataloader = data_manager.train_dataloader()
v_dataloader = data_manager.val_dataloader()

In [None]:
X1, y1 = next(iter(t_dataloader))
X2, y2 = next(iter(v_dataloader))
print(X1[0].x.device)
print(X2[0].x.device)

In [None]:
class HeteroGCNConv(nn.Module):
    def __init__(self, in_feature, out_feature, dropout = 0.57) -> None:
        super().__init__()
        self.conv1 = GATv2Conv(in_feature, out_feature, edge_dim=1, add_self_loops=False)
        self.conv2 = GATv2Conv(out_feature, out_feature, edge_dim=1, add_self_loops=False)
        
        self.batch_norm = BatchNorm(out_feature)
        self.dropout= nn.Dropout(dropout)
        
    def forward(self, x: Tensor, edge_index: Tensor, edge_weights: Tensor) -> Tensor:
        x = F.relu(self.conv1(x, edge_index, edge_attr=edge_weights))
        x = self.conv2(x, edge_index, edge_attr=edge_weights)
        x = self.batch_norm(x)
        x = F.relu(x)
        x = self.dropout(x)
        return x


In [None]:
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F
from torch import Tensor
import torch
import torch_geometric
from torch.nn import Linear
from torch_geometric.nn import GATv2Conv, GCNConv, GCN2Conv, DenseGCNConv, dense_diff_pool, BatchNorm, global_mean_pool, global_add_pool, global_max_pool, MemPooling, SAGEConv, to_hetero, HeteroBatchNorm
from torch_geometric.nn import Sequential as GSequential
from torch_geometric.utils import to_dense_adj
from torch import nn, Tensor
import torch.nn.functional as F

class HeteroGcnGatModel1(torch.nn.Module):
    def __init__(self, 
                 input_feature: int, out_features: int,
                 metadata,
                 base_hidden_feature: int=256,
                 dropout=0.1):
        
        super(GNN, self).__init__()
        self.input_features = input_feature
        self.num_out_features = out_features
        self.bsh: int = base_hidden_feature
        bsh2: int = int(self.bsh/2)
        bsh4: int = int(self.bsh/4)
        bsh8: int = int(self.bsh/8)
        
        self.encoder = GSequential('x_dict, edge_index, edge_weights', [
            (to_hetero(HeteroGCNConv(input_feature, self.bsh), metadata), 'x_dict, edge_index, edge_weights ->x1'),
            (to_hetero(HeteroGCNConv(self.bsh, self.bsh), metadata), 'x1, edge_index, edge_weights ->x1'),
            (to_hetero(HeteroGCNConv(self.bsh, self.bsh), metadata), 'x1, edge_index, edge_weights ->x1'),
            (to_hetero(HeteroGCNConv(self.bsh, bsh2), metadata), 'x1, edge_index, edge_weights -> x2'),
            (to_hetero(HeteroGCNConv(bsh2, bsh2), metadata), 'x2, edge_index, edge_weights -> x2'),
            (to_hetero(HeteroGCNConv(bsh2, bsh2), metadata), 'x2, edge_index, edge_weights -> x2'),
            (to_hetero(HeteroGCNConv(bsh2, bsh4), metadata), 'x2, edge_index, edge_weights -> x3'),
            (to_hetero(HeteroGCNConv(bsh4, bsh4), metadata), 'x3, edge_index, edge_weights -> x3'),
            (to_hetero(HeteroGCNConv(bsh4, bsh4), metadata), 'x3, edge_index, edge_weights -> x3'),
            (to_hetero(HeteroGCNConv(bsh4, bsh8), metadata), 'x3, edge_index, edge_weights -> x4'),
            (to_hetero(HeteroGCNConv(bsh8, bsh8), metadata), 'x4, edge_index, edge_weights -> x4'),
            (to_hetero(HeteroGCNConv(bsh8, bsh8), metadata), 'x4, edge_index, edge_weights -> x4'),
            (lambda x1, x2, x3, x4: (x1, x2, x3, x4), 'x1, x2, x3, x4 -> x1, x2, x3, x4')
        ])
        
        self.attention = GSequential('x3, x4, edge_index, edge_weights', [
            (GATv2Conv(bsh8, bsh8, 2, dropout=dropout), 'x4, edge_index ->x4'),
            (BatchNorm(bsh4), 'x4->x4'),
            (nn.ReLU(), 'x4->x4'),
            
            (GCN2Conv(bsh4, 0.5, 0.1, 2), 'x4, x3, edge_index, edge_weights->x3'),
            (BatchNorm(bsh4), 'x3->x3'),
            (nn.ReLU(), 'x3->x3'),
            (GCNConv(bsh4, bsh4), 'x3, edge_index, edge_weights -> x3'),
            (BatchNorm(bsh4), 'x3->x3'),
            (nn.ReLU(), 'x3->x3'),
            
            (GATv2Conv(bsh4, bsh4, 2, dropout=dropout), 'x3, edge_index ->x3'),
            (BatchNorm(bsh2), 'x3->x3'),
            (nn.ReLU(), 'x3->x3'),
            (lambda x3, x4: (x3, x4), 'x3, x4 -> x3, x4')
        ])
        
        self.decoder = GSequential('x1, x2, x3, edge_index, edge_weights', [
            
            (GCN2Conv(bsh2, 0.5, 0.1, 2), 'x3, x2, edge_index, edge_weights->x2'),
            (BatchNorm(bsh2), 'x2->x2'),
            (nn.ReLU(), 'x2->x2'),
            (nn.Dropout(dropout), 'x2->x2'),
            (GCNConv(bsh2, bsh2), 'x2, edge_index, edge_weights -> x2'),
            (BatchNorm(bsh2), 'x2->x2'),
            (nn.ReLU(), 'x2->x2'),
            (nn.Dropout(dropout), 'x2->x2'),
            (GCNConv(bsh2, self.bsh), 'x2, edge_index->x2'),
            (BatchNorm(self.bsh), 'x2->x2'),
            (nn.ReLU(), 'x2->x2'),
            (nn.Dropout(dropout), 'x2->x2'),
            
            (GCN2Conv(self.bsh, 0.5, 0.1, 2), 'x2, x1, edge_index, edge_weights->x1'),
            (BatchNorm(self.bsh), 'x1->x1'),
            (nn.ReLU(), 'x1->x1'),
            (nn.Dropout(dropout), 'x1->x1'),
            (GCNConv(self.bsh, self.bsh), 'x1, edge_index, edge_weights ->x1'),
            (BatchNorm(self.bsh), 'x1->x1'),
            (nn.ReLU(), 'x1->x1'),
            (nn.Dropout(dropout), 'x1->x1'),
            (GCNConv(self.bsh, self.bsh), 'x1, edge_index, edge_weights ->x1'),
            (BatchNorm(self.bsh), 'x1->x1'),
            (nn.ReLU(), 'x1->x1'),
            (nn.Dropout(dropout), 'x1->x1'),
            (GCNConv(self.bsh, self.input_features), 'x1, edge_index, edge_weights -> x1'),
        ])
        
                
        self.user_emb = torch.nn.Embedding(data['user'].num_nodes, self.input_features)
        self.movie_emb = torch.nn.Embedding(data['movie'].num_nodes, self.input_features)
        self.movie_lin = torch.nn.Linear(20, self.input_features)
        
    def forward(self, x: HeteroData) -> Tensor:
        x_dict = {
          "user": self.user_emb(x["user"].node_id),
          "movie": self.movie_lin(x["movie"].x) + self.movie_emb(x["movie"].node_id)
        }

        x1_dict, x2_dict, x3_dict, x4_dict = self.encoder(x_dict, x.edge_index_dict, x.edge_attr_dict)
        print(f'x4_dict["movie"]: {x4_dict["movie"].shape}, x_dict["movie"]: {x_dict["movie"].shape}')
        x_att, x4 = self.attention(x3_dict["movie"], x4_dict["movie"], 
                                   x.edge_index_dict[('movie','rev_rates','user')],
                                   x.edge_attr_dict[('movie','rev_rates','user')])
        x_dec = self.decoder(x1_dict["movie"], x2_dict["movie"], x_att, 
                             x.edge_index_dict[('movie','rev_rates','user')],
                             x.edge_attr_dict[('movie','rev_rates','user')])
        return x_dict #x_dec

In [None]:
torch_model = GcnGatModel1(300, 1, 1024, dropout=0.2)
torch_model = torch_model.to(device)
# print(next(iter(torch_model.parameters())).device)
pre = torch_model(X1.to(device))

In [None]:
import torch
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))

In [None]:
# from Scripts.Models.BaseModels.GcnGatModel1 import GcnGatModel1
from Scripts.Models.LightningModels.LightningModels import BinaryLightningModel
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import lightning as L
from lightning.pytorch.tuner import Tuner

In [None]:
callbacks = [
    ModelCheckpoint(save_top_k=2, mode='max', monitor='val_acc', save_last=True),
    EarlyStopping(patience=50, mode='max', monitor='val_acc')
]

In [None]:
lightning_model = BinaryLightningModel(torch_model,
                                 torch.optim.Adam(torch_model.parameters(), lr=0.00017, weight_decay=0.00055),
                                       torch.nn.BCEWithLogitsLoss(),
                                       learning_rate=0.00017,
                                       batch_size=batch_size,
                                       ).to(device)

In [None]:
trainer = L.Trainer(
            callbacks=callbacks,
            max_epochs=5000,
            accelerator='gpu',
            logger=CSVLogger(save_dir='logs/', name='GcnGatSentiment3'),
            num_sanity_val_steps=0)

In [None]:
tuner = Tuner(trainer)
results = tuner.lr_find(lightning_model, datamodule=data_manager, min_lr=0.00001,max_lr=0.03)

In [None]:
fig = results.plot(suggest=True)

In [None]:
# trainer.fit(lightning_model, datamodule=data_manager)
trainer.fit(lightning_model,
            datamodule=data_manager
            # train_dataloaders=data_manager.train_dataloader(),
            # val_dataloaders=data_manager.val_dataloader(),
            # logger=CSVLogger(save_dir='logs/', name='sample_model'),
            # default_root_dir="~/Desktop"
            )

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
'''From torch lightning tutorials'''
def plot_csv_logger(csv_path, loss_names=['train_loss', 'val_loss'], eval_names=['train_acc', 'val_acc']):
    metrics = pd.read_csv(csv_path)
    
    aggregation_metrics = []
    agg_col = 'epoch'
    for i, dfg in metrics.groupby(agg_col):
        agg = dict(dfg.mean())
        agg[agg_col] = i
        aggregation_metrics.append(agg)
    
    df_metrics = pd.DataFrame(aggregation_metrics)
    df_metrics[loss_names].plot(grid=True, legend=True, xlabel='Epoch', ylabel='loss')
    df_metrics[eval_names].plot(grid=True, legend=True, xlabel='Epoch', ylabel='accuracy')
    plt.show()
    

In [None]:
plot_csv_logger(r'C:\Users\fardin\Projects\ColorIntelligence\logs\GcnGatSentiment3\version_4\metrics.csv')