In [1]:
import os
data_dir = 'data/'
model_dir = 'models/'

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [3]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
import scikitplot as skplt

from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from source.utils import read_data, preprocess, train, test, run_kfold_test
from source.models import BitterGCN_Baseline, BitterGCN_MixedPool, BitterGAT_Baseline, \
    BitterGAT_MixedPool, BitterGraphSAGE_Baseline, BitterGraphSAGE_MixedPool
        

In [5]:
df = read_data(data_dir)

In [6]:
# k Fold
nsplits = 10
graph_data = preprocess(df)
n=len(graph_data)
graph_data[0], graph_data[7]

(Data(x=[2, 20], edge_index=[2, 2], y=0),
 Data(x=[8, 20], edge_index=[2, 14], y=1))

### KFold

In [7]:
KFOLD_RESULSTS = pd.DataFrame(index = list(range(10)))

In [None]:
class BitterGCN_Baseline(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(BitterGCN_Baseline, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(20, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.conv4 = GCNConv(hidden_channels, hidden_channels)
        self.conv5 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 2)



    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        x = x.relu()
        x = self.conv4(x, edge_index)
        x = x.relu()
        x = self.conv5(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.lin(x)
        
        return x


In [8]:
hidden_dim = 32
model = BitterGCN_Baseline(hidden_channels=hidden_dim)
print(model)

fold_test_acc, fold_test_roc = run_kfold_test(nsplits, graph_data, BitterGCN_Baseline, h=hidden_dim, lr = 0.05, b=1)

KFOLD_RESULSTS.loc[:,'ACC_GCN_Baseline'] = fold_test_acc
KFOLD_RESULSTS.loc[:,'ROC_GCN_Baseline'] = fold_test_roc

BitterGCN_Baseline(
  (conv1): GCNConv(20, 32)
  (conv2): GCNConv(32, 32)
  (conv3): GCNConv(32, 32)
  (lin): Linear(in_features=32, out_features=2, bias=True)
)
Fold 1:
Test Acc: 0.56 ROC: 0.41


KeyboardInterrupt: 

In [10]:
KFOLD_RESULSTS.mean()

ACC_GCN_Baseline    0.803
ROC_GCN_Baseline    0.852
dtype: float64

In [None]:
hidden_dim = 32
model = BitterGCN_Baseline(hidden_channels=hidden_dim)
print(model)

fold_test_acc, fold_test_roc = run_kfold_test(nsplits, graph_data, BitterGCN_Baseline, h=hidden_dim)

KFOLD_RESULSTS.loc[:,'ACC_GCN_Baseline'] = fold_test_acc
KFOLD_RESULSTS.loc[:,'ROC_GCN_Baseline'] = fold_test_roc

BitterGCN_Baseline(
  (conv1): GCNConv(20, 32)
  (conv2): GCNConv(32, 32)
  (conv3): GCNConv(32, 32)
  (lin): Linear(in_features=32, out_features=2, bias=True)
)
Fold 1:
Test Acc: 0.86 ROC: 0.86
Fold 2:
Test Acc: 0.92 ROC: 0.95
Fold 3:
Test Acc: 0.81 ROC: 0.9
Fold 4:
Test Acc: 0.81 ROC: 0.88
Fold 5:
Test Acc: 0.81 ROC: 0.88
Fold 6:
Test Acc: 0.91 ROC: 0.92
Fold 7:
Test Acc: 0.81 ROC: 0.88
Fold 8:
Test Acc: 0.8 ROC: 0.86
Fold 9:
Test Acc: 0.77 ROC: 0.85
Fold 10:
Test Acc: 0.72 ROC: 0.76


In [10]:
KFOLD_RESULSTS.mean()

ACC_GCN_Baseline    0.822
ROC_GCN_Baseline    0.874
dtype: float64

In [11]:
model = BitterGCN_MixedPool(hidden_channels=32)
print(model)

fold_test_acc, fold_test_roc = run_kfold_test(nsplits, graph_data, BitterGCN_MixedPool)

KFOLD_RESULSTS.loc[:,'ACC_GCN_MixedPool'] = fold_test_acc
KFOLD_RESULSTS.loc[:,'ROC_GCN_MixedPool'] = fold_test_roc

BitterGCN_MixedPool(
  (conv1): GCNConv(20, 32)
  (conv2): GCNConv(32, 32)
  (conv3): GCNConv(32, 32)
  (lin): Linear(in_features=32, out_features=2, bias=True)
)
Fold 1:
Test Acc: 0.86 ROC: 0.86
Fold 2:
Test Acc: 0.95 ROC: 0.98


KeyboardInterrupt: 

In [None]:
model = BitterGAT_Baseline(hidden_channels=32)
print(model)

fold_test_acc, fold_test_roc = run_kfold_test(nsplits, graph_data, BitterGAT_Baseline)

KFOLD_RESULSTS.loc[:,'ACC_GAT_Baseline'] = fold_test_acc
KFOLD_RESULSTS.loc[:,'ROC_GAT_Baseline'] = fold_test_roc

BitterGAT_Baseline(
  (conv1): GATConv(20, 32, heads=1)
  (conv2): GATConv(32, 32, heads=1)
  (conv3): GATConv(32, 32, heads=1)
  (lin): Linear(in_features=32, out_features=2, bias=True)
)
Fold 1:
Test Acc: 0.78 ROC: 0.81
Fold 2:
Test Acc: 0.81 ROC: 0.91
Fold 3:
Test Acc: 0.83 ROC: 0.85
Fold 4:
Test Acc: 0.88 ROC: 0.9
Fold 5:
Test Acc: 0.81 ROC: 0.85
Fold 6:
Test Acc: 0.84 ROC: 0.92
Fold 7:
Test Acc: 0.78 ROC: 0.93
Fold 8:
Test Acc: 0.88 ROC: 0.91
Fold 9:
Test Acc: 0.81 ROC: 0.87
Fold 10:
Test Acc: 0.7 ROC: 0.83


In [None]:
model = BitterGAT_MixedPool(hidden_channels=32)
print(model)

fold_test_acc, fold_test_roc = run_kfold_test(nsplits, graph_data, BitterGAT_MixedPool)

KFOLD_RESULSTS.loc[:,'ACC_GAT_MixedPool'] = fold_test_acc
KFOLD_RESULSTS.loc[:,'ROC_GAT_MixedPool'] = fold_test_roc

BitterGAT_MixedPool(
  (conv1): GATConv(20, 32, heads=1)
  (conv2): GATConv(32, 32, heads=1)
  (conv3): GATConv(32, 32, heads=1)
  (lin): Linear(in_features=32, out_features=2, bias=True)
)
Fold 1:
Test Acc: 0.8 ROC: 0.85
Fold 2:
Test Acc: 0.83 ROC: 0.97
Fold 3:
Test Acc: 0.84 ROC: 0.9
Fold 4:
Test Acc: 0.89 ROC: 0.89
Fold 5:
Test Acc: 0.78 ROC: 0.84
Fold 6:
Test Acc: 0.88 ROC: 0.93
Fold 7:
Test Acc: 0.86 ROC: 0.93
Fold 8:
Test Acc: 0.8 ROC: 0.91
Fold 9:
Test Acc: 0.8 ROC: 0.89
Fold 10:
Test Acc: 0.78 ROC: 0.85


In [None]:
model = BitterGraphSAGE_Baseline(hidden_channels=32)
print(model)

fold_test_acc, fold_test_roc = run_kfold_test(nsplits, graph_data, BitterGraphSAGE_Baseline)

KFOLD_RESULSTS.loc[:,'ACC_GraphSAGE_Baseline'] = fold_test_acc
KFOLD_RESULSTS.loc[:,'ROC_GraphSAGE_Baseline'] = fold_test_roc

BitterGraphSAGE_Baseline(
  (conv1): SAGEConv(20, 32)
  (conv2): SAGEConv(32, 32)
  (conv3): SAGEConv(32, 32)
  (lin): Linear(in_features=32, out_features=2, bias=True)
)
Fold 1:
Test Acc: 0.77 ROC: 0.86
Fold 2:
Test Acc: 0.89 ROC: 0.91
Fold 3:
Test Acc: 0.81 ROC: 0.92
Fold 4:
Test Acc: 0.77 ROC: 0.88
Fold 5:
Test Acc: 0.84 ROC: 0.88
Fold 6:
Test Acc: 0.84 ROC: 0.91
Fold 7:
Test Acc: 0.84 ROC: 0.93
Fold 8:
Test Acc: 0.84 ROC: 0.9
Fold 9:
Test Acc: 0.8 ROC: 0.82
Fold 10:
Test Acc: 0.72 ROC: 0.82


In [None]:
model = BitterGraphSAGE_MixedPool(hidden_channels=32)
print(model)

fold_test_acc, fold_test_roc = run_kfold_test(nsplits, graph_data, BitterGraphSAGE_MixedPool)

KFOLD_RESULSTS.loc[:,'ACC_GraphSAGE_MixedPool'] = fold_test_acc
KFOLD_RESULSTS.loc[:,'ROC_GraphSAGE_MixedPool'] = fold_test_roc

BitterGraphSAGE_MixedPool(
  (conv1): SAGEConv(20, 32)
  (conv2): SAGEConv(32, 32)
  (conv3): SAGEConv(32, 32)
  (lin): Linear(in_features=32, out_features=2, bias=True)
)
Fold 1:
Test Acc: 0.75 ROC: 0.86
Fold 2:
Test Acc: 0.89 ROC: 0.94
Fold 3:
Test Acc: 0.86 ROC: 0.9
Fold 4:
Test Acc: 0.83 ROC: 0.91
Fold 5:
Test Acc: 0.83 ROC: 0.85
Fold 6:
Test Acc: 0.84 ROC: 0.89
Fold 7:
Test Acc: 0.81 ROC: 0.93
Fold 8:
Test Acc: 0.8 ROC: 0.91
Fold 9:
Test Acc: 0.83 ROC: 0.87
Fold 10:
Test Acc: 0.81 ROC: 0.83


In [None]:
KFOLD_RESULSTS

Unnamed: 0,ACC_GCN_Baseline,ROC_GCN_Baseline,ACC_GCN_MixedPool,ROC_GCN_MixedPool,ACC_GAT_Baseline,ROC_GAT_Baseline,ACC_GAT_MixedPool,ROC_GAT_MixedPool,ACC_GraphSAGE_Baseline,ROC_GraphSAGE_Baseline,ACC_GraphSAGE_MixedPool,ROC_GraphSAGE_MixedPool
0,0.81,0.81,0.86,0.86,0.78,0.81,0.8,0.85,0.77,0.86,0.75,0.86
1,0.91,0.97,0.95,0.98,0.81,0.91,0.83,0.97,0.89,0.91,0.89,0.94
2,0.84,0.89,0.91,0.91,0.83,0.85,0.84,0.9,0.81,0.92,0.86,0.9
3,0.84,0.91,0.89,0.91,0.88,0.9,0.89,0.89,0.77,0.88,0.83,0.91
4,0.8,0.83,0.77,0.85,0.81,0.85,0.78,0.84,0.84,0.88,0.83,0.85
5,0.84,0.92,0.86,0.9,0.84,0.92,0.88,0.93,0.84,0.91,0.84,0.89
6,0.81,0.92,0.88,0.92,0.78,0.93,0.86,0.93,0.84,0.93,0.81,0.93
7,0.89,0.9,0.81,0.87,0.88,0.91,0.8,0.91,0.84,0.9,0.8,0.91
8,0.8,0.86,0.88,0.89,0.81,0.87,0.8,0.89,0.8,0.82,0.83,0.87
9,0.78,0.77,0.77,0.8,0.7,0.83,0.78,0.85,0.72,0.82,0.81,0.83


In [None]:
KFOLD_RESULSTS.mean()

ACC_GCN_Baseline           0.832
ROC_GCN_Baseline           0.878
ACC_GCN_MixedPool          0.858
ROC_GCN_MixedPool          0.889
ACC_GAT_Baseline           0.812
ROC_GAT_Baseline           0.878
ACC_GAT_MixedPool          0.826
ROC_GAT_MixedPool          0.896
ACC_GraphSAGE_Baseline     0.812
ROC_GraphSAGE_Baseline     0.883
ACC_GraphSAGE_MixedPool    0.825
ROC_GraphSAGE_MixedPool    0.889
dtype: float64