<a href="https://colab.research.google.com/github/tejaspradhan/Graph-Neural-Networks/blob/main/hiv-project/gcn_hiv_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Installations

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#@title
# Install rdkit
import sys
import os
import requests
import subprocess
import shutil
from logging import getLogger, StreamHandler, INFO
 
 
logger = getLogger(__name__)
logger.addHandler(StreamHandler())
logger.setLevel(INFO)
 
 
def install(
        chunk_size=4096,
        file_name="Miniconda3-latest-Linux-x86_64.sh",
        url_base="https://repo.continuum.io/miniconda/",
        conda_path=os.path.expanduser(os.path.join("~", "miniconda")),
        rdkit_version=None,
        add_python_path=True,
        force=False):
    """install rdkit from miniconda
    ```
    import rdkit_installer
    rdkit_installer.install()
    ```
    """
 
    python_path = os.path.join(
        conda_path,
        "lib",
        "python{0}.{1}".format(*sys.version_info),
        "site-packages",
    )
 
    if add_python_path and python_path not in sys.path:
        logger.info("add {} to PYTHONPATH".format(python_path))
        sys.path.append(python_path)
 
    if os.path.isdir(os.path.join(python_path, "rdkit")):
        logger.info("rdkit is already installed")
        if not force:
            return
 
        logger.info("force re-install")
 
    url = url_base + file_name
    python_version = "{0}.{1}.{2}".format(*sys.version_info)
 
    logger.info("python version: {}".format(python_version))
 
    if os.path.isdir(conda_path):
        logger.warning("remove current miniconda")
        shutil.rmtree(conda_path)
    elif os.path.isfile(conda_path):
        logger.warning("remove {}".format(conda_path))
        os.remove(conda_path)
 
    logger.info('fetching installer from {}'.format(url))
    res = requests.get(url, stream=True)
    res.raise_for_status()
    with open(file_name, 'wb') as f:
        for chunk in res.iter_content(chunk_size):
            f.write(chunk)
    logger.info('done')
 
    logger.info('installing miniconda to {}'.format(conda_path))
    subprocess.check_call(["bash", file_name, "-b", "-p", conda_path])
    logger.info('done')
 
    logger.info("installing rdkit")
    subprocess.check_call([
        os.path.join(conda_path, "bin", "conda"),
        "install",
        "--yes",
        "-c", "rdkit",
        "python==3.7.3",
        "rdkit" if rdkit_version is None else "rdkit=={}".format(rdkit_version)])
    logger.info("done")
 
    import rdkit
    logger.info("rdkit-{} installation finished!".format(rdkit.__version__))
 
 
if __name__ == "__main__":
    install()

add /root/miniconda/lib/python3.7/site-packages to PYTHONPATH
python version: 3.7.12
fetching installer from https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
done
installing miniconda to /root/miniconda
done
installing rdkit
done
rdkit-2020.09.1 installation finished!


In [3]:
! pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
! pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
! pip install -q torch-cluster -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
! pip install -q torch-geometric

[K     |████████████████████████████████| 7.9 MB 4.3 MB/s 
[K     |████████████████████████████████| 3.5 MB 4.5 MB/s 
[K     |████████████████████████████████| 2.5 MB 4.2 MB/s 
[K     |████████████████████████████████| 407 kB 4.2 MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


## Importing Libraries

In [4]:
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch 
import torch.nn.functional as F 
import torch_geometric
from torch_geometric.data import Dataset, Data
from torch_geometric.nn import GATConv, Linear, TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.loader import DataLoader
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from torch_geometric.nn import GCNConv

## Creating Custom Graph Dataset from CSV File

**Important**
Before running the next cell create a directory structure as follows : 

        |- data
              |- raw
                    |- HIV.csv (data csv file)

In [5]:
class HIVDataset(Dataset):
    def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data). 
        """
        self.test = test
        self.filename = filename
        super(HIVDataset, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.
            (The download func. is not implemented here)  
        """
        return self.filename

    @property
    def processed_file_names(self):
        return 'not_implemented.pt'

    def download(self):
        pass

    def process(self):
        self.data = pd.read_csv(self.raw_paths[0])
        print(self.raw_paths)
        for index, mol in tqdm(self.data.iterrows(), total=self.data.shape[0]):
            mol_obj = Chem.MolFromSmiles(mol["smiles"])
            # Get node features
            node_feats = self._get_node_features(mol_obj)
            # Get edge features
            edge_feats = self._get_edge_features(mol_obj)
            # Get adjacency info
            edge_index = self._get_adjacency_info(mol_obj)
            # Get labels info
            label = self._get_labels(mol["HIV_active"])

            # Create data object
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=label,
                        smiles=mol["smiles"]
                        ) 
            if self.test:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_test_{index}.pt'))
            else:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_{index}.pt'))

    def _get_node_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_node_feats = []

        for atom in mol.GetAtoms():
            node_feats = []
            # Feature 1: Atomic number        
            node_feats.append(atom.GetAtomicNum())
            # Feature 2: Atom degree
            node_feats.append(atom.GetDegree())
            # Feature 3: Formal charge
            node_feats.append(atom.GetFormalCharge())
            # Feature 4: Hybridization
            node_feats.append(atom.GetHybridization())
            # Feature 5: Aromaticity
            node_feats.append(atom.GetIsAromatic())
            # Feature 6: Total Num Hs
            node_feats.append(atom.GetTotalNumHs())
            # Feature 7: Radical Electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            # Feature 8: In Ring
            node_feats.append(atom.IsInRing())
            # Feature 9: Chirality
            node_feats.append(atom.GetChiralTag())

            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, label):
        label = np.asarray([label])
        return torch.tensor(label, dtype=torch.int64)

    def len(self):
        return self.data.shape[0]

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_{idx}.pt'))   
        return data

In [6]:
train_dataset = HIVDataset(root='data/',filename='HIV_train_oversampled.csv')
test_dataset = HIVDataset(root='data/',filename='HIV_test.csv')

Processing...


['data/raw/HIV_train_oversampled.csv']


100%|██████████| 71634/71634 [03:49<00:00, 311.62it/s]
Done!
Processing...


['data/raw/HIV_test.csv']


100%|██████████| 3999/3999 [00:10<00:00, 391.40it/s]
Done!


In [7]:
train_dataset

HIVDataset(71634)

In [8]:
test_dataset

HIVDataset(3999)

In [9]:
train_dataset[0].x.shape
# 9 node features

torch.Size([19, 9])

In [10]:
train_dataset[0].y.shape

torch.Size([1])

In [11]:
train_dataset[0].edge_index.t().shape # 2 edge featurs 

torch.Size([40, 2])

## Constructing the GCN for this Dataset 

In [12]:
class GCN(torch.nn.Module):
    def __init__(self, feature_size, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        num_classes = 2
        self.conv1 = GCNConv(feature_size, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = gap(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

##Training GCN


In [13]:
model = GCN(train_dataset[0].x.shape[1], 256)

In [None]:
model.to('cuda')

GCN(
  (conv1): GCNConv(9, 256)
  (conv2): GCNConv(256, 256)
  (conv3): GCNConv(256, 256)
  (lin): Linear(256, 2, bias=True)
)

In [21]:
# Defining loss and optimiser 
weights = torch.tensor([1,10],dtype=torch.float32).to('cpu')
loss_fn = torch.nn.CrossEntropyLoss(weight = weights)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1,momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.95)

In [16]:
BATCH_SIZE = 256
EPOCHS=80

In [17]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
def train(model, train_loader, optimizer, loss_fn):
    # Enumerate over the data
    all_preds = []
    all_labels = []
    running_loss = 0.0
    step = 0
    for _, batch in enumerate(tqdm(train_loader)):
        batch.to('cuda')  
        optimizer.zero_grad() 
        # forward pass
        pred = model(batch.x.float(), batch.edge_index, batch.batch) 
        # backward pass 
        loss = loss_fn(pred, batch.y)
        loss.backward()  
        optimizer.step()  

        # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds.append(np.argmax(pred.cpu().detach().numpy(),axis=1))
        all_labels.append(batch.y.cpu().detach().numpy())
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    # print(all_preds)
    acc= accuracy_score(all_labels, all_preds)
    return running_loss/step, acc

In [None]:
losses = []
accs = []
for epoch in range(EPOCHS):
  loss, accuracy = train(model, train_loader, optimizer,loss_fn)
  losses.append(loss)
  accs.append(accuracy)
  if epoch%10==0:
    torch.save(model.state_dict(), "/content/drive/MyDrive/GNN_HIV/models/gcn/model-hiv-"+str(epoch+170) +"epochs.pt")
  print(f"Epoch : {epoch+1} , Loss : {loss}, Accuracy : {accuracy}")

100%|██████████| 280/280 [00:44<00:00,  6.32it/s]


Epoch : 1 , Loss : 0.23617627189627716, Accuracy : 0.5997291788815367


100%|██████████| 280/280 [00:43<00:00,  6.41it/s]


Epoch : 2 , Loss : 0.23484112689537662, Accuracy : 0.6019487952648184


100%|██████████| 280/280 [00:44<00:00,  6.36it/s]


Epoch : 3 , Loss : 0.23698451928794384, Accuracy : 0.5969092888851663


100%|██████████| 280/280 [00:44<00:00,  6.31it/s]


Epoch : 4 , Loss : 0.23700520279152051, Accuracy : 0.5987938688332356


100%|██████████| 280/280 [00:43<00:00,  6.45it/s]


Epoch : 5 , Loss : 0.24143246401633536, Accuracy : 0.5925817349303404


100%|██████████| 280/280 [00:43<00:00,  6.47it/s]


Epoch : 6 , Loss : 0.2366058630070516, Accuracy : 0.598598430912695


100%|██████████| 280/280 [00:43<00:00,  6.46it/s]


Epoch : 7 , Loss : 0.23661670120699066, Accuracy : 0.6005388502666331


100%|██████████| 280/280 [00:43<00:00,  6.50it/s]


Epoch : 8 , Loss : 0.23554430050509317, Accuracy : 0.601920875561884


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 9 , Loss : 0.23529908955097198, Accuracy : 0.6051316413993355


100%|██████████| 280/280 [00:43<00:00,  6.48it/s]


Epoch : 10 , Loss : 0.24170771730797633, Accuracy : 0.5883658597872519


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 11 , Loss : 0.23450059965252876, Accuracy : 0.603847335064355


100%|██████████| 280/280 [00:43<00:00,  6.50it/s]


Epoch : 12 , Loss : 0.23658337454710687, Accuracy : 0.6008459669989111


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 13 , Loss : 0.23487116864749363, Accuracy : 0.6029678644219225


100%|██████████| 280/280 [00:43<00:00,  6.50it/s]


Epoch : 14 , Loss : 0.23416284838957446, Accuracy : 0.6069603819415361


100%|██████████| 280/280 [00:42<00:00,  6.53it/s]


Epoch : 15 , Loss : 0.23239324007715498, Accuracy : 0.6102130273333892


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 16 , Loss : 0.2325298664293119, Accuracy : 0.6101571879275205


100%|██████████| 280/280 [00:42<00:00,  6.54it/s]


Epoch : 17 , Loss : 0.23315385010625636, Accuracy : 0.6096406734232348


100%|██████████| 280/280 [00:42<00:00,  6.53it/s]


Epoch : 18 , Loss : 0.23159269403134072, Accuracy : 0.6124186838652037


100%|██████████| 280/280 [00:42<00:00,  6.51it/s]


Epoch : 19 , Loss : 0.23012178614735604, Accuracy : 0.6165647597509563


100%|██████████| 280/280 [00:42<00:00,  6.57it/s]


Epoch : 20 , Loss : 0.23064009568520955, Accuracy : 0.6150152162380992


100%|██████████| 280/280 [00:42<00:00,  6.59it/s]


Epoch : 21 , Loss : 0.23724669212741512, Accuracy : 0.5989055476449731


100%|██████████| 280/280 [00:43<00:00,  6.49it/s]


Epoch : 22 , Loss : 0.23026989208800452, Accuracy : 0.6164530809392188


100%|██████████| 280/280 [00:43<00:00,  6.49it/s]


Epoch : 23 , Loss : 0.22976858089012758, Accuracy : 0.6168299969288327


100%|██████████| 280/280 [00:42<00:00,  6.53it/s]


Epoch : 24 , Loss : 0.23102305728409972, Accuracy : 0.6137727894575201


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 25 , Loss : 0.2282214335032872, Accuracy : 0.6212133902895273


100%|██████████| 280/280 [00:43<00:00,  6.51it/s]


Epoch : 26 , Loss : 0.2317310606794698, Accuracy : 0.6146522600999526


100%|██████████| 280/280 [00:42<00:00,  6.53it/s]


Epoch : 27 , Loss : 0.2301212141556399, Accuracy : 0.6171510735125778


100%|██████████| 280/280 [00:42<00:00,  6.54it/s]


Epoch : 28 , Loss : 0.229222443486963, Accuracy : 0.6176675880168635


100%|██████████| 280/280 [00:42<00:00,  6.63it/s]


Epoch : 29 , Loss : 0.22974217172179903, Accuracy : 0.6216740653879442


100%|██████████| 280/280 [00:41<00:00,  6.67it/s]


Epoch : 30 , Loss : 0.23056538674448218, Accuracy : 0.6163832816818829


100%|██████████| 280/280 [00:42<00:00,  6.67it/s]


Epoch : 31 , Loss : 0.22775967552193574, Accuracy : 0.6241589189491024


100%|██████████| 280/280 [00:42<00:00,  6.65it/s]


Epoch : 32 , Loss : 0.23114927565412863, Accuracy : 0.6154479716335818


100%|██████████| 280/280 [00:41<00:00,  6.69it/s]


Epoch : 33 , Loss : 0.2272563115826675, Accuracy : 0.6266996119161292


100%|██████████| 280/280 [00:42<00:00,  6.61it/s]


Epoch : 34 , Loss : 0.2283905925495284, Accuracy : 0.6195242482619985


100%|██████████| 280/280 [00:43<00:00,  6.51it/s]


Epoch : 35 , Loss : 0.22815139016934805, Accuracy : 0.6233771672669403


100%|██████████| 280/280 [00:43<00:00,  6.50it/s]


Epoch : 36 , Loss : 0.23194828554987906, Accuracy : 0.6120976072814586


100%|██████████| 280/280 [00:43<00:00,  6.51it/s]


Epoch : 37 , Loss : 0.22581320740282534, Accuracy : 0.6287377502303375


100%|██████████| 280/280 [00:42<00:00,  6.51it/s]


Epoch : 38 , Loss : 0.23044376841613223, Accuracy : 0.6180026244520759


100%|██████████| 280/280 [00:42<00:00,  6.53it/s]


Epoch : 39 , Loss : 0.23138514746512687, Accuracy : 0.6168160370773655


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 40 , Loss : 0.2252715487565313, Accuracy : 0.6281235167657816


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 41 , Loss : 0.22583095537764686, Accuracy : 0.628388753943658


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 42 , Loss : 0.22288930432072707, Accuracy : 0.6348242454700282


100%|██████████| 280/280 [00:43<00:00,  6.50it/s]


Epoch : 43 , Loss : 0.22328598557838372, Accuracy : 0.6360806321020744


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 44 , Loss : 0.22434249339359147, Accuracy : 0.63441940977748


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 45 , Loss : 0.22354257905057498, Accuracy : 0.6339308149761287


100%|██████████| 280/280 [00:43<00:00,  6.48it/s]


Epoch : 46 , Loss : 0.22486281134188174, Accuracy : 0.6299243376050478


100%|██████████| 280/280 [00:42<00:00,  6.54it/s]


Epoch : 47 , Loss : 0.22071803233453205, Accuracy : 0.6413714158081358


100%|██████████| 280/280 [00:42<00:00,  6.61it/s]


Epoch : 48 , Loss : 0.2250220243952104, Accuracy : 0.6300220565653182


100%|██████████| 280/280 [00:42<00:00,  6.64it/s]


Epoch : 49 , Loss : 0.22414982324200017, Accuracy : 0.6335957785409163


100%|██████████| 280/280 [00:42<00:00,  6.65it/s]


Epoch : 50 , Loss : 0.22363525990928923, Accuracy : 0.6354105592316498


100%|██████████| 280/280 [00:42<00:00,  6.56it/s]


Epoch : 51 , Loss : 0.2222422579569476, Accuracy : 0.6385654856632326


100%|██████████| 280/280 [00:42<00:00,  6.53it/s]


Epoch : 52 , Loss : 0.2242793720215559, Accuracy : 0.6357455956668621


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 53 , Loss : 0.21958785434918746, Accuracy : 0.6434933132311472


100%|██████████| 280/280 [00:42<00:00,  6.54it/s]


Epoch : 54 , Loss : 0.21833432742527553, Accuracy : 0.6456850099114946


100%|██████████| 280/280 [00:43<00:00,  6.50it/s]


Epoch : 55 , Loss : 0.22065734543970653, Accuracy : 0.6416785325404137


100%|██████████| 280/280 [00:42<00:00,  6.54it/s]


Epoch : 56 , Loss : 0.2186226005532912, Accuracy : 0.6476952285227685


100%|██████████| 280/280 [00:43<00:00,  6.49it/s]


Epoch : 57 , Loss : 0.22015991077891417, Accuracy : 0.6415528938772092


100%|██████████| 280/280 [00:42<00:00,  6.54it/s]


Epoch : 58 , Loss : 0.22247198889298098, Accuracy : 0.6385515258117653


100%|██████████| 280/280 [00:42<00:00,  6.63it/s]


Epoch : 59 , Loss : 0.21715181012238774, Accuracy : 0.6490632939665522


100%|██████████| 280/280 [00:41<00:00,  6.69it/s]


Epoch : 60 , Loss : 0.2206152895199401, Accuracy : 0.6405896641259737


100%|██████████| 280/280 [00:41<00:00,  6.73it/s]


Epoch : 61 , Loss : 0.22062592070017542, Accuracy : 0.6420554485300276


100%|██████████| 280/280 [00:41<00:00,  6.71it/s]


Epoch : 62 , Loss : 0.2168914363852569, Accuracy : 0.6517994248541196


100%|██████████| 280/280 [00:41<00:00,  6.74it/s]


Epoch : 63 , Loss : 0.22083212589578968, Accuracy : 0.6427394812519195


100%|██████████| 280/280 [00:42<00:00,  6.66it/s]


Epoch : 64 , Loss : 0.22120199001261165, Accuracy : 0.642851160063657


100%|██████████| 280/280 [00:42<00:00,  6.59it/s]


Epoch : 65 , Loss : 0.22202415860124997, Accuracy : 0.639989390512885


100%|██████████| 280/280 [00:42<00:00,  6.61it/s]


Epoch : 66 , Loss : 0.22065371807132447, Accuracy : 0.6439260686266298


100%|██████████| 280/280 [00:42<00:00,  6.63it/s]


Epoch : 67 , Loss : 0.21707023456692695, Accuracy : 0.6522042605466678


100%|██████████| 280/280 [00:42<00:00,  6.67it/s]


Epoch : 68 , Loss : 0.21724167663071836, Accuracy : 0.6497054471340425


100%|██████████| 280/280 [00:42<00:00,  6.61it/s]


Epoch : 69 , Loss : 0.21423001464988503, Accuracy : 0.6606220509813776


100%|██████████| 280/280 [00:42<00:00,  6.53it/s]


Epoch : 70 , Loss : 0.22120058472667423, Accuracy : 0.6438004299634252


100%|██████████| 280/280 [00:43<00:00,  6.51it/s]


Epoch : 71 , Loss : 0.21973660380712579, Accuracy : 0.6465365608509925


100%|██████████| 280/280 [00:42<00:00,  6.55it/s]


Epoch : 72 , Loss : 0.21605543504868235, Accuracy : 0.652064662031996


100%|██████████| 280/280 [00:42<00:00,  6.53it/s]


Epoch : 73 , Loss : 0.21437632170106682, Accuracy : 0.6584024345980959


100%|██████████| 280/280 [00:42<00:00,  6.61it/s]


Epoch : 74 , Loss : 0.23886028485638755, Accuracy : 0.6143311835162074


100%|██████████| 280/280 [00:42<00:00,  6.60it/s]


Epoch : 75 , Loss : 0.22673256514327866, Accuracy : 0.6268810899852025


100%|██████████| 280/280 [00:42<00:00,  6.57it/s]


Epoch : 76 , Loss : 0.2211612473641123, Accuracy : 0.6413574559566686


100%|██████████| 280/280 [00:42<00:00,  6.52it/s]


Epoch : 77 , Loss : 0.21700722532612937, Accuracy : 0.6528882932685596


100%|██████████| 280/280 [00:42<00:00,  6.60it/s]


Epoch : 78 , Loss : 0.21758842085089003, Accuracy : 0.6524834575760113


100%|██████████| 280/280 [00:41<00:00,  6.72it/s]


Epoch : 79 , Loss : 0.21626544275454113, Accuracy : 0.6553033475723818


100%|██████████| 280/280 [00:42<00:00,  6.57it/s]

Epoch : 80 , Loss : 0.21756227474127496, Accuracy : 0.6529022531200268





In [None]:
torch.save(model.state_dict(), "/content/drive/MyDrive/GNN_HIV/models/gcn/model-hiv-250epochs.pt")

## Testing 

In [14]:
model.load_state_dict(torch.load('/content/drive/MyDrive/GNN_HIV/models/gcn/model-hiv-500epochs.pt',map_location = torch.device('cpu')))

<All keys matched successfully>

In [22]:
model.eval()
with torch.no_grad():
  test_preds = []
  test_labels = []
  running_loss = 0.0
  steps =0 
  for _, batch in enumerate(tqdm(train_loader)):
    batch.to('cpu')
    pred = model(batch.x.float(),batch.edge_index, batch.batch)
    # print(batch.y)
    loss = loss_fn(pred, batch.y)
    test_preds.append(np.argmax(pred.cpu().detach().numpy(),axis=1))
    test_labels.append(batch.y.cpu().detach().numpy())
    running_loss += loss.item()
    steps+=1
  test_preds = np.concatenate(test_preds).ravel()
  test_labels = np.concatenate(test_labels).ravel()
  acc= accuracy_score(test_labels, test_preds)
  print(f'\nTest Loss : {running_loss/steps}')

100%|██████████| 280/280 [01:45<00:00,  2.67it/s]


Test Loss : 0.13925306624067682





In [24]:
print(acc)

0.8284753050227546
