## 1. Import and Setup

In [1]:
import sys, os
sys.path.append(os.path.abspath("../src"))

# Torch & GNN
import torch
from torch_geometric.nn import SAGEConv

# Custom modules
from data_utils_temporal import load_and_preprocess_elliptic_temporal
from train_temporal import temporal_split
from model_sage import GraphSAGENet
from train_utils import set_seed, train_full, save_feature_experiment
from config import SEEDS

# Feature engineering control
from features.feature_config import FEATURE_CONFIGS
from features.feature_runner import apply_engineered_features

print("✓ All modules imported successfully.")


✓ All modules imported successfully.


## 2. Load and Preprocess the Elliptic Dataset
This loads the data, maps transaction IDs to node indices, filters isolated nodes, and normalizes node features. Uses temporal splitting.


In [2]:
from normalization import normalize_base_features_only
from features.feature_pipeline import generate_all_features
from features.feature_utils import prepare_graph_and_timestamps
from data_utils_temporal import load_and_preprocess_elliptic_temporal
from train_temporal import temporal_split
from features.feature_config import FEATURE_CONFIGS
from train_utils import set_seed
from config import SEEDS
import torch
import os

# 1. Load PyG temporal data (with node_times)
data, node_times = load_and_preprocess_elliptic_temporal("../elliptic_bitcoin_dataset", normalize=False)
print(f"✓ Temporal PyG Data object loaded: {data}")

# 2. Prepare NetworkX graph + timestamps
G_nx, node_timestamps = prepare_graph_and_timestamps(
    edgelist_path="../elliptic_bitcoin_dataset/elliptic_txs_edgelist.csv",
    features_path="../elliptic_bitcoin_dataset/elliptic_txs_features.csv",
    num_nodes=data.num_nodes
)

print(f"✓ NetworkX Graph: {G_nx.number_of_nodes()} nodes, {G_nx.number_of_edges()} edges")
print(f"✓ Timestamps found for {sum(len(ts) > 0 for ts in node_timestamps)} nodes")

# 3. Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ABLATORS = {
    k: v for k, v in FEATURE_CONFIGS.items()
}
FEATURED_DATA = {}

# 4. Loop through ablations + seeds
for ablation_name, feature_flags in ABLATORS.items():
    print(f"\n== Config: {ablation_name} ==")
    for seed in SEEDS:
        print(f"→ Seed {seed}")
        set_seed(seed)

        # 4.1 Temporal split (train up to time step 35)
        train_idx, val_idx = temporal_split(data, node_times=node_times, train_until=35)

        # 4.2 Feature engineering
        data_exp = data.clone()
        X, feature_names = generate_all_features(data_exp, G_nx, node_timestamps, **feature_flags)
        data_exp.x = torch.tensor(X, dtype=torch.float)

        # Optional: monitor temporal feature stats
        temporal_names = {"temporal_lag", "component_size"}
        temporal_indices = [i for i, name in enumerate(feature_names) if name in temporal_names]
        if temporal_indices:
            print("→ Temporal features (raw)")
            print("  Means:", data_exp.x[:, temporal_indices].mean(0))
            print("  Stds :", data_exp.x[:, temporal_indices].std(0))

        # 4.3 Normalize base features only
        data_exp.x = normalize_base_features_only(data_exp.x, feature_names)

        # 4.4 Finalize
        data_exp = data_exp.to(device)
        FEATURED_DATA[(ablation_name, seed)] = {
            "data": data_exp,
            "train_idx": train_idx,
            "val_idx": val_idx,
            "feature_flags": feature_flags,
        }

        # 4.5 Save to disk
        save_dir = f"../data/SAGE_temporal/{ablation_name}/seed_{seed}"
        os.makedirs(save_dir, exist_ok=True)
        torch.save(data_exp, os.path.join(save_dir, "data_exp.pt"))
        torch.save(train_idx, os.path.join(save_dir, "train_idx.pt"))
        torch.save(val_idx, os.path.join(save_dir, "val_idx.pt"))
        print(f"✓ Saved to {save_dir}")
        print(f"✓ Final shape: {data_exp.x.shape}")


✓ Temporal PyG Data object loaded: Data(x=[203769, 166], edge_index=[2, 234355], y=[203769])
Loaded directed graph with 203769 nodes and 234355 edges.
Graph type: <class 'networkx.classes.digraph.DiGraph'>
Features DataFrame shape: (203769, 167)
Number of features (excluding txId and time_step): 165
Timestamps found for 203769 out of 203769 nodes.
✓ Time step range: 1.0 to 49.0
✓ NetworkX Graph: 203769 nodes, 234355 edges
✓ Timestamps found for 203769 nodes

== Config: base ==
→ Seed 42
✓ Saved to ../data/SAGE_temporal/base/seed_42
✓ Final shape: torch.Size([203769, 166])
→ Seed 123
✓ Saved to ../data/SAGE_temporal/base/seed_123
✓ Final shape: torch.Size([203769, 166])
→ Seed 777
✓ Saved to ../data/SAGE_temporal/base/seed_777
✓ Final shape: torch.Size([203769, 166])
→ Seed 2023
✓ Saved to ../data/SAGE_temporal/base/seed_2023
✓ Final shape: torch.Size([203769, 166])
→ Seed 31415
✓ Saved to ../data/SAGE_temporal/base/seed_31415
✓ Final shape: torch.Size([203769, 166])

== Config: base+st

Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 407350.75it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/base+basic_temporal/seed_42
✓ Final shape: torch.Size([203769, 168])
→ Seed 123


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 271939.50it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/base+basic_temporal/seed_123
✓ Final shape: torch.Size([203769, 168])
→ Seed 777


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 232821.70it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/base+basic_temporal/seed_777
✓ Final shape: torch.Size([203769, 168])
→ Seed 2023


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 357435.39it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/base+basic_temporal/seed_2023
✓ Final shape: torch.Size([203769, 168])
→ Seed 31415


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 316850.29it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/base+basic_temporal/seed_31415
✓ Final shape: torch.Size([203769, 168])

== Config: base+basic_temporal+typology ==
→ Seed 42


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 556535.36it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/base+basic_temporal+typology/seed_42
✓ Final shape: torch.Size([203769, 170])
→ Seed 123


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 361131.87it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/base+basic_temporal+typology/seed_123
✓ Final shape: torch.Size([203769, 170])
→ Seed 777


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 288168.30it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/base+basic_temporal+typology/seed_777
✓ Final shape: torch.Size([203769, 170])
→ Seed 2023


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 341305.73it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/base+basic_temporal+typology/seed_2023
✓ Final shape: torch.Size([203769, 170])
→ Seed 31415


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 210136.31it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/base+basic_temporal+typology/seed_31415
✓ Final shape: torch.Size([203769, 170])

== Config: all ==
→ Seed 42


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 486592.81it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/all/seed_42
✓ Final shape: torch.Size([203769, 174])
→ Seed 123


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 570928.89it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/all/seed_123
✓ Final shape: torch.Size([203769, 174])
→ Seed 777


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 551132.41it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/all/seed_777
✓ Final shape: torch.Size([203769, 174])
→ Seed 2023


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 555331.18it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/all/seed_2023
✓ Final shape: torch.Size([203769, 174])
→ Seed 31415


Computing temporal features: 100%|██████████| 203769/203769 [00:00<00:00, 550570.03it/s]


→ Temporal features (raw)
  Means: tensor([25.1560])
  Stds : tensor([15.1722])
✓ Saved to ../data/SAGE_temporal/all/seed_31415
✓ Final shape: torch.Size([203769, 174])


In [3]:
import torch
import os

ROOT_DIR = "../data/SAGE_temporal"
BASE_DIM = 166
SEEDS = [42, 123, 777, 2023, 31415]

print("\n📊 ENGINEERED FEATURE INSPECTION (from disk)")

# Loop over ablation folders
for ablation_name in sorted(os.listdir(ROOT_DIR)):
    ablation_path = os.path.join(ROOT_DIR, ablation_name)
    if not os.path.isdir(ablation_path):
        continue

    for seed in SEEDS:
        seed_dir = f"seed_{seed}"
        data_path = os.path.join(ablation_path, seed_dir, "data_exp.pt")
        if not os.path.exists(data_path):
            print(f"⛔ Missing file: {data_path}")
            continue

        data = torch.load(data_path, map_location="cpu", weights_only=False)
        x = data.x

        if x.shape[1] <= BASE_DIM:
            print(f"\n {ablation_name} | Seed: {seed}")
            print(" No engineered features found.")
            continue

        engineered = x[:, BASE_DIM:]
        print(f"\n {ablation_name} | Seed: {seed}")
        print(f" Engineered feature count: {engineered.shape[1]}")

        means = engineered.mean(dim=0)
        stds = engineered.std(dim=0)

        for i in range(engineered.shape[1]):
            m = means[i].item()
            s = stds[i].item()
            sample_vals = engineered[:5, i].numpy()
            print(f"Feature {BASE_DIM + i:>3} | Mean: {m:8.4f} | Std: {s:8.4f} | Sample: {sample_vals}")



📊 ENGINEERED FEATURE INSPECTION (from disk)

 all | Seed: 42
 Engineered feature count: 8
Feature 166 | Mean:   1.1501 | Std:   3.9111 | Sample: [1. 0. 1. 1. 0.]
Feature 167 | Mean:   1.1501 | Std:   1.8947 | Sample: [  1.   3. 112.   0.  50.]
Feature 168 | Mean:   0.0000 | Std:   4.3634 | Sample: [   0.   -3. -111.    1.  -50.]
Feature 169 | Mean:   0.0138 | Std:   0.0973 | Sample: [0.         0.         0.00063211 0.         0.        ]
Feature 170 | Mean:   4.2315 | Std:   6.0100 | Sample: [ 0.6931467 14.914124   4.727387   0.        17.727533 ]
Feature 171 | Mean:   0.4457 | Std:   0.3437 | Sample: [0.49999976 0.         0.00884956 0.999999   0.        ]
Feature 172 | Mean:  25.1560 | Std:  15.1722 | Sample: [ 1. 43. 15. 48.  1.]
Feature 173 | Mean:   8.4067 | Std:   0.3653 | Sample: [7.991254  8.373092  7.8188324 8.97221   7.991254 ]

 all | Seed: 123
 Engineered feature count: 8
Feature 166 | Mean:   1.1501 | Std:   3.9111 | Sample: [1. 0. 1. 1. 0.]
Feature 167 | Mean:   1.1501 

## 3. Training Multiseed Models and Ablation

In [4]:
from model_sage import GraphSAGENet
from train_utils import train_full, save_feature_experiment
import os
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


for (ablation_name, seed), bundle in FEATURED_DATA.items():
    print(f"\n=== Training GraphSAGE (Temporal Split): {ablation_name} | Seed: {seed} ===")

    data_exp = bundle["data"]
    train_idx = bundle["train_idx"]
    val_idx = bundle["val_idx"]
    feature_flags = bundle["feature_flags"]

    # 1. Initialize model
    model = GraphSAGENet(
        in_channels=data_exp.x.shape[1],
        hidden_channels=64,
        out_channels=2,
        dropout=0.1
    ).to(device)

    # 2. Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

    # 3. Train the model (use default CrossEntropyLoss)
    model, *results = train_full(
        model=model,
        data=data_exp,
        train_idx=train_idx,
        val_idx=val_idx,
        optimizer=optimizer,
        num_epochs=300,
        patience=30
    )

    # 4. Save model + logs
    output_dir = f"../model_features/temporal/SAGE_ablation/{ablation_name}/seed_{seed}"
    os.makedirs(output_dir, exist_ok=True)

    save_feature_experiment(
        output_dir=output_dir,
        model=model,
        results=results,
        seed=seed,
        config={
            "model": "SAGE",
            "ablation_name": ablation_name,
            "feature_flags": feature_flags,
            "feature_dim": data_exp.x.shape[1],
            "dropout": 0.1,
            "num_epochs": 300,
            "patience": 30,
            "lr": 0.001,
            "weight_decay": 5e-4,
            "class_weights": None  # explicitly noting unweighted
        },
        val_idx=val_idx
    )



=== Training GraphSAGE (Temporal Split): base | Seed: 42 ===
Epoch 001 | Loss: 0.6544 | Val Acc: 0.7222 | Val F1: 0.5062 | F1 Illicit: 0.8328
Epoch 002 | Loss: 0.5410 | Val Acc: 0.7699 | Val F1: 0.5436 | F1 Illicit: 0.8650
Epoch 003 | Loss: 0.4701 | Val Acc: 0.7900 | Val F1: 0.5611 | F1 Illicit: 0.8781
Epoch 004 | Loss: 0.4188 | Val Acc: 0.8054 | Val F1: 0.5761 | F1 Illicit: 0.8879
Epoch 005 | Loss: 0.3808 | Val Acc: 0.8158 | Val F1: 0.5866 | F1 Illicit: 0.8944
Epoch 006 | Loss: 0.3536 | Val Acc: 0.8237 | Val F1: 0.5931 | F1 Illicit: 0.8994
Epoch 007 | Loss: 0.3322 | Val Acc: 0.8331 | Val F1: 0.5996 | F1 Illicit: 0.9054
Epoch 008 | Loss: 0.3141 | Val Acc: 0.8406 | Val F1: 0.6057 | F1 Illicit: 0.9100
Epoch 009 | Loss: 0.3000 | Val Acc: 0.8465 | Val F1: 0.6103 | F1 Illicit: 0.9137
Epoch 010 | Loss: 0.2873 | Val Acc: 0.8532 | Val F1: 0.6150 | F1 Illicit: 0.9178
Epoch 011 | Loss: 0.2773 | Val Acc: 0.8594 | Val F1: 0.6202 | F1 Illicit: 0.9216
Epoch 012 | Loss: 0.2683 | Val Acc: 0.8644 | Va

## 4. Evaluation

In [1]:
import torch, os

base_dir = "../model_features/temporal/SAGE_ablation"  # adjust as needed

for ablation in sorted(os.listdir(base_dir)):
    for seed in [42, 123, 777, 2023, 31415]:
        path = os.path.join(base_dir, ablation, f"seed_{seed}", "model.pth")
        if not os.path.exists(path): continue
        state_dict = torch.load(path, map_location="cpu")
        in_dim = state_dict["conv1.lin_l.weight"].shape[1]
        print(f"{ablation}/seed_{seed} → expects in_channels: {in_dim}")


all/seed_42 → expects in_channels: 174
all/seed_123 → expects in_channels: 174
all/seed_777 → expects in_channels: 174
all/seed_2023 → expects in_channels: 174
all/seed_31415 → expects in_channels: 174
base/seed_42 → expects in_channels: 166
base/seed_123 → expects in_channels: 166
base/seed_777 → expects in_channels: 166
base/seed_2023 → expects in_channels: 166
base/seed_31415 → expects in_channels: 166
base+basic_temporal/seed_42 → expects in_channels: 168
base+basic_temporal/seed_123 → expects in_channels: 168
base+basic_temporal/seed_777 → expects in_channels: 168
base+basic_temporal/seed_2023 → expects in_channels: 168
base+basic_temporal/seed_31415 → expects in_channels: 168
base+basic_temporal+typology/seed_42 → expects in_channels: 170
base+basic_temporal+typology/seed_123 → expects in_channels: 170
base+basic_temporal+typology/seed_777 → expects in_channels: 170
base+basic_temporal+typology/seed_2023 → expects in_channels: 170
base+basic_temporal+typology/seed_31415 → expects

In [2]:
import os
import json
import torch

base_dir = "../model_features/temporal/SAGE_ablation"
data_dir = "../data/SAGE_temporal"
seeds = [42, 123, 777, 2023, 31415]

for ablation in sorted(os.listdir(base_dir)):
    ablation_model_path = os.path.join(base_dir, ablation)
    ablation_data_path = os.path.join(data_dir, ablation)

    if not os.path.isdir(ablation_model_path):
        continue

    print(f"\n→ Processing ablation: {ablation}")

    for seed in seeds:
        seed_dir = f"seed_{seed}"
        config_path = os.path.join(ablation_model_path, seed_dir, "config.json")
        model_path = os.path.join(ablation_model_path, seed_dir, "model.pth")
        data_path = os.path.join(ablation_data_path, seed_dir, "data_exp.pt")

        if not all(os.path.exists(p) for p in [config_path, model_path, data_path]):
            continue

        with open(config_path, "r") as f:
            config = json.load(f)

        updated = False
        data = torch.load(data_path, map_location="cpu", weights_only=False)
        feature_dim = data.x.shape[1]

        if config.get("feature_dim") != feature_dim:
            config["feature_dim"] = feature_dim
            updated = True

        # Patch SAGE-specific fields
        if config.get("model") == "SAGE":
            if config.get("hidden_channels") != 64:
                config["hidden_channels"] = 64
                updated = True
            if config.get("out_channels") != 2:
                config["out_channels"] = 2
                updated = True
            if config.get("dropout") != 0.1:
                config["dropout"] = 0.1
                updated = True

        if updated:
            with open(config_path, "w") as f:
                json.dump(config, f, indent=4)
            print(f"✓ Patched config for {ablation} / seed {seed}")
        else:
            print(f"✓ Already complete: {ablation} / seed {seed}")



→ Processing ablation: all
✓ Patched config for all / seed 42
✓ Patched config for all / seed 123
✓ Patched config for all / seed 777
✓ Patched config for all / seed 2023
✓ Patched config for all / seed 31415

→ Processing ablation: base
✓ Patched config for base / seed 42
✓ Patched config for base / seed 123
✓ Patched config for base / seed 777
✓ Patched config for base / seed 2023
✓ Patched config for base / seed 31415

→ Processing ablation: base+basic_temporal
✓ Patched config for base+basic_temporal / seed 42
✓ Patched config for base+basic_temporal / seed 123
✓ Patched config for base+basic_temporal / seed 777
✓ Patched config for base+basic_temporal / seed 2023
✓ Patched config for base+basic_temporal / seed 31415

→ Processing ablation: base+basic_temporal+typology
✓ Patched config for base+basic_temporal+typology / seed 42
✓ Patched config for base+basic_temporal+typology / seed 123
✓ Patched config for base+basic_temporal+typology / seed 777
✓ Patched config for base+basic_t

In [None]:
import os
from features.evaluate_ablation_model import evaluate_ablation_model
from evaluation_pipeline import run_inference_all_seeds
from analysis_utils import load_metrics_across_seeds, log_metrics_to_csv
from features.feature_utils import make_model_class_from_config

# === Config ===
seeds = [42, 123, 777, 2023, 31415]
base_dir = "../model_features/temporal/SAGE_ablation"
data_exp_root = "../data/SAGE_temporal"
node_id_csv_path = "../elliptic_bitcoin_dataset/elliptic_txs_features.csv"

model_tag = "SAGE-Temporal"
split = "temporal"

# === Evaluate all ablations ===
for ablation in sorted(os.listdir(base_dir)):
    model_dir = os.path.join(base_dir, ablation)
    if not os.path.isdir(model_dir):
        continue

    print(f"\n→ Evaluating ablation: {ablation}")
    model_name = f"{model_tag}: {ablation}"
    ablation_data_dir = os.path.join(data_exp_root, ablation)
    config_path = os.path.join(model_dir, "seed_42", "config.json")
    model_class = make_model_class_from_config(config_path)

    # Run full evaluation
    evaluate_ablation_model(
        model_dir=model_dir,
        model_class=model_class,
        model_name=model_name,
        seeds=seeds,
        node_id_csv_path=node_id_csv_path,
        data_dir=ablation_data_dir
    )

    # Run inference (for logging CSV)
    y_true_all, y_pred_all, y_proba_all, seed_metrics = run_inference_all_seeds(
        model_dir=model_dir,
        model_class=model_class,
        data_dir=ablation_data_dir,
        seeds=seeds
    )
    val_acc_list = load_metrics_across_seeds(model_dir, ["val_acc"])["val_acc"]

    # Save to CSV
    log_metrics_to_csv(
        model_name=model_tag,
        split_name=split,
        ablation=ablation,
        seeds=seeds,
        val_acc_list=val_acc_list,
        seed_metrics=seed_metrics,
        is_feature=True  # append _feature to filename
    )
