# Transaction Flow Graph

##  Install required packages.

In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['PYTHONWARNINGS'] = "ignore"
print(torch.__version__)
!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

2.2.2
Looking in links: https://data.pyg.org/whl/torch-2.2.2.html
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
  Preparing metadata (setup.py) ... [?25lerror
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m [31m[52 lines of output][0m
  [31m   [0m Compiling without OpenMP...
  [31m   [0m Compiling without OpenMP...
  [31m   [0m Compiling without OpenMP...
  [31m   [0m Compiling without OpenMP...
  [31m   [0m running egg_info
  [31m   [0m creating /private/var/folders/c8/rmlp9g456b5_c_vt71ncc0nm0000gn/T/pip-pip-egg-info-9dgdonsa/torch_scatter.egg-info
  [31m   [0m writing /private/var/folders/c8/rmlp9g456b5_c_vt71ncc0nm0000gn/T/pip-pip-egg-info-9dgdonsa/torch_scatter.egg-info/PKG-INFO
  [31m   [0m writing dependency_links to /private/var/folders/c8/rmlp9g456b5_c_vt71ncc0nm0000gn/T/pip-pip-egg-info-9dgdons

##  Set up

In [None]:
import os
from pathlib import Path
import pandas as pd
import pickle
import json
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx
from torch.nn.utils.rnn import pad_sequence

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report, precision_score, recall_score, f1_score, accuracy_score
from utils.threshold import tune_thresholds
from utils.comparing import evaluate_multilabel_classification

PATH = './data/labeled'

In [None]:
# Load features (ensure index is lowercase)
def load_feature(file):
    df = pd.read_csv(file, index_col=0)
    df.index = df.index.str.lower()
    return df

In [None]:
# Load ground truth
ground = pd.read_csv(os.path.join(PATH, "groundtruth.csv")).set_index('Address')
ground.index = ground.index.str.lower()
label_cols = [col for col in ground.columns]
graph = load_feature(os.path.join(PATH,"txn_graph_features.csv"))

with open(os.path.join(PATH, "txn.pkl"), "rb") as f:
    data = pickle.load(f)

In [None]:
class GCN(nn.Module):
    def __init__(self, in_channels, hidden, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.lin = nn.Linear(hidden, out_channels)
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

In [None]:
dataset = []

for i, (address, graph_data) in enumerate(data.items()):
    # Check if the address exists in the ground truth DataFrame
    if address in ground.index:
        feature = graph.loc[address] # Use txn_graph features for txn_dataset
        data = from_networkx(graph_data)
        data.x = torch.tensor(feature.values, dtype=torch.float32).repeat(data.num_nodes, 1) # Repeat features for each node
        data.y = torch.tensor(ground.set_index('Address').loc[address][label_col].values, dtype=torch.float32).unsqueeze(0) # Add a batch dimension
        if i < 10:
            print(data)
        dataset.append(data)

In [None]:
from torch_geometric.loader import DataLoader

# Split your dataset
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)

train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8, shuffle=False)

In [None]:
model = GCN(in_channels=feature.shape[0], hidden=64, out_channels=3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()  # For multi-label classification

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_score
            self.counter = 0


In [None]:
early_stopper = EarlyStopping(patience=3)  # stop if no improvement for 3 epochs

for epoch in range(1, 51):  # max 50 epochs
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []

    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = loss_fn(out, batch.y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        preds = (torch.sigmoid(out) > 0.5).int().cpu().numpy()
        labels = batch.y.int().cpu().numpy()
        all_preds.append(preds)
        all_labels.append(labels)

    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    acc = (all_preds == all_labels).mean()

    print(f"Epoch {epoch}, Loss: {total_loss:.4f}, Accuracy: {acc:.4f}")

    # 🔁 Check for early stopping
    early_stopper(acc)
    if early_stopper.early_stop:
        print("🛑 Early stopping triggered.")
        break


In [None]:
model.eval()
y_true, y_probs = [], []
with torch.no_grad():
    for batch in test_loader:
        out = torch.sigmoid(model(batch.x, batch.edge_index, batch.batch))
        y_probs.append(out.cpu())
        y_true.append(batch.y.cpu())

y_true = np.vstack(y_true)
y_probs = np.vstack(y_probs)

best_thresholds, _ = tune_thresholds(y_true, y_probs)

In [None]:
evaluate_multilabel_classification(y_true, y_probs, label_names=label_cols, threshold=best_thresholds)