In [None]:
## Run in py39_2 conda env
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from tqdm import tqdm
import pathlib
import math
import sklearn
import torch_optimizer as optim
from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau
from metrics import *
from torch_geometric.nn import GCNConv, GATConv, global_max_pool as gmp, global_add_pool as gap,global_mean_pool as gep,global_sort_pool
from torch_geometric.utils import dropout_adj
import networkx as nx

import biographs as bg
from Bio import SeqIO
from Bio.PDB.PDBParser import PDBParser

In [None]:

# Check if CUDA (GPU) is available and set the device accordingly
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


In [None]:
# ftrs = np.load("cas13a_embedding.npy", allow_pickle=True)
# print(ftrs.shape)

In [None]:
import MDAnalysis as mda

workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/Chinmai_proj/win0_vs_win16/'

# Load the DCD trajectory file and associated topology (PSF or PDB file)
# dcd_file = os.path.join(workdir, "below-US_window1.dcd")
# topology_file = os.path.join(workdir, "SYS_nowat-ref.pdb")  # or "your_topology_file.pdb"

dcd_file = os.path.join(workdir, "below-US_window16.dcd")
topology_file = os.path.join(workdir, "SYS_nowat-ref.pdb")  # or "your_topology_file.pdb"

# os.makedirs(os.path.join(workdir, "frame_1"), exist_ok = True)
os.makedirs(os.path.join(workdir, "frame_16"), exist_ok = True)

# Create a Universe object to represent the system
u = mda.Universe(topology_file, dcd_file)

# Iterate over each frame in the trajectory
for ts in u.trajectory:
    # Create a PDB filename for the current frame (you can modify the naming as needed)
    pdb_filename = os.path.join(workdir,f"frame_16/frame_{ts.frame}.pdb")
    ag = u.select_atoms("protein")

    # Write the coordinates of the current frame to the PDB file
    with mda.Writer(pdb_filename, bonds=None, n_atoms=ag.atoms.n_atoms) as pdb:
        pdb.write(ag.atoms)

    print(f"Saved PDB file: {pdb_filename}")


In [None]:
from bio_embeddings.embed import SeqVecEmbedder

workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/Chinmai_proj/win0_vs_win16_win56/'
# Dictionary for getting Residue symbols
ressymbl = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU':'E', 'PHE': 'F', 'GLY': 'G', 'HIE': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN':'Q', 'ARG':'R', 'SER': 'S','THR': 'T', 'VAL': 'V', 'TRP':'W', 'TYR': 'Y'}
# get structure from a pdb file
# Uses biopython
def get_structure(file):
    parser = PDBParser()
    structure = parser.get_structure(id, file) # return a Structure object.
    return structure

# Function to get sequence from pdb structure
# Uses structure made using biopython
# Those residues for which symbols are U / X are converted into A
def get_sequence(structure):
    sequence =""
    for model in structure:
      for chain in model:
        for residue in chain:
          if residue.get_resname() in ressymbl.keys():
              sequence = sequence+ ressymbl[residue.get_resname()]
    return sequence

# Define the protein sequence
# seq = 'MKVTKVGGISHKKYTSEGRLVKSESEENRTDERLSALLNMRLDMYIKNPSSTETKENQKRIGKLKKFFSNKMVYLKDNTLSLKNGKKENIDREYSETDILESDVRDKKNFAVLKKIYLNENVNSEELEVFRNDIKKKLNKINSLKYSFEKNKANYQKINENNIEKVEGKSKRNIIYDYYRESAKRDAYVSNVKEAFDKLYKEEDIAKLVLEIENLTKLEKYKIREFYHEIIGRKNDKENFAKIIYEEIQNVNNMKELIEKVPDMSELKKSQVFYKYYLDKEELNDKNIKYAFCHFVEIEMSQLLKNYVYKRLSNISNDKIKRIFEYQNLKKLIENKLLNKLDTYVRNCGKYNYYLQDGEIATSDFIARNRQNEAFLRNIIGVSSVAYFSLRNILETENENDITGRMRGKTVKNNKGEEKYVSGEVDKIYNENKKNEVKENLKMFYSYDFNMDNKNEIEDFFANIDEAISSIRHGIVHFNLELEGKDIFAFKNIAPSEISKKMFQNEINEKKLKLKIFRQLNSANVFRYLEKYKILNYLKRTRFEFVNKNIPFVPSFTKLYSRIDDLKNSLGIYWKTPKTNDDNKTKEIIDAQIYLLKNIYYGEFLNYFMSNNGNFFEISKEIIELNKNDKRNLKTGFYKLQKFEDIQEKIPKEYLANIQSLYMINAGNQDEEEKDTYIDFIQKIFLKGFMTYLANNGRLSLIYIGSDEETNTSLAEKKQEFDKFLKKYEQNNNIKIPYEINEFLREIKLGNILKYTERLNMFYLILKLLNHKELTNLKGSLEKYQSANKEEAFSDQLELINLLNLDNNRVTEDFELEADEIGKFLDFNGNKVKDNKELKKFDTNKIYFDGENIIKHRAFYNIKKYGMLNLLEKIADKAGYKISIEELKKYSNKKNEIEKNHKMQENLHRKYARPRKDEKFTDEDYESYKQAIENIEEYTHLKNKVEFNELNLLQGLLLRILHRLVGYTSIWERDLRFRLKGEFPENQYIEEIFNFENKKNVKYKGGQIVEKYIKFYKELHQNDEVKINKYSSANIKVLKQEKKDLYIANYIAAFNYIPHAEISLLEVLENLRKLLSYDRKLKNAVMKSVVDILKEYGFVATFKIGADKKIGIQTLESEKIVHLKNLKKKKLMTDRNSEELCKLVKIMFEYKMEEKKSEN'
seq = get_sequence(get_structure(os.path.join(workdir, "SYS_nowat-ref.pdb")))
# Create an instance of the SeqVecEmbedder
embedder = SeqVecEmbedder()

# Obtain the embedding of the protein sequence
embedding = embedder.embed(seq)

# Convert the embedding to a PyTorch tensor on CPU
protein_embd = torch.tensor(embedding).sum(dim=0).cpu().numpy()
print(protein_embd.shape)

In [None]:
from torch_geometric.data import Dataset, Dataset, download_url, Data,  Batch

# list of 20 proteins
pro_res_table = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']

# Dictionary for getting Residue symbols
ressymbl = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU':'E', 'PHE': 'F', 'GLY': 'G', 'HIE': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN':'Q', 'ARG':'R', 'SER': 'S','THR': 'T', 'VAL': 'V', 'TRP':'W', 'TYR': 'Y'}


class ProteinDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(ProteinDataset, self).__init__(root, transform=None,
                 pre_transform=None)
        self.data = self.processed_paths

    @property
    def raw_file_names(self):

#         return [filename for filename in os.scandir(self.root)] # A list of files in the raw directory which needs to be found in order to skip the download. (this file path is also raw_path)
        files = [filename for filename in os.scandir(self.root+"/raw")] # so keep all the pdb files in the raw/ directory
#         print(files)
#         print(self.raw_paths)
        return files


    
    @property
    def processed_file_names(self):
#         print(self.raw_paths)
        file_names = [file.split('/')[0]+"/"+file.split('/')[1]+"/"+file.split('/')[4] for file in self.raw_paths]
#         file = self.raw_paths[0]
#         print(file_names)
#         return [os.path.splitext(os.path.basename(file))[0]+'.pt' for file in self.raw_paths] # A list of files in the processed_dir which needs to be found in order to skip the processing. If *.pt not found, data will be processed
        return [os.path.splitext(os.path.basename(file))[0]+'.pt' for file in file_names]
    def download(self):
        # Download to `self.raw_dir`.
        pass

    def process(self):
        # Read data into huge `Data` list.
        self.data = self.processed_paths


        data_list =[]
        count = 0
        file_names = [file.split('/')[0]+"/"+file.split('/')[1]+"/"+file.split('/')[4] for file in self.raw_paths]
#         for file in tqdm(self.raw_paths): # tqdm is a progress bar library that can be used to visualize progress when iterating through a sequence (in this case, file paths).
        for file in tqdm(file_names): # tqdm is a progress bar library that can be used to visualize progress when iterating through a sequence (in this case, file paths).
           if(pathlib.Path(file).suffix ==".pdb"): # checks if the current file has a ".pdb" extension. This is done using pathlib.Path(file).suffix, which returns the file extension.

               try:
                struct = self._get_structure(file) # extract structural information from the PDB file (file)
               except:
                print('except', file)
                continue
               seq = self._get_sequence(struct) # extract the sequence information from the structural information (struct)

          # node features extracted
              #  node_feats = self._get_one_hot_symbftrs(seq) # extract node features (one-hot encoded symbol features) from the sequence (seq)
               node_feats = torch.tensor(protein_embd) # extract node features (one-hot encoded symbol features) from the sequence (seq)

          #edge-index extracted


               mat = self._get_adjacency(file) # extract the adjacency matrix (mat) from the PDB file (file

           # if sequence size > matrix dimensions
               if(mat.shape[0] < torch.Tensor.size(node_feats)[0]) :
                 #node_feats = torch.tensor(ftrs.item()[os.path.splitext(os.path.basename(file))[0]])
                 edge_index = self._get_edgeindex(file, mat)

                 print(f'Node features size :{torch.Tensor.size(node_feats)}')
                 print(f'mat size :{mat.shape}')
          # create data object

                 data = Data(x = node_feats, edge_index = edge_index )
                 count += 1
                 data_list.append(data)
                 torch.save(data, self.processed_dir + "/"+ os.path.splitext(os.path.basename(file))[0]+'.pt')


               elif mat.shape[0] == torch.Tensor.size(node_feats)[0] :
                 #node_feats = torch.tensor(ftrs.item()[os.path.splitext(os.path.basename(file))[0]])
                 edge_index = self._get_edgeindex(file, mat)


                 print(f'Node features size :{torch.Tensor.size(node_feats)}')
                 print(f'mat size :{mat.shape}')

          # create data object

                 data = Data(x = node_feats, edge_index = edge_index )
                 count += 1

                 data_list.append(data)
                 torch.save(data, self.processed_dir + "/"+ os.path.splitext(os.path.basename(file))[0]+'.pt')

        self.data_prot = data_list
        print(count)


        # data, slices = self.collate(data_list)
        # torch.save((data, slices), self.processed_paths[0])

    def __len__(self):
        return len(self.processed_file_names)


    # file stands for file path
    def __getitem__(self, idx):

        return self.data_prot[idx]

    # Biographs returns the network as a networkx.Graph object
    def _get_adjacency(self, file):
        edge_ind =[]
        molecule = bg.Pmolecule(file)
        network = molecule.network() # network.nodes[:10] can be used for output residue nodes; network = molecule.network(cutoff=8) to increase distance cut-off from 5A (default to 8A)
        mat = nx.adjacency_matrix(network) # calculates the adjacency matrix of a NetworkX graph ()
        m = mat.todense() # convert to dense matrix (probably internally using scipy library)
        return m


    # get adjacency matrix in coo format to pass in GCNN model
    def _get_edgeindex(self, file, adjacency_mat):
        edge_ind = []
        m = self._get_adjacency(file) #
        #check_symmetric(m, rtol=1e-05, atol=1e-08)

        a = np.nonzero(m > 0)[0] # find the indices of nonzero elements in the matrix m. [row]
        b = np.nonzero(m > 0)[1] # find the indices of nonzero elements in the matrix m. [column]
        edge_ind.append(a) # These arrays represent the edge indices of the nonzero elements in the adjacency matrix.
        edge_ind.append(b)
        return torch.tensor(np.array(edge_ind), dtype= torch.long) # creates a PyTorch tensor from the concatenated arrays in edge_ind, converting it to a long data type tensor


    # get structure from a pdb file
    # Uses biopython
    def _get_structure(self, file):
        parser = PDBParser()
        structure = parser.get_structure(id, file) # return a Structure object.
        return structure

    # Function to get sequence from pdb structure
    # Uses structure made using biopython
    # Those residues for which symbols are U / X are converted into A
    def _get_sequence(self, structure):
        sequence =""
        for model in structure:
          for chain in model:
            for residue in chain:
              if residue.get_resname() in ressymbl.keys():
                  sequence = sequence+ ressymbl[residue.get_resname()]
        return sequence


    # One hot encoding for symbols
    def _get_one_hot_symbftrs(self, sequence):
        one_hot_symb = np.zeros((len(sequence),len(pro_res_table)))
        row= 0
        for res in sequence:
          col = pro_res_table.index(res)
          one_hot_symb[row][col]=1
          row +=1
        return torch.tensor(one_hot_symb, dtype= torch.float)


    # Residue features calculated from pcp_dict
    def _get_res_ftrs(self, sequence):
        res_ftrs_out = []
        for res in sequence:
          res_ftrs_out.append(pcp_dict[res])
        res_ftrs_out= np.array(res_ftrs_out)
        #print(res_ftrs_out.shape)
        return torch.tensor(res_ftrs_out, dtype = torch.float)


    # total features after concatenating one_hot_symbftrs and res_ftrs
    def _get_node_ftrs(self, sequence):
        one_hot_symb = one_hot_symbftrs(sequence)
        res_ftrs_out = res_ftrs(sequence)
        return torch.tensor(np.hstack((one_hot_symb, res_ftrs_out)), dtype = torch.float)


In [None]:
# molecule = bg.Pmolecule("/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/Chinmai_proj/frame_1/frame_100.pdb")
# network = molecule.network()

In [None]:
# network.nodes[:10]
# print(network)
# network.nodes

In [None]:
# prot_graphs = ProteinDataset(root="frame_1/")
# prot_graphs = ProteinDataset(root="frame_16/")

In [None]:
# print(trainset.indices)
# trainloader = DataLoader(dataset=trainset, batch_size=64, num_workers=0)
# print(trainloader.batch_size)
# for batch in trainloader:
# #     print('a')
#     data, labels = batch
#     print(labels)

In [None]:
import os
import glob
import torch
from torch.utils.data import Dataset
from torch_geometric.data import DataLoader, Data

class ManuallyLabelledDataset(Dataset):
    def __init__(self, label0_dir, label1_dir, label2_dir):
        self.label0_files = glob.glob(os.path.join(label0_dir, "frame*.pt"))
        self.label1_files = glob.glob(os.path.join(label1_dir, "frame*.pt"))
        self.label2_files = glob.glob(os.path.join(label2_dir, "frame*.pt"))

        self.files = self.label0_files + self.label1_files + self.label2_files
        self.labels = [0] * len(self.label0_files) + [1] * len(self.label1_files) + [2] * len(self.label2_files)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        label = self.labels[idx]
        
        data = torch.load(file_path)
        # data.y = torch.tensor([label], dtype=torch.float32)

        # Convert label to a tensor (dtype long for multi-class classification)
        data.y = torch.tensor([label], dtype=torch.long)
        
        return data

# Define paths to the directories
label0_dir = "frame_1/processed"
label1_dir = "frame_16/processed"
label2_dir = "frame_56/processed"

# Create dataset instance
dataset = ManuallyLabelledDataset(label0_dir=label0_dir, label1_dir=label1_dir, label2_dir=label2_dir)

# Split into training and testing sets
print("Size is:", len(dataset))

train_size = int(0.8 * len(dataset))
# train_size = 256
test_size = len(dataset) - train_size
print(train_size, test_size)

trainset, testset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create DataLoader
trainloader = DataLoader(dataset=trainset, batch_size=16, shuffle=True) # always keep in mind that each batch should met batch size criterio i.e. number of batches * batch_size = trainloader
testloader = DataLoader(dataset=testset, batch_size=16, shuffle=False)

# Verify the batched data
for step, batch in enumerate(trainloader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {batch.num_graphs}')
    print(f'Node feature shape: {batch.x.shape}')
    print(f'Edge index shape: {batch.edge_index.shape}')
    print(f'Batch tensor: {batch.batch}') 
    print(f'Label tensor: {batch.y}')
    print(batch)
    print()
    break  # Remove this if you want to see all batches


1. Node feature shape: torch.Size([11530, 1024])
The node features for all nodes in the batch (10*1153), with each node having 1024 features. The first dimension (11530) represents the total number of nodes across all graphs in the batch.
2. Batch tensor: tensor([0, 0, 0,  ..., 9, 9, 9])
The batch tensor is a 1-dimensional tensor where each element indicates the graph index for the corresponding node.
3. Edge index shape: torch.Size([2, 145596])
The edge indices for all edges in the batch. The first dimension (2) is because edges are represented as pairs of node indices (source and target). The second dimension (145734) is the total number of edges in the batch.

In [None]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, num_features_pro, hidden_channels, num_classes):
        super(GCN, self).__init__()
#         torch.manual_seed(12345)
#         self.conv1 = GCNConv(num_features_pro, hidden_channels)
        self.conv1 = GCNConv(num_features_pro, num_features_pro)
#         self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv2 = GCNConv(num_features_pro, hidden_channels)
        self.fc1 = Linear(hidden_channels, 128)
        
        self.fc2 = Linear(128, 64)
        self.out = Linear(64, num_classes)
        self.relu = torch.nn.LeakyReLU()
        self.dropout = torch.nn.Dropout(0.2)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
#         x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # flatten
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        
        # add some dense layers
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        x = self.out(x)

        # 3. Apply a final classifier
#         x = F.dropout(x, p=0.5, training=self.training)
#         x = self.lin(x)
        
        return x

model = GCN(num_features_pro=1024, hidden_channels=256, num_classes=3)
print(model)

In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool


class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes, heads=4):
        super(GAT, self).__init__()


        # First Graph Attention Layer with multi-head attention
        self.gat1 = GATConv(num_features_pro, hidden_channels, heads=heads, concat=True)
        # Second Graph Attention Layer
        self.gat2 = GATConv(hidden_channels * heads, hidden_channels, heads=1, concat=True)  # 1 head, no concatenation
        
        # Fully connected layer to map hidden states to output classes
        self.fc1 = Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, batch):
        # First GAT layer with ReLU activation
        x = self.gat1(x, edge_index)
        x = F.relu(x)
        
        # Second GAT layer with ReLU activation
        x = self.gat2(x, edge_index)
        x = F.relu(x)
        
        # Global Mean Pooling to get graph-level representation
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # Fully connected layer for classification
        x = self.fc1(x)

        return x

        # Return log probabilities with softmax for 3-class classification
        # return F.log_softmax(x, dim=-1)


# Initialize model with input feature size (in_channels), hidden channels, and 3 output classes
num_features_pro = 1024  # Example input size
hidden_channels = 256
num_classes = 3
heads = 4  # Multi-head attention

model = GAT(num_features_pro, hidden_channels, num_classes, heads)

print(model)


In [None]:
# If we load only then run this
model.load_state_dict(torch.load('best_model.pt'))

In [None]:
optimizer =  torch.optim.Adam(model.parameters(), lr= 0.01)
criterion = torch.nn.CrossEntropyLoss()

def train(epoch):
    model.train()
    scheduler = MultiStepLR(optimizer, milestones=[5,8], gamma=0.5)

    for data in trainloader:  # Iterate in batches over the training dataset.
        out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
#         print(out)
        label = data.y.long().to(device)
        loss = criterion(out, label)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
    scheduler.step()
    print(f'Epoch {epoch}  [==============================] - train_loss : {loss}')
        
            
def test(loader):
        model.eval()

        correct = 0
        for data in loader:  # Iterate in batches over the training/test dataset.
            out = model(data.x, data.edge_index, data.batch)  
            pred = out.argmax(dim=1)  # Use the class with highest probability.
            correct += int((pred == data.y).sum())  # Check against ground-truth labels.
        return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 10):
    train(epoch)
    train_acc = test(trainloader)
    test_acc = test(testloader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

In [None]:
# # This is to check that if all batches have same number of data
# c = 0
# for data in trainloader:  # Iterate in batches over the training dataset.
#         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
# #         print(out)
#         label = data.y.long().to(device)
#         print(label)
#         print(c)
#         c += 1

In [None]:
import torch
from torch.optim.lr_scheduler import MultiStepLR
from torch_geometric.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Early stopping parameters
# patience = 20 # GCNN
patience = 5 # GAT
best_test_acc = 0
epochs_no_improve = 0
early_stop = False

def train(epoch):
    model.train()
    
    # scheduler = MultiStepLR(optimizer, milestones=[18, 22], gamma=0.5) # GCNN
    scheduler = MultiStepLR(optimizer, milestones=[6, 9], gamma=0.5) #GAT

    total_loss = 0
    for data in trainloader:  # Iterate in batches over the training dataset.
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        
        # Assuming the model output is of shape [batch_size, 2]
        label = data.y.long().to(device)  # Ensure labels are of type long
        loss = criterion(out, label)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
        total_loss += loss.item()
    scheduler.step()
    avg_loss = total_loss / len(trainloader)
    print(f'Epoch {epoch}  [==============================] - train_loss : {avg_loss:.4f}')

def evaluate(loader):
    model.eval()
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for data in loader:  # Iterate in batches over the dataset.
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)  
            
            # Use softmax to get probabilities
            # prob = torch.softmax(out, dim=1)
            prob = torch.log_softmax(out, dim=-1)
            pred = prob.argmax(dim=1)  # Use the class with highest probability
            
            y_true.extend(data.y.cpu().numpy())
            y_pred.extend(pred.cpu().numpy())

            
    
    accuracy = accuracy_score(y_true, y_pred)
    # precision = precision_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='micro')
    # recall = recall_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred,average='micro')
    f1 = f1_score(y_true, y_pred, average='micro')
    # roc_auc = roc_auc_score(y_true, y_pred, multi_class='ovr')
    
    # return accuracy, precision, recall, f1, roc_auc
    return accuracy, precision, recall, f1

# for epoch in range(1, 25): # GCNN
for epoch in range(1, 10): # GAT
    train(epoch)
    # train_acc, train_prec, train_recall, train_f1, train_roc_auc = evaluate(trainloader)
    # test_acc, test_prec, test_recall, test_f1, test_roc_auc = evaluate(testloader)

    train_acc, train_prec, train_recall, train_f1 = evaluate(trainloader)
    test_acc, test_prec, test_recall, test_f1 = evaluate(testloader)
    
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
    print(f'Precision: Train: {train_prec:.4f}, Test: {test_prec:.4f}')
    print(f'Recall: Train: {train_recall:.4f}, Test: {test_recall:.4f}')
    print(f'F1-Score: Train: {train_f1:.4f}, Test: {test_f1:.4f}')
    # print(f'ROC-AUC: Train: {train_roc_auc:.4f}, Test: {test_roc_auc:.4f}')
    
    # Early stopping logic
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        epochs_no_improve = 0
        # torch.save(model.state_dict(), 'best_model.pt') # GCNN
        torch.save(model.state_dict(), 'best_model_GAT.pt') # GAT
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print('Early stopping!')
            early_stop = True
            break

# Load the best model if early stopping was triggered
if early_stop:
    # model.load_state_dict(torch.load('best_model.pt')) # GCNN
    model.load_state_dict(torch.load('best_model_GAT.pt')) # GAT
