In [25]:
import torch_geometric

import numpy as np
import pandas as pd
import torch
from torch.nn import Linear, LayerNorm, ReLU, Dropout
import torch.nn.functional as F
from torch_geometric.nn import ChebConv, NNConv, DeepGCNLayer, GATConv, DenseGCNConv, GCNConv, GraphConv
from torch_geometric.data import Data, DataLoader
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
import scipy.sparse as sp

import warnings
warnings.filterwarnings("ignore")

# ref: https://www.kaggle.com/code/divyareddyyeruva/elliptic-gcn-pyg

#### Import dataset

In [26]:
# import data 
df_features = pd.read_csv('data/elliptic_txs_features.csv', header=None)
df_edges = pd.read_csv("data/elliptic_txs_edgelist.csv")
df_classes =  pd.read_csv("data/elliptic_txs_classes.csv")
# map unknown classes to -1
df_classes['class'] = df_classes['class'].apply(lambda x: 0 if x == "unknown" else int(x))

# merging dataframes
df_merge = df_features.merge(df_classes, how='left', right_on="txId", left_on=0)
display(df_merge.head())

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,159,160,161,162,163,164,165,166,txId,class
0,230425980,1,-0.171469,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162097,...,1.46133,1.461369,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,230425980,0
1,5530458,1,-0.171484,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162112,...,-0.979074,-0.978556,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,5530458,0
2,232022460,1,-0.172107,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162749,...,-0.979074,-0.978556,-0.098889,-0.106715,-0.131155,-0.183671,-0.120613,-0.119792,232022460,0
3,232438397,1,0.163054,1.96379,-0.646376,12.409294,-0.063725,9.782742,12.414558,-0.163645,...,0.241128,0.241406,1.072793,0.08553,-0.131155,0.677799,-0.120613,-0.119792,232438397,2
4,230460314,1,1.011523,-0.081127,-1.201369,1.153668,0.333276,1.312656,-0.061584,-0.163523,...,0.517257,0.579382,0.018279,0.277775,0.326394,1.29375,0.178136,0.179117,230460314,0


#### Split dataset masks

In [27]:
# take time step from 1 to 34 as train data
df_train = df_merge[df_merge[1] <= 34]
# take rest as test data
df_test = df_merge[df_merge[1] > 34]
display(df_train.head())
display(df_test.head())

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,159,160,161,162,163,164,165,166,txId,class
0,230425980,1,-0.171469,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162097,...,1.46133,1.461369,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,230425980,0
1,5530458,1,-0.171484,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162112,...,-0.979074,-0.978556,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,5530458,0
2,232022460,1,-0.172107,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162749,...,-0.979074,-0.978556,-0.098889,-0.106715,-0.131155,-0.183671,-0.120613,-0.119792,232022460,0
3,232438397,1,0.163054,1.96379,-0.646376,12.409294,-0.063725,9.782742,12.414558,-0.163645,...,0.241128,0.241406,1.072793,0.08553,-0.131155,0.677799,-0.120613,-0.119792,232438397,2
4,230460314,1,1.011523,-0.081127,-1.201369,1.153668,0.333276,1.312656,-0.061584,-0.163523,...,0.517257,0.579382,0.018279,0.277775,0.326394,1.29375,0.178136,0.179117,230460314,0


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,159,160,161,162,163,164,165,166,txId,class
136265,54785412,35,-0.159837,-0.030732,1.018602,-0.12197,-0.043875,-0.113002,-0.061584,-0.150191,...,-0.979074,-0.978556,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,54785412,0
136266,69354384,35,-0.165893,-0.029572,1.018602,-0.12197,-0.043875,-0.113002,-0.061584,-0.156388,...,1.46133,1.461369,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,69354384,0
136267,54775772,35,-0.129693,0.070098,1.573595,-0.12197,0.075226,-0.113002,-0.061584,-0.119348,...,-0.463356,-0.462939,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,54775772,0
136268,69343934,35,-0.111789,1.29491,1.573595,0.553368,-0.043875,0.641758,-0.061584,-0.159732,...,1.46133,1.461369,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,69343934,0
136269,70102750,35,-0.172796,-0.081127,-1.201369,-0.046932,-0.043875,-0.02914,-0.061584,-0.163571,...,-0.979074,-0.978556,0.018279,0.854508,2.146417,2.013077,-1.760926,-1.760984,70102750,0


In [76]:
num_features = 94

# split train and test features
train_features = df_train.iloc[:, 2:2+num_features].values
test_features = df_test.iloc[:, 2:2+num_features].values

# split train and test labels
train_labels = df_train.iloc[:, -1].values
test_labels = df_test.iloc[:, -1].values

print("Train features shape: ", train_features.shape)
print("Train labels shape: ", train_labels.shape)

print("Test features shape: ", test_features.shape)
print("Test labels shape: ", test_labels.shape)

Train features shape:  (136265, 94)
Train labels shape:  (136265,)
Test features shape:  (67504, 94)
Test labels shape:  (67504,)


In [77]:
# all nodes in data
nodes = df_merge[0].values
map_id = {j:i for i,j in enumerate(nodes)} # mapping nodes to indexes

edges = df_edges.copy()
edges.txId1 = edges.txId1.map(map_id)
edges.txId2 = edges.txId2.map(map_id)
edges = edges.astype(int)

edge_index = np.array(edges.values).T

edge_index = torch.tensor(edge_index, dtype=torch.long).contiguous()
edge_weight = torch.tensor([1]* edge_index.shape[1] , dtype=torch.double)
print(edge_index.shape)

torch.Size([2, 234355])


In [30]:
train_nodes = df_train[0].unique()
map_id = {j:i for i,j in enumerate(nodes)}

train_idx = [map_id[node_id] for node_id in train_nodes]
edges['txId1']


0              0
1              2
2              4
3              6
4              8
           ...  
234350    203602
234351    203603
234352    201921
234353    201480
234354    201954
Name: txId1, Length: 234355, dtype: int32

In [78]:
train_edge_mask = edges['txId1'].isin(train_idx) & edges['txId2'].isin(train_idx)
train_edge_index = edge_index[:, train_edge_mask]
train_edge_weight = edge_weight[train_edge_mask]
train_edge_index.shape

torch.Size([2, 156843])

In [79]:
# inverse mapping for test data
test_nodes = df_test[0].unique()
map_id = {j:i for i,j in enumerate(nodes)}
test_idx = [map_id[node_id] for node_id in test_nodes]

test_edge_mask = edges['txId1'].isin(test_idx) & edges['txId2'].isin(test_idx)
test_edge_index = edge_index[:, test_edge_mask]
test_edge_weight = edge_weight[test_edge_mask]
test_edge_index.shape

torch.Size([2, 77512])

In [80]:
# construct graph train data and test data
train_graph = Data(x=torch.tensor(train_features, dtype=torch.double), edge_index=train_edge_index, edge_weight=train_edge_weight, y=torch.tensor(train_labels, dtype=torch.long))
# train_graph = Data(x=torch.tensor(train_features, dtype=torch.float), edge_index=edge_index, edge_weight=edge_weight, y=torch.tensor(train_labels, dtype=torch.double))
test_graph = Data(x=torch.tensor(test_features, dtype=torch.double), edge_index=test_edge_index, edge_weight=test_edge_weight, y=torch.tensor(test_labels, dtype=torch.long))
print(train_graph)
print(test_graph)

Data(x=[136265, 94], edge_index=[2, 156843], y=[136265], edge_weight=[156843])
Data(x=[67504, 94], edge_index=[2, 77512], y=[67504], edge_weight=[77512])


#### GCN Model and Training

In [83]:
# 2-layer GCN
class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels=128):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)
        
    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
        # x: [136265, 94]   edge_index: [2, 234355]
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        print(edge_weight.dtype)
        x = self.conv2(x, edge_index, edge_weight)
        return F.log_softmax(x, dim=1)

In [54]:
device = torch.device('cpu')
device

device(type='cpu')

In [94]:
def train(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.cross_entropy(out, data.y)
    loss.backward()
    optimizer.step()
    return loss

def test(model, data):
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)
    acc = accuracy_score(data.y, pred)
    return acc

In [58]:
model = GCN(num_features, 3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
num_epochs = 200
print(model)

for epoch in tqdm(range(1, num_epochs+1)):
    model = model.double()
    train_loss = train(model, train_graph.to(device), optimizer)
    # train_acc = test(model, train_graph)
    # test_acc = test(model, test_graph)
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, ')

GCN(
  (conv1): GCNConv(94, 128)
  (conv2): GCNConv(128, 3)
)


  0%|          | 0/200 [00:00<?, ?it/s]

torch.float64


  0%|          | 1/200 [00:01<03:48,  1.15s/it]

torch.float64


  1%|          | 2/200 [00:02<03:45,  1.14s/it]

torch.float64


  2%|▏         | 3/200 [00:03<03:45,  1.14s/it]

torch.float64


  2%|▏         | 4/200 [00:04<03:42,  1.13s/it]

torch.float64


  2%|▎         | 5/200 [00:05<03:46,  1.16s/it]

torch.float64


  3%|▎         | 6/200 [00:06<03:41,  1.14s/it]

torch.float64


  4%|▎         | 7/200 [00:08<03:45,  1.17s/it]

torch.float64


  4%|▍         | 8/200 [00:09<03:43,  1.16s/it]

torch.float64


  4%|▍         | 9/200 [00:10<03:42,  1.16s/it]

torch.float64


  5%|▌         | 10/200 [00:11<03:41,  1.17s/it]

Epoch: 010, Train Loss: 0.5466, 
torch.float64


  6%|▌         | 11/200 [00:12<03:44,  1.19s/it]

torch.float64


  6%|▌         | 12/200 [00:13<03:39,  1.17s/it]

torch.float64


  6%|▋         | 13/200 [00:15<03:39,  1.17s/it]

torch.float64


  7%|▋         | 14/200 [00:16<03:36,  1.16s/it]

torch.float64


  8%|▊         | 15/200 [00:17<03:35,  1.16s/it]

torch.float64


  8%|▊         | 16/200 [00:18<03:30,  1.14s/it]

torch.float64


  8%|▊         | 17/200 [00:19<03:36,  1.18s/it]

torch.float64


  9%|▉         | 18/200 [00:20<03:31,  1.16s/it]

torch.float64


 10%|▉         | 19/200 [00:22<03:33,  1.18s/it]

torch.float64


 10%|█         | 20/200 [00:23<03:34,  1.19s/it]

Epoch: 020, Train Loss: 0.5009, 
torch.float64


 10%|█         | 21/200 [00:24<03:31,  1.18s/it]

torch.float64


 11%|█         | 22/200 [00:25<03:30,  1.18s/it]

torch.float64


 12%|█▏        | 23/200 [00:26<03:25,  1.16s/it]

torch.float64


 12%|█▏        | 24/200 [00:27<03:26,  1.17s/it]

torch.float64


 12%|█▎        | 25/200 [00:29<03:22,  1.16s/it]

torch.float64


 13%|█▎        | 26/200 [00:30<03:20,  1.15s/it]

torch.float64


 14%|█▎        | 27/200 [00:31<03:18,  1.15s/it]

torch.float64


 14%|█▍        | 28/200 [00:32<03:14,  1.13s/it]

torch.float64


 14%|█▍        | 29/200 [00:33<03:20,  1.17s/it]

torch.float64


 15%|█▌        | 30/200 [00:34<03:15,  1.15s/it]

Epoch: 030, Train Loss: 0.4777, 
torch.float64


 16%|█▌        | 31/200 [00:36<03:18,  1.18s/it]

torch.float64


 16%|█▌        | 32/200 [00:37<03:16,  1.17s/it]

torch.float64


 16%|█▋        | 33/200 [00:38<03:15,  1.17s/it]

torch.float64


 17%|█▋        | 34/200 [00:39<03:09,  1.14s/it]

torch.float64


 18%|█▊        | 35/200 [00:40<03:12,  1.17s/it]

torch.float64


 18%|█▊        | 36/200 [00:41<03:07,  1.14s/it]

torch.float64


 18%|█▊        | 37/200 [00:43<03:10,  1.17s/it]

torch.float64


 19%|█▉        | 38/200 [00:44<03:08,  1.16s/it]

torch.float64


 20%|█▉        | 39/200 [00:45<03:07,  1.16s/it]

torch.float64


 20%|██        | 40/200 [00:46<03:05,  1.16s/it]

Epoch: 040, Train Loss: 0.4654, 
torch.float64


 20%|██        | 41/200 [00:47<03:07,  1.18s/it]

torch.float64


 21%|██        | 42/200 [00:48<03:03,  1.16s/it]

torch.float64


 22%|██▏       | 43/200 [00:50<03:05,  1.18s/it]

torch.float64


 22%|██▏       | 44/200 [00:51<03:02,  1.17s/it]

torch.float64


 22%|██▎       | 45/200 [00:52<03:00,  1.16s/it]

torch.float64


 23%|██▎       | 46/200 [00:53<02:55,  1.14s/it]

torch.float64


 24%|██▎       | 47/200 [00:54<02:59,  1.18s/it]

torch.float64


 24%|██▍       | 48/200 [00:55<02:55,  1.16s/it]

torch.float64


 24%|██▍       | 49/200 [00:57<02:56,  1.17s/it]

torch.float64


 25%|██▌       | 50/200 [00:58<02:54,  1.16s/it]

Epoch: 050, Train Loss: 0.4561, 
torch.float64


 26%|██▌       | 51/200 [00:59<02:53,  1.17s/it]

torch.float64


 26%|██▌       | 52/200 [01:00<02:52,  1.16s/it]

torch.float64


 26%|██▋       | 53/200 [01:01<02:54,  1.18s/it]

torch.float64


 27%|██▋       | 54/200 [01:02<02:50,  1.17s/it]

torch.float64


 28%|██▊       | 55/200 [01:04<02:54,  1.20s/it]

torch.float64


 28%|██▊       | 56/200 [01:05<02:51,  1.19s/it]

torch.float64


 28%|██▊       | 57/200 [01:06<02:48,  1.18s/it]

torch.float64


 29%|██▉       | 58/200 [01:07<02:47,  1.18s/it]

torch.float64


 30%|██▉       | 59/200 [01:08<02:44,  1.16s/it]

torch.float64


 30%|███       | 60/200 [01:09<02:43,  1.17s/it]

Epoch: 060, Train Loss: 0.4502, 
torch.float64


 30%|███       | 61/200 [01:11<02:41,  1.16s/it]

torch.float64


 31%|███       | 62/200 [01:12<02:43,  1.18s/it]

torch.float64


 32%|███▏      | 63/200 [01:13<02:39,  1.16s/it]

torch.float64


 32%|███▏      | 64/200 [01:14<02:40,  1.18s/it]

torch.float64


 32%|███▎      | 65/200 [01:15<02:36,  1.16s/it]

torch.float64


 33%|███▎      | 66/200 [01:16<02:37,  1.18s/it]

torch.float64


 34%|███▎      | 67/200 [01:18<02:37,  1.19s/it]

torch.float64


 34%|███▍      | 68/200 [01:19<02:37,  1.19s/it]

torch.float64


 34%|███▍      | 69/200 [01:20<02:33,  1.17s/it]

torch.float64


 35%|███▌      | 70/200 [01:21<02:34,  1.19s/it]

Epoch: 070, Train Loss: 0.4437, 
torch.float64


 36%|███▌      | 71/200 [01:22<02:30,  1.17s/it]

torch.float64


 36%|███▌      | 72/200 [01:24<02:30,  1.17s/it]

torch.float64


 36%|███▋      | 73/200 [01:25<02:28,  1.17s/it]

torch.float64


 37%|███▋      | 74/200 [01:26<02:28,  1.18s/it]

torch.float64


 38%|███▊      | 75/200 [01:27<02:25,  1.16s/it]

torch.float64


 38%|███▊      | 76/200 [01:28<02:28,  1.19s/it]

torch.float64


 38%|███▊      | 77/200 [01:29<02:24,  1.18s/it]

torch.float64


 39%|███▉      | 78/200 [01:31<02:23,  1.18s/it]

torch.float64


 40%|███▉      | 79/200 [01:32<02:29,  1.24s/it]

torch.float64


 40%|████      | 80/200 [01:33<02:25,  1.22s/it]

Epoch: 080, Train Loss: 0.4381, 
torch.float64


 40%|████      | 81/200 [01:34<02:24,  1.21s/it]

torch.float64


 41%|████      | 82/200 [01:36<02:20,  1.19s/it]

torch.float64


 42%|████▏     | 83/200 [01:37<02:18,  1.18s/it]

torch.float64


 42%|████▏     | 84/200 [01:38<02:17,  1.18s/it]

torch.float64


 42%|████▎     | 85/200 [01:39<02:14,  1.17s/it]

torch.float64


 43%|████▎     | 86/200 [01:40<02:15,  1.19s/it]

torch.float64


 44%|████▎     | 87/200 [01:41<02:12,  1.17s/it]

torch.float64


 44%|████▍     | 88/200 [01:43<02:12,  1.18s/it]

torch.float64


 44%|████▍     | 89/200 [01:44<02:12,  1.20s/it]

torch.float64


 45%|████▌     | 90/200 [01:45<02:10,  1.19s/it]

Epoch: 090, Train Loss: 0.4346, 
torch.float64


 46%|████▌     | 91/200 [01:46<02:09,  1.19s/it]

torch.float64


 46%|████▌     | 92/200 [01:47<02:07,  1.18s/it]

torch.float64


 46%|████▋     | 93/200 [01:49<02:09,  1.21s/it]

torch.float64


 47%|████▋     | 94/200 [01:50<02:09,  1.22s/it]

torch.float64


 48%|████▊     | 95/200 [01:51<02:05,  1.19s/it]

torch.float64


 48%|████▊     | 96/200 [01:52<02:06,  1.21s/it]

torch.float64


 48%|████▊     | 97/200 [01:53<02:02,  1.19s/it]

torch.float64


 49%|████▉     | 98/200 [01:55<02:01,  1.19s/it]

torch.float64


 50%|████▉     | 99/200 [01:56<02:00,  1.20s/it]

torch.float64


 50%|█████     | 100/200 [01:57<01:58,  1.18s/it]

Epoch: 100, Train Loss: 0.4303, 
torch.float64


 50%|█████     | 101/200 [01:58<01:59,  1.21s/it]

torch.float64


 51%|█████     | 102/200 [01:59<02:00,  1.23s/it]

torch.float64


 52%|█████▏    | 103/200 [02:01<02:00,  1.24s/it]

torch.float64


 52%|█████▏    | 104/200 [02:02<02:01,  1.27s/it]

torch.float64


 52%|█████▎    | 105/200 [02:03<02:00,  1.27s/it]

torch.float64


 53%|█████▎    | 106/200 [02:05<02:01,  1.29s/it]

torch.float64


 54%|█████▎    | 107/200 [02:06<01:59,  1.29s/it]

torch.float64


 54%|█████▍    | 108/200 [02:07<01:54,  1.24s/it]

torch.float64


 55%|█████▍    | 109/200 [02:08<01:53,  1.25s/it]

torch.float64


 55%|█████▌    | 110/200 [02:10<01:50,  1.23s/it]

Epoch: 110, Train Loss: 0.4263, 
torch.float64


 56%|█████▌    | 111/200 [02:11<01:49,  1.23s/it]

torch.float64


 56%|█████▌    | 112/200 [02:12<01:48,  1.24s/it]

torch.float64


 56%|█████▋    | 113/200 [02:13<01:49,  1.26s/it]

torch.float64


 57%|█████▋    | 114/200 [02:15<01:48,  1.26s/it]

torch.float64


 57%|█████▊    | 115/200 [02:16<01:48,  1.28s/it]

torch.float64


 58%|█████▊    | 116/200 [02:17<01:42,  1.22s/it]

torch.float64


 58%|█████▊    | 117/200 [02:18<01:41,  1.22s/it]

torch.float64


 59%|█████▉    | 118/200 [02:19<01:37,  1.18s/it]

torch.float64


 60%|█████▉    | 119/200 [02:21<01:36,  1.19s/it]

torch.float64


 60%|██████    | 120/200 [02:22<01:34,  1.18s/it]

Epoch: 120, Train Loss: 0.4231, 
torch.float64


 60%|██████    | 121/200 [02:23<01:32,  1.17s/it]

torch.float64


 61%|██████    | 122/200 [02:24<01:29,  1.14s/it]

torch.float64


 62%|██████▏   | 123/200 [02:25<01:30,  1.18s/it]

torch.float64


 62%|██████▏   | 124/200 [02:26<01:28,  1.17s/it]

torch.float64


 62%|██████▎   | 125/200 [02:28<01:28,  1.18s/it]

torch.float64


 63%|██████▎   | 126/200 [02:29<01:26,  1.17s/it]

torch.float64


 64%|██████▎   | 127/200 [02:30<01:24,  1.16s/it]

torch.float64


 64%|██████▍   | 128/200 [02:31<01:22,  1.15s/it]

torch.float64


 64%|██████▍   | 129/200 [02:32<01:22,  1.16s/it]

torch.float64


 65%|██████▌   | 130/200 [02:33<01:19,  1.14s/it]

Epoch: 130, Train Loss: 0.4203, 
torch.float64


 66%|██████▌   | 131/200 [02:34<01:19,  1.15s/it]

torch.float64


 66%|██████▌   | 132/200 [02:36<01:17,  1.15s/it]

torch.float64


 66%|██████▋   | 133/200 [02:37<01:17,  1.15s/it]

torch.float64


 67%|██████▋   | 134/200 [02:38<01:16,  1.16s/it]

torch.float64


 68%|██████▊   | 135/200 [02:39<01:15,  1.16s/it]

torch.float64


 68%|██████▊   | 136/200 [02:40<01:15,  1.18s/it]

torch.float64


 68%|██████▊   | 137/200 [02:41<01:12,  1.15s/it]

torch.float64


 69%|██████▉   | 138/200 [02:43<01:12,  1.16s/it]

torch.float64


 70%|██████▉   | 139/200 [02:44<01:10,  1.16s/it]

torch.float64


 70%|███████   | 140/200 [02:45<01:09,  1.16s/it]

Epoch: 140, Train Loss: 0.4169, 
torch.float64


 70%|███████   | 141/200 [02:46<01:08,  1.16s/it]

torch.float64


 71%|███████   | 142/200 [02:47<01:07,  1.17s/it]

torch.float64


 72%|███████▏  | 143/200 [02:48<01:06,  1.17s/it]

torch.float64


 72%|███████▏  | 144/200 [02:50<01:07,  1.21s/it]

torch.float64


 72%|███████▎  | 145/200 [02:51<01:06,  1.21s/it]

torch.float64


 73%|███████▎  | 146/200 [02:52<01:04,  1.19s/it]

torch.float64


 74%|███████▎  | 147/200 [02:53<01:03,  1.20s/it]

torch.float64


 74%|███████▍  | 148/200 [02:54<01:01,  1.18s/it]

torch.float64


 74%|███████▍  | 149/200 [02:56<01:00,  1.19s/it]

torch.float64


 75%|███████▌  | 150/200 [02:57<00:58,  1.18s/it]

Epoch: 150, Train Loss: 0.4149, 
torch.float64


 76%|███████▌  | 151/200 [02:58<00:57,  1.17s/it]

torch.float64


 76%|███████▌  | 152/200 [02:59<00:56,  1.17s/it]

torch.float64


 76%|███████▋  | 153/200 [03:00<00:55,  1.18s/it]

torch.float64


 77%|███████▋  | 154/200 [03:01<00:53,  1.16s/it]

torch.float64


 78%|███████▊  | 155/200 [03:03<00:52,  1.17s/it]

torch.float64


 78%|███████▊  | 156/200 [03:04<00:51,  1.17s/it]

torch.float64


 78%|███████▊  | 157/200 [03:05<00:50,  1.18s/it]

torch.float64


 79%|███████▉  | 158/200 [03:06<00:50,  1.19s/it]

torch.float64


 80%|███████▉  | 159/200 [03:07<00:48,  1.19s/it]

torch.float64


 80%|████████  | 160/200 [03:09<00:47,  1.19s/it]

Epoch: 160, Train Loss: 0.4122, 
torch.float64


 80%|████████  | 161/200 [03:10<00:46,  1.20s/it]

torch.float64


 81%|████████  | 162/200 [03:11<00:46,  1.22s/it]

torch.float64


 82%|████████▏ | 163/200 [03:12<00:44,  1.20s/it]

torch.float64


 82%|████████▏ | 164/200 [03:13<00:43,  1.20s/it]

torch.float64


 82%|████████▎ | 165/200 [03:15<00:41,  1.18s/it]

torch.float64


 83%|████████▎ | 166/200 [03:16<00:39,  1.17s/it]

torch.float64


 84%|████████▎ | 167/200 [03:17<00:39,  1.21s/it]

torch.float64


 84%|████████▍ | 168/200 [03:18<00:37,  1.18s/it]

torch.float64


 84%|████████▍ | 169/200 [03:19<00:37,  1.22s/it]

torch.float64


 85%|████████▌ | 170/200 [03:21<00:36,  1.21s/it]

Epoch: 170, Train Loss: 0.4099, 
torch.float64


 86%|████████▌ | 171/200 [03:22<00:35,  1.21s/it]

torch.float64


 86%|████████▌ | 172/200 [03:23<00:33,  1.20s/it]

torch.float64


 86%|████████▋ | 173/200 [03:24<00:31,  1.17s/it]

torch.float64


 87%|████████▋ | 174/200 [03:25<00:30,  1.19s/it]

torch.float64


 88%|████████▊ | 175/200 [03:26<00:29,  1.16s/it]

torch.float64


 88%|████████▊ | 176/200 [03:28<00:28,  1.19s/it]

torch.float64


 88%|████████▊ | 177/200 [03:29<00:27,  1.20s/it]

torch.float64


 89%|████████▉ | 178/200 [03:30<00:26,  1.19s/it]

torch.float64


 90%|████████▉ | 179/200 [03:31<00:25,  1.22s/it]

torch.float64


 90%|█████████ | 180/200 [03:32<00:23,  1.19s/it]

Epoch: 180, Train Loss: 0.4089, 
torch.float64


 90%|█████████ | 181/200 [03:34<00:22,  1.20s/it]

torch.float64


 91%|█████████ | 182/200 [03:35<00:21,  1.21s/it]

torch.float64


 92%|█████████▏| 183/200 [03:36<00:20,  1.19s/it]

torch.float64


 92%|█████████▏| 184/200 [03:37<00:19,  1.20s/it]

torch.float64


 92%|█████████▎| 185/200 [03:38<00:17,  1.19s/it]

torch.float64


 93%|█████████▎| 186/200 [03:40<00:16,  1.19s/it]

torch.float64


 94%|█████████▎| 187/200 [03:41<00:15,  1.18s/it]

torch.float64


 94%|█████████▍| 188/200 [03:42<00:14,  1.18s/it]

torch.float64


 94%|█████████▍| 189/200 [03:43<00:12,  1.16s/it]

torch.float64


 95%|█████████▌| 190/200 [03:44<00:11,  1.18s/it]

Epoch: 190, Train Loss: 0.4064, 
torch.float64


 96%|█████████▌| 191/200 [03:45<00:10,  1.16s/it]

torch.float64


 96%|█████████▌| 192/200 [03:47<00:09,  1.17s/it]

torch.float64


 96%|█████████▋| 193/200 [03:48<00:08,  1.17s/it]

torch.float64


 97%|█████████▋| 194/200 [03:49<00:07,  1.17s/it]

torch.float64


 98%|█████████▊| 195/200 [03:50<00:05,  1.17s/it]

torch.float64


 98%|█████████▊| 196/200 [03:51<00:04,  1.19s/it]

torch.float64


 98%|█████████▊| 197/200 [03:53<00:03,  1.19s/it]

torch.float64


 99%|█████████▉| 198/200 [03:54<00:02,  1.17s/it]

torch.float64


100%|█████████▉| 199/200 [03:55<00:01,  1.20s/it]

torch.float64


100%|██████████| 200/200 [03:56<00:00,  1.18s/it]

Epoch: 200, Train Loss: 0.4055, 





In [96]:
# test accuracy
test_acc = test(model, test_graph.to(device))
print(f"Test Accuracy: {test_acc:.4f}")

RuntimeError: index 136266 is out of bounds for dimension 0 with size 67504