# Test Sentiment Classification
### On Amazon-Review, Full Heterogeneous Graph

In [None]:
import numpy as np
from Scripts.Configs.ConfigClass import Config
from Scripts.DataManager.GraphConstructor.GraphConstructor import TextGraphType
from lightning.pytorch.loggers import CSVLogger
import os
from Scripts.DataManager.GraphLoader.AmazonReviewGraphDataModule import AmazonReviewGraphDataModule
import time
import torch

In [None]:
config = Config(r'E:\Darsi\Payan Name Arshad\Second Work\ColorIntelligence2\ColorIntelligence')
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
device = 'cuda'
batch_size = 256

In [None]:
# tag_dep_seq_sent = TextGraphType.TAGS | TextGraphType.DEPENDENCY | TextGraphType.SEQUENTIAL | TextGraphType.SENTENCE | TextGraphType.SENTIMENT
tag_dep_seq_sent = TextGraphType.SENTIMENT
removals = ['sentiment' , 'tag' , 'general' , 'sentence']
data_manager = AmazonReviewGraphDataModule(config, True, True, shuffle=False, start_data_load=0 , end_data_load = -1, device='cpu', batch_size=batch_size, graph_type=tag_dep_seq_sent, load_preprocessed_data = True)
data_manager.load_labels()
data_manager.load_graphs()

In [None]:
# data_manager.update_batch_size(128)
t_dataloader = data_manager.train_dataloader()
v_dataloader = data_manager.val_dataloader()
X1, y1 = next(iter(t_dataloader))
X2, y2 = next(iter(v_dataloader))
# X1.metadata()
len(X1)

In [None]:
from Scripts.Models.LightningModels.LightningModels import HeteroBinaryLightningModel
from Scripts.Models.LossFunctions.HeteroLossFunctions import HeteroLossArgs, HeteroLoss1, HeteroLoss2
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from Scripts.Models.ModelsManager.ClassifierModelManager import ClassifierModelManager

In [None]:
X1.metadata()

In [None]:
edge_type_weights = {
    'full': [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
    'seq': [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
    'dep': [1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
    'tag': [0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
    'general_sentence': [0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0.],
    'sentence': [0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0.],
    'sentiment': [0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1.],
}

In [None]:
from Scripts.Models.GraphEmbedding.HeteroDeepGraphEmbedding2 import HeteroDeepGraphEmbedding2

In [None]:
import torch
from Scripts.Models.LightningModels.LightningModels import BaseLightningModel
from Scripts.Models.ModelsManager.ModelManager import ModelManager
import pandas as pd
import matplotlib.pyplot as plt
from typing import List
from torch_geometric.nn import summary
from lightning.pytorch.callbacks import Callback, ModelCheckpoint, EarlyStopping
from os import path

from sklearn.metrics import classification_report, f1_score, accuracy_score, precision_score, recall_score, confusion_matrix, hinge_loss

class ClassifierModelManager(ModelManager):

    def __init__(self,
                 torch_model: torch.nn.Module,
                 lightning_model: BaseLightningModel,
                 model_save_dir: str = '~/Desktop',
                 log_dir: str = 'logs/',
                 log_name: str = 'model_logs',
                 device='cpu',
                 num_train_epoch = 100):
        super(ClassifierModelManager, self).__init__(torch_model, lightning_model, model_save_dir, log_dir, log_name, device, num_train_epoch)

    def _create_callbacks(self) -> List[Callback]:
        return [
            ModelCheckpoint(save_top_k=2, mode='max', monitor='val_acc', save_last=True),
            # EarlyStopping(patience=50, mode='max', monitor='val_acc')
        ]

    def draw_summary(self, dataloader):
        X, y = next(iter(dataloader))
        print(summary(self.torch_model, X.to(self.device)))

    def plot_csv_logger(self, loss_names=['train_loss', 'val_loss'], eval_names=['train_acc', 'val_acc']):
        csv_path = path.join(self.log_dir, self.log_name, f'version_{self.logger.version}', 'metrics.csv')
        metrics = pd.read_csv(csv_path)

        aggregation_metrics = []
        agg_col = 'epoch'
        for i, dfg in metrics.groupby(agg_col):
            agg = dict(dfg.mean())
            agg[agg_col] = i
            aggregation_metrics.append(agg)

        df_metrics = pd.DataFrame(aggregation_metrics)
        df_metrics[loss_names].plot(grid=True, legend=True, xlabel='Epoch', ylabel='loss')
        df_metrics[eval_names].plot(grid=True, legend=True, xlabel='Epoch', ylabel='accuracy')
        plt.show()

    def save_plot_csv_logger(self, loss_names=['train_loss', 'val_loss'], eval_names=['train_acc', 'val_acc'], name_prepend: str=""):
        csv_path = path.join(self.log_dir, self.log_name, f'version_{self.logger.version}', 'metrics.csv')
        metrics = pd.read_csv(csv_path)

        aggregation_metrics = []
        agg_col = 'epoch'
        for i, dfg in metrics.groupby(agg_col):
            agg = dict(dfg.mean())
            agg[agg_col] = i
            aggregation_metrics.append(agg)

        df_metrics = pd.DataFrame(aggregation_metrics)
        df_metrics[loss_names].plot(grid=True, legend=True, xlabel='Epoch', ylabel='loss')
        
        loss_png = path.join(self.log_dir, self.log_name, f'version_{self.logger.version}', f'{name_prepend}_loss_metric.png')
        plt.savefig(loss_png)
        
        df_metrics[eval_names].plot(grid=True, legend=True, xlabel='Epoch', ylabel='accuracy')
        
        acc_png = path.join(self.log_dir, self.log_name, f'version_{self.logger.version}', f'{name_prepend}_acc_metric.png')
        plt.savefig(acc_png)
        
        plt.close()
    
    def evaluate(self, eval_dataloader,
                 give_confusion_matrix: bool=True, 
                 give_report: bool=True, 
                 give_f1_score: bool=False, 
                 give_accuracy_score: bool=False, 
                 give_precision_score: bool=False, 
                 give_recall_score: bool=False, 
                 give_hinge_loss: bool=False):
        y_true = []
        y_pred = []
        self.lightning_model.eval()
        for X, y in eval_dataloader:
            y_p = self.torch_model(X.to(self.device))
            if type(y_p) is tuple:
                y_p = y_p[0]
            y_pred.append((y_p>0).to(torch.int32).detach().to(y.device))
            y_true.append(y.to(torch.int32))
        y_true = torch.concat(y_true)
        y_pred = torch.concat(y_pred)
        if(give_confusion_matrix):
            print(f'confusion_matrix: \n{confusion_matrix(y_true, y_pred)}')
        if(give_report):
            print(classification_report(y_true, y_pred))
        if(give_f1_score):
            print(f'f1_score: {f1_score(y_true, y_pred)}')
        if(give_accuracy_score):
            print(f'accuracy_score: {accuracy_score(y_true, y_pred)}')
        if(give_precision_score):
            print(f'precision_score: {precision_score(y_true, y_pred)}')
        if(give_recall_score):
            print(f'recall_score: {recall_score(y_true, y_pred)}')
        if(give_hinge_loss):
            print(f'hinge_loss: {hinge_loss(y_true, y_pred)}')
                
    def save_evaluation(self, eval_dataloader, name_prepend: str='',
                 give_confusion_matrix: bool=True, 
                 give_report: bool=True, 
                 give_f1_score: bool=False, 
                 give_accuracy_score: bool=False, 
                 give_precision_score: bool=False, 
                 give_recall_score: bool=False, 
                 give_hinge_loss: bool=False):
        
        test_metrics_path = path.join(self.log_dir, self.log_name, f'version_{self.logger.version}', f'{name_prepend}_test_metrics.txt')
        
        y_true = []
        y_pred = []
        self.lightning_model.eval()
        for X, y in eval_dataloader:
            y_p = self.torch_model(X.to(self.device))
            if type(y_p) is tuple:
                y_p = y_p[0]
            y_pred.append((y_p>0).to(torch.int32).detach().to(y.device))
            y_true.append(y.to(torch.int32))
        y_true = torch.concat(y_true)
        y_pred = torch.concat(y_pred)
        with open(test_metrics_path, 'at+') as f:
            if(give_confusion_matrix):
                print(f'confusion_matrix: \n{confusion_matrix(y_true, y_pred)}', file=f)
            if(give_report):
                print(classification_report(y_true, y_pred), file=f)
            if(give_f1_score):
                print(f'f1_score: {f1_score(y_true, y_pred)}', file=f)
            if(give_accuracy_score):
                print(f'accuracy_score: {accuracy_score(y_true, y_pred)}', file=f)
            if(give_precision_score):
                print(f'precision_score: {precision_score(y_true, y_pred)}', file=f)
            if(give_recall_score):
                print(f'recall_score: {recall_score(y_true, y_pred)}', file=f)
            if(give_hinge_loss):
                print(f'hinge_loss: {hinge_loss(y_true, y_pred)}', file=f)

In [None]:
for k in edge_type_weights:
    if k == 'dep':
        graph_embedding = HeteroDeepGraphEmbedding2(300, 1, X1.metadata(), 128, dropout=0.2, edge_type_count=11, edge_type_weights=edge_type_weights[k])
        graph_embedding = graph_embedding.to(device)
        callbacks = [
        ModelCheckpoint(save_top_k=5, mode='max', monitor='val_acc', save_last=True)
        ]
        lightning_model = HeteroBinaryLightningModel(graph_embedding,
                                        torch.optim.Adam(graph_embedding.parameters(), lr=0.0045, weight_decay=0.001),
                                            loss_func=HeteroLoss1(exception_keys=['word'], enc_factor=0.005),
                                            learning_rate=0.0045,
                                            batch_size=batch_size,
                                            user_lr_scheduler=True,
                                            min_lr=0.0003
                                            ).to(device)
        model_manager = ClassifierModelManager(graph_embedding, lightning_model, log_name='hetero_model_10', model_save_dir=r'E:\Darsi\Payan Name Arshad\Second Work\ColorIntelligence2\ColorIntelligence\Practices\Tasks\HeterogeneousGraphs\hetero_model_7',device=device, num_train_epoch=67)
        model_manager.fit(datamodule=data_manager,ckpt_path="E:\\Darsi\\Payan Name Arshad\\Second Work\\ColorIntelligence2\\ColorIntelligence\\logs\\hetero_model_10\\version_2\\checkpoints\\epoch=67-step=17476.ckpt")
        # model_manager.fit(datamodule=data_manager)
        # model_manager.save_plot_csv_logger(loss_names=['train_loss', 'val_loss'], eval_names=['train_acc_epoch', 'val_acc_epoch'], name_prepend=f'tests_{k}')
        # model_manager.torch_model = model_manager.torch_model.to('cuda')
        # model_manager.save_evaluation(data_manager.test_dataloader(), f'{k}',True, True, True, True, True, True, True)

In [None]:
model_manager.save_plot_csv_logger(loss_names=['train_loss', 'val_loss'], eval_names=['train_acc_epoch', 'val_acc_epoch'], name_prepend=f'tests_{k}')
model_manager.torch_model = model_manager.torch_model.to('cuda')
model_manager.save_evaluation(data_manager.test_dataloader(), f'{k}',True, True, True, True, True, True, True)