In [None]:
pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.4.0


In [None]:
pip install shap

Collecting shap
  Downloading shap-0.44.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (533 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m533.5/533.5 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
Collecting slicer==0.0.7 (from shap)
  Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
Installing collected packages: slicer, shap
Successfully installed shap-0.44.0 slicer-0.0.7


In [None]:
import torch
from torch_geometric.datasets import EllipticBitcoinDataset
import torch.nn as nn
import torch.nn.functional as F
# from ogb.nodeproppred import Evaluator
from sklearn import metrics as metrics
import shap

# Read Dataset

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset = EllipticBitcoinDataset(root='data/whole_graph')
data = dataset[0]

Downloading https://data.pyg.org/datasets/elliptic/elliptic_txs_features.csv.zip
Extracting data/whole_graph/raw/elliptic_txs_features.csv.zip
Downloading https://data.pyg.org/datasets/elliptic/elliptic_txs_edgelist.csv.zip
Extracting data/whole_graph/raw/elliptic_txs_edgelist.csv.zip
Downloading https://data.pyg.org/datasets/elliptic/elliptic_txs_classes.csv.zip
Extracting data/whole_graph/raw/elliptic_txs_classes.csv.zip
Processing...
Done!


# Models

## GCN

In [None]:
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, dropout):
        super().__init__()
        torch.manual_seed(777)

        # Convolution layers
        if num_layers > 1:
            self.convs = nn.ModuleList([GCNConv(dataset.num_features, hidden_channels)])
            self.convs.extend([GCNConv(hidden_channels, hidden_channels) for i in range(num_layers - 2)])
            self.convs.append(GCNConv(hidden_channels, dataset.num_classes))

            # Batch normilization
            self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_channels)
                                     for i in range(num_layers - 1)])
        else:
            self.convs = nn.ModuleList([GCNConv(dataset.num_features, dataset.num_classes)])
            self.bns = nn.ModuleList([])

        # Softmax layer
        self.softmax = nn.LogSoftmax(1)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    # initialize parameters
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, edge_index, embedding=False):
        for gcn, bn in zip(self.convs, self.bns):
            x = self.dropout(torch.relu(bn(gcn(x, edge_index))))
        if embedding:
            embeddings = x.cpu().detach().numpy()
        else: embeddings = None

        x = self.convs[-1](x, edge_index)

        return self.softmax(x), embeddings

## GAT

In [None]:
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads, num_layers, dropout):
        super().__init__()
        torch.manual_seed(777)
        self.num_layers = num_layers

        if num_layers > 1:
            # GAT layers
            self.convs = nn.ModuleList([GATConv(dataset.num_features, hidden_channels, heads)])
            self.convs.extend([GATConv(heads*hidden_channels, hidden_channels, heads) for i in range(num_layers - 2)])
            self.convs.append(GATConv(heads*hidden_channels, dataset.num_classes))

            # Batch Normilization
            self.bns = nn.ModuleList([nn.BatchNorm1d(heads*hidden_channels)
                                     for i in range(num_layers - 1)])
        else:
            self.convs = nn.ModuleList([GATConv(dataset.num_features, dataset.num_classes)])
            self.bns = nn.ModuleList([])

         # Softmax layer
        self.softmax = nn.LogSoftmax(1)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    # initialize parameters
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, edge_index, embedding=False):
        for gat, bn in zip(self.convs, self.bns):
            x = self.dropout(torch.relu(bn(gat(x, edge_index))))
        if embedding:
            embeddings = x.cpu().detach().numpy()
        else: embeddings = None

        x = self.convs[-1](x, edge_index)

        return self.softmax(x), embeddings



## GraphSAGE

In [None]:
from torch_geometric.nn import SAGEConv

class GraphSAGENet(nn.Module):
    def __init__(self, hidden_channels, num_layers, dropout):
        super().__init__()
        torch.manual_seed(777)

        if num_layers > 1:
            # Convolution layers
            self.convs = nn.ModuleList([SAGEConv(dataset.num_features, hidden_channels)])
            self.convs.extend([SAGEConv(hidden_channels, hidden_channels) for i in range(num_layers - 2)])
            self.convs.append(SAGEConv(hidden_channels, dataset.num_classes))

             # Batch normilization
            self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_channels)
                                     for i in range(num_layers - 1)])

        else:
            self.convs = nn.ModuleList([SAGEConv(dataset.num_features, dataset.num_classes)])
            self.bns = nn.ModuleList([])

        # Softmax layer
        self.softmax = nn.LogSoftmax(1)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    # initialize parameters
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, edge_index, embedding=False):
        for gcn, bn in zip(self.convs, self.bns):
            x = self.dropout(torch.relu(bn(gcn(x, edge_index))))
        if embedding:
            embeddings = x.cpu().detach().numpy()
        else: embeddings = None

        x = self.convs[-1](x, edge_index)

        return self.softmax(x), embeddings


# Help functions

In [None]:
def train(model, data, optimizer, loss_fn, embedding):
    model.train()

    # Clear gradients.
    optimizer.zero_grad()

    # feed datas into the model
    output, _ = model(data.x, data.edge_index, embedding)

    # Get the model's predictions and labels
    pred, label = output[data.train_mask], data.y[data.train_mask].view(-1)

    loss = loss_fn(pred, label0)

    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss.item()


def test(model, data, embedding):
    model.eval()

    output, _ = model(data.x, data.edge_index, embedding)
    y_pred = output.argmax(dim=1)
    train_acc = metrics.accuracy_score(data.y[data.train_mask].cpu(), y_pred[data.train_mask].cpu())
    test_acc = metrics.accuracy_score(data.y[data.test_mask].cpu(), y_pred[data.test_mask].cpu())
    test_pre = metrics.precision_score(data.y[data.test_mask].cpu(), y_pred[data.test_mask].cpu())
    test_recall = metrics.recall_score(data.y[data.test_mask].cpu(), y_pred[data.test_mask].cpu())
    test_f1 = metrics.f1_score(data.y[data.test_mask].cpu(), y_pred[data.test_mask].cpu())

    return train_acc, test_acc, test_pre, test_recall, test_f1


def runModel(model, data, optimizer, loss_fn, embedding):
    model.reset_parameters()
    data = data.to(device)
    for epoch in range(1, 501):
        loss = train(model, data, optimizer, loss_fn, embedding)
        result = test(model, data, embedding)
        train_acc, _, _, _, _ = result
        print(f'Epoch: {epoch:02d}, '
              f'Loss: {loss:.4f}, '
              f'Train: {100*train_acc:.2f}%')

    result = test(model, data, embedding)
    _, test_acc, test_pre, test_recall, test_f1 = result
    print(f'Test Accuracy: {100*test_acc:.2f}%  '
          f'Test Precision: {100*test_pre:.2f}%  '
          f'Test Recall: {100*test_recall:.2f}%  '
          f'Test F1: {100*test_f1:.2f}%  ')


def get_embeddings(model, data, optimizer, loss_fn):
    model.train()
    # Clear gradients.
    optimizer.zero_grad()
    # feed datas into the model
    for _ in range(500):
        _, _ = model(data.x, data.edge_index, False)
    _, embeddings = model(data.x, data.edge_index, True)
    return embeddings

# Run Models

In [None]:
model_GAT = GAT(hidden_channels=64, heads=8, num_layers=2, dropout=0.3).to(device)
model_SAGE = GraphSAGENet(hidden_channels=64, num_layers=2, dropout=0.3).to(device)
model_GCN = GCN(hidden_channels=128, num_layers=2, dropout=0.3).to(device)

## GCN

In [None]:
#weight = torch.tensor([0.5655, 4.3174]).to(device)
model = model_GCN
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
loss_fn = torch.nn.CrossEntropyLoss()
runModel(model, data, optimizer, loss_fn, _)

NameError: ignored

## GAT

In [None]:
model = model_GAT
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
loss_fn = torch.nn.CrossEntropyLoss()
runModel(model, data, optimizer, loss_fn, False)

## GraphSAGE

In [None]:
model = model_SAGE
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
loss_fn = torch.nn.CrossEntropyLoss()
runModel(model, data, optimizer, loss_fn, False)

NameError: ignored

# Benchmark Models

## Use all features (without node embeddings)

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
import numpy as np

train_label_benchmark = data.y[data.train_mask].cpu().numpy()
test_label_benchmark = data.y[data.test_mask].cpu().numpy()

In [None]:
train_dataset_AF = data.x[data.train_mask].cpu().numpy()
test_dataset_AF = data.x[data.test_mask].cpu().numpy()

### SVM

In [None]:
clf_SVM = SVC(class_weight="balanced", probability=True)
clf_SVM.fit(train_dataset_AF, train_label_benchmark)
prediction = clf_SVM.predict(test_dataset_AF)

SVM_acc = metrics.accuracy_score(test_label_benchmark, prediction)
print(f'Test Accuracy: {100*SVM_acc:.2f}%')
SVM_pre = metrics.precision_score(test_label_benchmark, prediction)
print(f'Test Precision: {100*SVM_pre:.2f}%')
SVM_recall = metrics.recall_score(test_label_benchmark, prediction)
print(f'Test Recall: {100*SVM_recall:.2f}%')
SVM_f1 = metrics.f1_score(test_label_benchmark, prediction)
print(f'Test F1: {100*SVM_f1:.2f}%')

Test Accuracy: 86.00%
Test Precision: 28.29%
Test Recall: 75.25%
Test F1: 41.12%


### RandomForest

In [None]:
clf_RF = RandomForestClassifier(n_estimators=500, max_depth=15, random_state=0, class_weight="balanced")
clf_RF.fit(train_dataset_AF, train_label_benchmark)
prediction = clf_RF.predict(test_dataset_AF)

RF_acc = metrics.accuracy_score(test_label_benchmark, prediction)
print(f'Test Accuracy: {100*RF_acc:.2f}%')
RF_pre = metrics.precision_score(test_label_benchmark, prediction)
print(f'Test Precision: {100*RF_pre:.2f}%')
RF_recall = metrics.recall_score(test_label_benchmark, prediction)
print(f'Test Recall: {100*RF_recall:.2f}%')
RF_f1 = metrics.f1_score(test_label_benchmark, prediction)
print(f'Test F1: {100*RF_f1:.2f}%')

Test Accuracy: 97.76%
Test Precision: 92.26%
Test Recall: 71.56%
Test F1: 80.60%


Sort features by important scores

In [None]:
RF_important_features = np.argsort(clf_RF.feature_importances_)[::-1]
RF_important_features

array([ 46,  52,  54,  42,  40,  48, 131,  89,   4,  13,  17,  53,   1,
        64, 137,  75,  45,   3,  22,  60,  51,  66,  58,  39,   5,  28,
         9,  65,  99, 102,  30, 155, 135, 162, 100,  59,  24,   7,  41,
        76,  79, 138,  21, 141,  27,  47,  15,  95,  88, 160, 136,  80,
       153,  83,  84,  78,  77, 157, 124,  82,   0, 126,   8,  16,  23,
        19, 105, 101,  10, 158,  18,  57,  20, 143, 159, 106,   2,  63,
        90, 161, 108,  92,  91,  87, 129,  29, 144, 118, 142,  94,  93,
       130, 120,  96,  81, 123, 156, 154, 132, 125, 107, 113,  72,  12,
       104,  61, 139,  85, 145, 149, 152,  70, 114,  86, 119,  71,  73,
       103, 150, 146,  56, 151,  67,  55, 134, 140, 147, 164, 163,  11,
       133,  74, 111,  32,  97, 148,  25, 112,  34,  98,  68,  36,  31,
        26, 109, 110,  43,  33, 121, 117,  49,  62, 115, 122, 127,  50,
        44, 128, 116,  35,  37,  69,  38,   6,  14])

## Use node embeddings

### GCN embedding with SVM

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

In [None]:
embeddings = get_embeddings(model_GCN, data, optimizer, loss_fn)

In [None]:
train_dataset_NE = embeddings[data.train_mask.cpu()]
test_dataset_NE = embeddings[data.test_mask.cpu()]

In [None]:
clf_SVM_NE = SVC(class_weight="balanced")
clf_SVM_NE.fit(train_dataset_NE, train_label_benchmark)
prediction_NE = clf_SVM_NE.predict(test_dataset_NE)

SVM_NE_acc = metrics.accuracy_score(test_label_benchmark, prediction_NE)
print(f'Test Accuracy: {100*SVM_NE_acc:.2f}%')
SVM_NE_pre = metrics.precision_score(test_label_benchmark, prediction_NE)
print(f'Test Precision: {100*SVM_NE_pre:.2f}%')
SVM_NE_recall = metrics.recall_score(test_label_benchmark, prediction_NE)
print(f'Test Recall: {100*SVM_NE_recall:.2f}%')
SVM_NE_f1 = metrics.f1_score(test_label_benchmark, prediction_NE)
print(f'Test F1: {100*SVM_NE_f1:.2f}%')

Test Accuracy: 71.70%
Test Precision: 15.47%
Test Recall: 75.16%
Test F1: 25.66%


### GCN embedding with RandomForest

In [None]:
clf_RF_NE = RandomForestClassifier(n_estimators=500, max_depth=15, random_state=0, class_weight="balanced")
clf_RF_NE.fit(train_dataset_NE, train_label_benchmark)
prediction_NE = clf_RF_NE.predict(test_dataset_NE)

RF_NE_acc = metrics.accuracy_score(test_label_benchmark, prediction_NE)
print(f'Test Accuracy: {100*RF_NE_acc:.2f}%')
RF_NE_pre = metrics.precision_score(test_label_benchmark, prediction_NE)
print(f'Test Precision: {100*RF_NE_pre:.2f}%')
RF_NE_recall = metrics.recall_score(test_label_benchmark, prediction_NE)
print(f'Test Recall: {100*RF_NE_recall:.2f}%')
RF_NE_f1 = metrics.f1_score(test_label_benchmark, prediction_NE)
print(f'Test F1: {100*RF_NE_f1:.2f}%')

Test Accuracy: 84.70%
Test Precision: 23.03%
Test Recall: 57.89%
Test F1: 32.96%


### SAGE embedding with SVM

In [None]:
embeddings = get_embeddings(model_SAGE, data, optimizer, loss_fn)

In [None]:
train_dataset_NE = embeddings[data.train_mask.cpu()]
test_dataset_NE = embeddings[data.test_mask.cpu()]

In [None]:
clf = SVC(class_weight="balanced")
clf.fit(train_dataset_NE, train_label_benchmark)
prediction_NE = clf.predict(test_dataset_NE)

SVM_NE_acc = metrics.accuracy_score(test_label_benchmark, prediction_NE)
print(f'Test Accuracy: {100*SVM_NE_acc:.2f}%')
SVM_NE_pre = metrics.precision_score(test_label_benchmark, prediction_NE)
print(f'Test Precision: {100*SVM_NE_pre:.2f}%')
SVM_NE_recall = metrics.recall_score(test_label_benchmark, prediction_NE)
print(f'Test Recall: {100*SVM_NE_recall:.2f}%')
SVM_NE_f1 = metrics.f1_score(test_label_benchmark, prediction_NE)
print(f'Test F1: {100*SVM_NE_f1:.2f}%')

Test Accuracy: 64.37%
Test Precision: 13.06%
Test Recall: 79.22%
Test F1: 22.42%


### SAGE embedding with RandomForest

In [None]:
clf = RandomForestClassifier(n_estimators=500, max_depth=15, random_state=0, class_weight="balanced")
clf.fit(train_dataset_NE, train_label_benchmark)
prediction_NE = clf.predict(test_dataset_NE)

RF_NE_acc = metrics.accuracy_score(test_label_benchmark, prediction_NE)
print(f'Test Accuracy: {100*RF_NE_acc:.2f}%')
RF_NE_pre = metrics.precision_score(test_label_benchmark, prediction_NE)
print(f'Test Precision: {100*RF_NE_pre:.2f}%')
RF_NE_recall = metrics.recall_score(test_label_benchmark, prediction_NE)
print(f'Test Recall: {100*RF_NE_recall:.2f}%')
RF_NE_f1 = metrics.f1_score(test_label_benchmark, prediction_NE)
print(f'Test F1: {100*RF_NE_f1:.2f}%')

Test Accuracy: 84.49%
Test Precision: 22.34%
Test Recall: 56.05%
Test F1: 31.95%


### GAT embedding with SVM

In [None]:
embeddings = get_embeddings(model_GAT, data, optimizer, loss_fn)

In [None]:
train_dataset_NE = embeddings[data.train_mask.cpu()]
test_dataset_NE = embeddings[data.test_mask.cpu()]

In [None]:
clf = SVC(class_weight="balanced")
clf.fit(train_dataset_NE, train_label_benchmark)
prediction_NE = clf.predict(test_dataset_NE)

SVM_NE_acc = metrics.accuracy_score(test_label_benchmark, prediction_NE)
print(f'Test Accuracy: {100*SVM_NE_acc:.2f}%')
SVM_NE_pre = metrics.precision_score(test_label_benchmark, prediction_NE)
print(f'Test Precision: {100*SVM_NE_pre:.2f}%')
SVM_NE_recall = metrics.recall_score(test_label_benchmark, prediction_NE)
print(f'Test Recall: {100*SVM_NE_recall:.2f}%')
SVM_NE_f1 = metrics.f1_score(test_label_benchmark, prediction_NE)
print(f'Test F1: {100*SVM_NE_f1:.2f}%')

Test Accuracy: 74.45%
Test Precision: 16.91%
Test Recall: 74.98%
Test F1: 27.60%


### GAT embedding with RandomForest

In [None]:
clf = RandomForestClassifier(n_estimators=500, max_depth=15, random_state=0, class_weight="balanced")
clf.fit(train_dataset_NE, train_label_benchmark)
prediction_NE = clf.predict(test_dataset_NE)

RF_NE_acc = metrics.accuracy_score(test_label_benchmark, prediction_NE)
print(f'Test Accuracy: {100*RF_NE_acc:.2f}%')
RF_NE_pre = metrics.precision_score(test_label_benchmark, prediction_NE)
print(f'Test Precision: {100*RF_NE_pre:.2f}%')
RF_NE_recall = metrics.recall_score(test_label_benchmark, prediction_NE)
print(f'Test Recall: {100*RF_NE_recall:.2f}%')
RF_NE_f1 = metrics.f1_score(test_label_benchmark, prediction_NE)
print(f'Test F1: {100*RF_NE_f1:.2f}%')

Test Accuracy: 87.38%
Test Precision: 26.57%
Test Recall: 53.46%
Test F1: 35.50%


## Use all features and node embeddings

### GAT embedding and all features with SVM

In [None]:
embeddings = get_embeddings(model_GAT, data, optimizer, loss_fn)
train_dataset_NE = embeddings[data.train_mask.cpu()]
test_dataset_NE = embeddings[data.test_mask.cpu()]

In [None]:
train_dataset_AFNE = np.concatenate((train_dataset_AF, train_dataset_NE), axis=1)
test_dataset_AFNE = np.concatenate((test_dataset_AF, test_dataset_NE), axis=1)

In [None]:
clf = SVC(class_weight="balanced")
clf.fit(train_dataset_AFNE, train_label_benchmark)
prediction_AFNE = clf.predict(test_dataset_AFNE)

SVM_AFNE_acc = metrics.accuracy_score(test_label_benchmark, prediction_AFNE)
print(f'Test Accuracy: {100*SVM_AFNE_acc:.2f}%')
SVM_AFNE_pre = metrics.precision_score(test_label_benchmark, prediction_AFNE)
print(f'Test Precision: {100*SVM_AFNE_pre:.2f}%')
SVM_AFNE_recall = metrics.recall_score(test_label_benchmark, prediction_AFNE)
print(f'Test Recall: {100*SVM_AFNE_recall:.2f}%')
SVM_AFNE_f1 = metrics.f1_score(test_label_benchmark, prediction_AFNE)
print(f'Test F1: {100*SVM_AFNE_f1:.2f}%')

Test Accuracy: 81.99%
Test Precision: 22.97%
Test Recall: 75.35%
Test F1: 35.21%


### GAT embedding and all features with RandomForest

In [None]:
clf = RandomForestClassifier(n_estimators=500, max_depth=15, random_state=0, class_weight="balanced")
clf.fit(train_dataset_AFNE, train_label_benchmark)
prediction_AFNE = clf.predict(test_dataset_AFNE)

RF_AFNE_acc = metrics.accuracy_score(test_label_benchmark, prediction_AFNE)
print(f'Test Accuracy: {100*RF_AFNE_acc:.2f}%')
RF_AFNE_pre = metrics.precision_score(test_label_benchmark, prediction_AFNE)
print(f'Test Precision: {100*RF_AFNE_pre:.2f}%')
RF_AFNE_recall = metrics.recall_score(test_label_benchmark, prediction_AFNE)
print(f'Test Recall: {100*RF_AFNE_recall:.2f}%')
RF_AFNE_f1 = metrics.f1_score(test_label_benchmark, prediction_AFNE)
print(f'Test F1: {100*RF_AFNE_f1:.2f}%')

Test Accuracy: 97.55%
Test Precision: 89.02%
Test Recall: 71.10%
Test F1: 79.06%


Sort features by important scores

In [None]:
RF_important_features = np.argsort(clf.feature_importances_)[::-1]
RF_important_features

array([ 40,  54,  52,  46,  48,  42, 131, 137,   4,  89,  13,   1,  17,
        28,   3,  58,  22,  51,  64,  59,  53,  75,  60,  66,  45, 155,
        39,   5,  24,  47,  95,  76, 141, 162,  30,   9,  65, 102,  78,
        99, 100,  77,  27, 138, 160,  83, 135,  41,  21,  88,  84,  82,
        79, 143, 136, 153,  19, 158,   7, 124,  20, 118,  29, 157, 126,
        80,  10,  16,  15,   0, 106,   8, 101, 140, 120,  63,  90, 316,
        23, 105,  57,  81, 139,  73, 159,  11,  92, 161,  18,   2,  91,
       123,  67, 107, 125,  87, 108, 129,  94,  55, 133, 229, 154, 104,
       149, 223,  86,  85, 145, 193, 277, 156, 103,  31,  96, 560,  32,
       577, 119, 301, 142,  68, 621, 244,  62, 130,  61, 132,  12,  56,
       364, 144,  93, 191, 146,  71,  98, 509, 186,  72,  74,  26, 152,
       559, 163, 134, 151,  44, 405, 675,  70, 150, 164, 112, 554, 114,
       111, 567, 293,  97, 634, 491, 187,  49, 147,  25, 251, 328, 329,
       315, 389, 540, 388, 391, 113, 526,  43, 317, 445, 628,  3

# Explain

## GNNExplainer

### GCN Explain

In [None]:
from torch_geometric.nn import Set2Set
from torch_geometric.explain import GNNExplainer, Explainer, CaptumExplainer
import torch_geometric.transforms as TabError
from tqdm import tqdm, trange
import matplotlib.pyplot as plt

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, dropout):
        super().__init__()
        torch.manual_seed(777)

        # Convolution layers
        if num_layers > 1:
            self.convs = nn.ModuleList([GCNConv(dataset.num_features, hidden_channels)])
            self.convs.extend([GCNConv(hidden_channels, hidden_channels) for i in range(num_layers - 2)])
            self.convs.append(GCNConv(hidden_channels, dataset.num_classes))

            # Batch normilization
            self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_channels)
                                     for i in range(num_layers - 1)])
        else:
            self.convs = nn.ModuleList([GCNConv(dataset.num_features, dataset.num_classes)])
            self.bns = nn.ModuleList([])

        # Softmax layer
        self.softmax = nn.Softmax(1)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    # initialize parameters
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, edge_index):
        for gcn, bn in zip(self.convs, self.bns):
            x = self.dropout(torch.relu(bn(gcn(x, edge_index))))
        x = self.convs[-1](x, edge_index)
        return self.softmax(x)

In [None]:
model=GCN(hidden_channels=128, num_layers=2, dropout=0.3).to(device)
data = data.to(device)

In [None]:
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

explanation_GCN = explainer(data.x, data.edge_index)
path = 'feature_importance_GCN.png'
explanation_GCN.visualize_feature_importance(path, top_k=20)
print(f"Feature importance plot has been saved to '{path}'")



Feature importance plot has been saved to 'feature_importance_GCN.png'


In [None]:
GCN_node_importantScore = explanation_GCN.node_mask.sum(0).cpu().numpy()
GCN_node_importantNorm = np.linalg.norm(GCN_node_importantScore)
GCN_node_importantScore = GCN_node_importantScore / GCN_node_importantNorm
GCN_important_features = np.argsort(GCN_node_importantScore)[::-1]
GCN_important_features

NameError: ignored

### GAT Explain

In [None]:
class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads, num_layers, dropout):
        super().__init__()
        torch.manual_seed(777)
        self.num_layers = num_layers

        if num_layers > 1:
            # GAT layers
            self.convs = nn.ModuleList([GATConv(dataset.num_features, hidden_channels, heads)])
            self.convs.extend([GATConv(heads*hidden_channels, hidden_channels, heads) for i in range(num_layers - 2)])
            self.convs.append(GATConv(heads*hidden_channels, dataset.num_classes))

            # Batch Normilization
            self.bns = nn.ModuleList([nn.BatchNorm1d(heads*hidden_channels)
                                     for i in range(num_layers - 1)])
        else:
            self.convs = nn.ModuleList([GATConv(dataset.num_features, dataset.num_classes)])
            self.bns = nn.ModuleList([])

         # Softmax layer
        self.softmax = nn.Softmax(1)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    # initialize parameters
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, edge_index):
        for gat, bn in zip(self.convs, self.bns):
            x = self.dropout(torch.relu(bn(gat(x, edge_index))))
        x = self.convs[-1](x, edge_index)

        return self.softmax(x)


In [None]:
model=GAT(hidden_channels=64, heads=8, num_layers=2, dropout=0.3).to(device)

In [None]:
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=100),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

explanation_GAT = explainer(data.x, data.edge_index)
print(f'Generated explanations in {explanation_GAT.available_explanations}')

path = 'feature_importance_GAT.png'
explanation_GAT.visualize_feature_importance(path, top_k=20)
print(f"Feature importance plot has been saved to '{path}'")

Generated explanations in ['node_mask', 'edge_mask']
Feature importance plot has been saved to 'feature_importance_GAT.png'


In [None]:
GAT_node_importantScore = explanation_GAT.node_mask.sum(0).cpu().numpy()
GAT_node_importantNorm = np.linalg.norm(GAT_node_importantScore)
GAT_node_importantScore = GAT_node_importantScore / GAT_node_importantNorm
GAT_important_features = np.argsort(GAT_node_importantScore)[::-1]
GAT_important_features

In [None]:
class GraphSAGENet(nn.Module):
    def __init__(self, hidden_channels, num_layers, dropout):
        super().__init__()
        torch.manual_seed(777)

        if num_layers > 1:
            # Convolution layers
            self.convs = nn.ModuleList([SAGEConv(dataset.num_features, hidden_channels)])
            self.convs.extend([SAGEConv(hidden_channels, hidden_channels) for i in range(num_layers - 2)])
            self.convs.append(SAGEConv(hidden_channels, dataset.num_classes))

             # Batch normilization
            self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_channels)
                                     for i in range(num_layers - 1)])

        else:
            self.convs = nn.ModuleList([SAGEConv(dataset.num_features, dataset.num_classes)])
            self.bns = nn.ModuleList([])

        # Softmax layer
        self.softmax = nn.Softmax(1)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    # initialize parameters
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, edge_index):
        for gcn, bn in zip(self.convs, self.bns):
            x = self.dropout(torch.relu(bn(gcn(x, edge_index))))
        x = self.convs[-1](x, edge_index)
        return self.softmax(x)

In [None]:
model = GraphSAGENet(hidden_channels=64, num_layers=2, dropout=0.3).to(device)

In [None]:
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

explanation_SAGE = explainer(data.x, data.edge_index)
path = 'feature_importance_SAGE.png'
explanation_SAGE.visualize_feature_importance(path, top_k=20)
print(f"Feature importance plot has been saved to '{path}'")

path = 'subgraph.pdf'
explanation_SAGE.visualize_graph(path)
print(f"Subgraph visualization plot has been saved to '{path}'")

Feature importance plot has been saved to 'feature_importance_SAGE.png'


KeyboardInterrupt: ignored

In [None]:
SAGE_node_importantScore = explanation_SAGE.node_mask.sum(0).cpu().numpy()
SAGE_node_importantNorm = np.linalg.norm(SAGE_node_importantScore)
SAGE_node_importantScore = SAGE_node_importantScore / SAGE_node_importantNorm
SAGE_important_features = np.argsort(SAGE_node_importantScore)[::-1]
SAGE_important_features

array([157,  67, 158, 144,  79,  90, 133, 143,  61,  88,  19, 150,  80,
        68, 151, 138, 139, 100, 135, 140, 134,  74, 152,  91, 146, 112,
        86,  92,  62,  84, 113, 145,  85, 114, 102, 147, 111,   2,  99,
       155, 136,  51, 148, 149,  20,  54,  55,  32,  39,  48,  98,  78,
        12,  97,  89, 121, 163, 164,  46, 156, 154,  77, 142,  93,  26,
       104, 106,  73,  40, 103,   0, 110,  83,  52,  94, 153,  66,  25,
        87, 115, 137,  11,  31, 122,  65,   8,  28,  56, 116, 109,  44,
        43, 127, 125, 160,  49,  45,  58,  64,  96, 105,  42, 118,  82,
        60, 128,  50, 119,  16, 123,  41, 130,  29, 107, 120,  17,  18,
        21,  47,   3, 132,  22,  76, 117, 131, 108,   5, 129,  53, 124,
        15,  38, 126,  10, 162,  24,  95,   7,   1,  81,  37, 101,  75,
       141,  36,  63,  59, 159,  30, 161,  27,  13,  71,  69,  33,  34,
        57,   6,  70,   4,  35,   9,  14,  72,  23])

In [None]:
important_score = GAT_node_importantScore + GCN_node_importantScore + SAGE_node_importantScore
important_features = np.argsort(important_score)[::-1]
important_features

array([ 61, 135,  99, 147, 100, 114,   2, 136, 157,  19, 133, 102, 111,
       139, 158,  67, 138, 144,  79, 146, 134,  86,  90, 148,  91, 151,
        62,  88,  85, 152,  68, 143,  92,  80, 149,  54,  20, 155, 140,
       145, 163,  51, 154,  74, 112,  65, 164,  89, 142,  52, 113,  39,
       150,  76, 156,  24,  53,  84, 153,  78,  83,  42,   0,  55, 137,
        32, 106,  97,  46, 104,  73,  18,  98,  16,  48, 103,  58,  26,
        82,  12,  87,  96,  11,  28,  25, 118,  77, 121, 110, 105,  31,
        40,  45,  10, 109,   8,  56,  94,  66, 115, 122, 130,  43, 161,
        93, 126, 116,  44,  59,  75, 128,  47, 132,   5,  50,  60, 117,
        17, 131,  64, 107, 127,   7,  30,  49,   3,  21, 125, 162, 160,
       141, 120,  22, 119, 124, 123,  29,  27,   1, 129,  23, 108,   6,
        41,   4, 159,  38,  37,  15,  13,  95, 101,  36,  81,  69,  63,
        34,  70,  33,  57,  35,  71,  14,  72,   9])