# E5: Build Heterogeneous Graph (HeteroData)

**Milestone:** E5 - Heterogeneous Graph Construction  
**Objective:** Build PyG HeteroData from Elliptic++ CSV files with temporal constraints

**Node Types:**
- Transaction: 203,769 nodes (93 local features)
- Address: 823,942 nodes (52 features)

**Edge Types:**
- tx → tx: Transaction flows
- addr → tx: Address inputs to transaction
- tx → addr: Transaction outputs to address
- addr → addr: Address-to-address connections

**Output:**
- `hetero_graph.pt`: PyG HeteroData object
- `hetero_graph_summary.json`: Graph statistics
- `node_mappings.json`: ID mappings

## Setup

In [1]:
# Install dependencies
!pip install -q torch torch-geometric pandas numpy tqdm

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m93.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m73.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3

In [2]:
import pandas as pd
import numpy as np
import torch
from pathlib import Path
import json
from torch_geometric.data import HeteroData
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"GPU available: {torch.cuda.is_available()}")

PyTorch version: 2.6.0+cu124
GPU available: True


## Configuration

In [3]:
# Paths
DATA_ROOT = Path('/kaggle/input/elliptic-dataset')  # Adjust to your Kaggle dataset path
OUTPUT_DIR = Path('/kaggle/working')

# Graph construction settings
TOP_K_ADDRESSES = 100000  # Use top 100K most active addresses (MVP)
USE_ALL_ADDRESSES = False  # Set True to use all 823K addresses (requires ~12GB RAM)

# Temporal splits (same as E3)
TRAIN_FRAC = 0.6
VAL_FRAC = 0.2
TEST_FRAC = 0.2

print(f"Data root: {DATA_ROOT}")
print(f"Output dir: {OUTPUT_DIR}")
print(f"Address strategy: {'All addresses' if USE_ALL_ADDRESSES else f'Top {TOP_K_ADDRESSES:,}'}")

Data root: /kaggle/input/elliptic-dataset
Output dir: /kaggle/working
Address strategy: Top 100,000


## Load Transaction Nodes

In [4]:
print("Loading transaction nodes...")

# Load features and classes
tx_features = pd.read_csv(DATA_ROOT / 'txs_features.csv')
tx_classes = pd.read_csv(DATA_ROOT / 'txs_classes.csv')

# Merge
tx_data = tx_features.merge(tx_classes, on='txId', how='left')

print(f"  Transactions loaded: {len(tx_data):,}")

# Create ID mappings
tx_ids = tx_data['txId'].values
tx_id_to_idx = {tx_id: idx for idx, tx_id in enumerate(tx_ids)}
tx_idx_to_id = {idx: tx_id for idx, tx_id in enumerate(tx_ids)}

# Extract LOCAL features only (AF1-AF93)
local_features = [col for col in tx_data.columns if 'Local_feature' in col]
print(f"  Features: {len(local_features)} (Local only)")

# Extract features
tx_x = torch.FloatTensor(tx_data[local_features].values)
tx_x = torch.nan_to_num(tx_x, nan=0.0, posinf=0.0, neginf=0.0)

# Normalize
tx_x = (tx_x - tx_x.mean(dim=0)) / (tx_x.std(dim=0) + 1e-8)
tx_x = torch.nan_to_num(tx_x, nan=0.0)

# Extract timestamps
tx_timestamps = torch.LongTensor(tx_data['Time step'].values)

# Extract labels (1=illicit→1, 2=licit→0, 3/NaN=unknown→-1)
y_raw = tx_data['class'].fillna(3).astype(int).values
tx_y = torch.LongTensor(np.where(y_raw == 1, 1, np.where(y_raw == 2, 0, -1)))

print(f"  Labeled: {(tx_y >= 0).sum():,} / {len(tx_y):,}")
print(f"  Fraud: {(tx_y == 1).sum():,}, Legit: {(tx_y == 0).sum():,}")
print(f"  Feature shape: {tx_x.shape}")

Loading transaction nodes...
  Transactions loaded: 203,769
  Features: 93 (Local only)
  Labeled: 46,564 / 203,769
  Fraud: 4,545, Legit: 42,019
  Feature shape: torch.Size([203769, 93])


## Load Address Nodes

In [5]:
print("Loading address nodes...")

# Try combined file first (more efficient)
combined_file = DATA_ROOT / 'wallets_features_classes_combined.csv'

if combined_file.exists():
    print("  Using combined wallet file...")
    addr_data = pd.read_csv(combined_file)
else:
    print("  Loading separate files...")
    addr_features = pd.read_csv(DATA_ROOT / 'wallets_features.csv')
    addr_classes = pd.read_csv(DATA_ROOT / 'wallets_classes.csv')
    addr_data = addr_features.merge(addr_classes, on='address', how='left')

print(f"  Total addresses: {len(addr_data):,}")

# Select top K most active if not using all
if not USE_ALL_ADDRESSES and TOP_K_ADDRESSES:
    addr_data = addr_data.nlargest(TOP_K_ADDRESSES, 'total_txs')
    print(f"  Selected top {TOP_K_ADDRESSES:,} most active addresses")

# Create ID mappings
addr_ids = addr_data['address'].values
addr_id_to_idx = {addr_id: idx for idx, addr_id in enumerate(addr_ids)}
addr_idx_to_id = {idx: addr_id for idx, addr_id in enumerate(addr_ids)}

# Extract features (exclude ID, timestamp, class)
feature_cols = [col for col in addr_data.columns 
                if col not in ['address', 'Time step', 'class']]
print(f"  Features: {len(feature_cols)}")

# Extract features
addr_x = torch.FloatTensor(addr_data[feature_cols].values)
addr_x = torch.nan_to_num(addr_x, nan=0.0, posinf=0.0, neginf=0.0)

# Normalize
addr_x = (addr_x - addr_x.mean(dim=0)) / (addr_x.std(dim=0) + 1e-8)
addr_x = torch.nan_to_num(addr_x, nan=0.0)

# Extract timestamps
addr_timestamps = torch.LongTensor(addr_data['Time step'].values)

# Extract labels
y_raw = addr_data['class'].fillna(3).astype(int).values
addr_y = torch.LongTensor(np.where(y_raw == 1, 1, np.where(y_raw == 2, 0, -1)))

print(f"  Labeled: {(addr_y >= 0).sum():,} / {len(addr_y):,}")
print(f"  Fraud: {(addr_y == 1).sum():,}, Legit: {(addr_y == 0).sum():,}")
print(f"  Feature shape: {addr_x.shape}")

Loading address nodes...
  Using combined wallet file...
  Total addresses: 1,268,260
  Selected top 100,000 most active addresses
  Features: 55
  Labeled: 31,754 / 100,000
  Fraud: 3,880, Legit: 27,874
  Feature shape: torch.Size([100000, 55])


## Load Edges

In [6]:
def load_edges(edge_type, src_mapping, dst_mapping):
    """
    Load edges for a specific type.
    
    Args:
        edge_type: One of ['tx-tx', 'addr-tx', 'tx-addr', 'addr-addr']
        src_mapping: Source node ID to index mapping
        dst_mapping: Destination node ID to index mapping
    
    Returns:
        edge_index: [2, E] tensor
    """
    file_map = {
        'tx-tx': 'txs_edgelist.csv',
        'addr-tx': 'AddrTx_edgelist.csv',
        'tx-addr': 'TxAddr_edgelist.csv',
        'addr-addr': 'AddrAddr_edgelist.csv'
    }
    
    print(f"\nLoading {edge_type} edges...")
    
    edges_df = pd.read_csv(DATA_ROOT / file_map[edge_type])
    cols = list(edges_df.columns)
    src_col, dst_col = cols[0], cols[1]
    
    print(f"  Total edges in file: {len(edges_df):,}")
    
    # Filter to valid nodes
    valid = (edges_df[src_col].isin(src_mapping) & 
             edges_df[dst_col].isin(dst_mapping))
    
    src_idx = edges_df.loc[valid, src_col].map(src_mapping).values
    dst_idx = edges_df.loc[valid, dst_col].map(dst_mapping).values
    
    edge_index = torch.LongTensor(np.vstack([src_idx, dst_idx]))
    
    print(f"  Valid edges: {edge_index.shape[1]:,}")
    
    return edge_index

In [7]:
# Load all edge types
print("="*70)
print("LOADING EDGES")
print("="*70)

edge_index_tx_tx = load_edges('tx-tx', tx_id_to_idx, tx_id_to_idx)
edge_index_addr_tx = load_edges('addr-tx', addr_id_to_idx, tx_id_to_idx)
edge_index_tx_addr = load_edges('tx-addr', tx_id_to_idx, addr_id_to_idx)

# addr-addr edges may be very large - optional
try:
    edge_index_addr_addr = load_edges('addr-addr', addr_id_to_idx, addr_id_to_idx)
except Exception as e:
    print(f"\nSkipping addr-addr edges: {e}")
    edge_index_addr_addr = torch.LongTensor([[],[]])  # Empty edges

LOADING EDGES

Loading tx-tx edges...
  Total edges in file: 234,355
  Valid edges: 234,355

Loading addr-tx edges...
  Total edges in file: 477,117
  Valid edges: 53,059

Loading tx-addr edges...
  Total edges in file: 837,124
  Valid edges: 80,489

Loading addr-addr edges...
  Total edges in file: 2,868,964
  Valid edges: 54,173


## Create Temporal Splits

In [8]:
print("\nCreating temporal splits...")

# Sort timestamps and find boundaries
sorted_times = np.sort(np.unique(tx_timestamps.numpy()))
n_timesteps = len(sorted_times)

train_end_idx = int(n_timesteps * TRAIN_FRAC)
val_end_idx = int(n_timesteps * (TRAIN_FRAC + VAL_FRAC))

train_time_end = sorted_times[train_end_idx - 1]
val_time_end = sorted_times[val_end_idx - 1]

# Create masks (only for labeled transactions)
labeled = tx_y >= 0
train_mask = (tx_timestamps <= train_time_end) & labeled
val_mask = ((tx_timestamps > train_time_end) & (tx_timestamps <= val_time_end)) & labeled
test_mask = (tx_timestamps > val_time_end) & labeled

print(f"  Train: {train_mask.sum():,} (time <= {train_time_end})")
print(f"  Val:   {val_mask.sum():,} (time <= {val_time_end})")
print(f"  Test:  {test_mask.sum():,}")


Creating temporal splits...
  Train: 26,381 (time <= 29)
  Val:   8,999 (time <= 39)
  Test:  11,184


## Build HeteroData

In [9]:
print("\n" + "="*70)
print("BUILDING HETERODATA")
print("="*70)

# Initialize HeteroData
data = HeteroData()

# Add transaction nodes
data['transaction'].x = tx_x
data['transaction'].y = tx_y
data['transaction'].timestamp = tx_timestamps
data['transaction'].train_mask = train_mask
data['transaction'].val_mask = val_mask
data['transaction'].test_mask = test_mask

# Add address nodes
data['address'].x = addr_x
data['address'].y = addr_y
data['address'].timestamp = addr_timestamps

# Add edges
data['transaction', 'to', 'transaction'].edge_index = edge_index_tx_tx
data['address', 'to', 'transaction'].edge_index = edge_index_addr_tx
data['transaction', 'to', 'address'].edge_index = edge_index_tx_addr
data['address', 'to', 'address'].edge_index = edge_index_addr_addr

print("\nHeteroData Summary:")
print(data)

print("\nNode Statistics:")
print(f"  Transactions: {data['transaction'].num_nodes:,}")
print(f"  Addresses: {data['address'].num_nodes:,}")

print("\nEdge Statistics:")
for edge_type in data.edge_types:
    src, rel, dst = edge_type
    print(f"  {src} -> {dst}: {data[edge_type].num_edges:,}")


BUILDING HETERODATA

HeteroData Summary:
HeteroData(
  transaction={
    x=[203769, 93],
    y=[203769],
    timestamp=[203769],
    train_mask=[203769],
    val_mask=[203769],
    test_mask=[203769],
  },
  address={
    x=[100000, 55],
    y=[100000],
    timestamp=[100000],
  },
  (transaction, to, transaction)={ edge_index=[2, 234355] },
  (address, to, transaction)={ edge_index=[2, 53059] },
  (transaction, to, address)={ edge_index=[2, 80489] },
  (address, to, address)={ edge_index=[2, 54173] }
)

Node Statistics:
  Transactions: 203,769
  Addresses: 100,000

Edge Statistics:
  transaction -> transaction: 234,355
  address -> transaction: 53,059
  transaction -> address: 80,489
  address -> address: 54,173


## Save Outputs

In [10]:
print("\n" + "="*70)
print("SAVING OUTPUTS")
print("="*70)

# Save HeteroData
torch.save(data, OUTPUT_DIR / 'hetero_graph.pt')
print(f"\nSaved HeteroData: {OUTPUT_DIR / 'hetero_graph.pt'}")

# Save summary
summary = {
    'num_nodes': {
        'transaction': data['transaction'].num_nodes,
        'address': data['address'].num_nodes
    },
    'num_edges': {
        f"{src}_to_{dst}": data[src, rel, dst].num_edges
        for src, rel, dst in data.edge_types
    },
    'num_labeled': {
        'transaction': (data['transaction'].y >= 0).sum().item(),
        'address': (data['address'].y >= 0).sum().item()
    },
    'temporal_range': {
        'transaction': [
            data['transaction'].timestamp.min().item(),
            data['transaction'].timestamp.max().item()
        ],
        'address': [
            data['address'].timestamp.min().item(),
            data['address'].timestamp.max().item()
        ]
    },
    'feature_dims': {
        'transaction': data['transaction'].x.shape[1],
        'address': data['address'].x.shape[1]
    },
    'splits': {
        'train': data['transaction'].train_mask.sum().item(),
        'val': data['transaction'].val_mask.sum().item(),
        'test': data['transaction'].test_mask.sum().item()
    }
}

with open(OUTPUT_DIR / 'hetero_graph_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)
print(f"Saved summary: {OUTPUT_DIR / 'hetero_graph_summary.json'}")

# Save node mappings (first 1000 of each for reference)
mappings = {
    'tx_id_to_idx_sample': {str(k): int(v) for k, v in list(tx_id_to_idx.items())[:1000]},
    'addr_id_to_idx_sample': {str(k): int(v) for k, v in list(addr_id_to_idx.items())[:1000]}
}

with open(OUTPUT_DIR / 'node_mappings_sample.json', 'w') as f:
    json.dump(mappings, f, indent=2)
print(f"Saved mappings: {OUTPUT_DIR / 'node_mappings_sample.json'}")

print("\n" + "="*70)
print("E5 MILESTONE COMPLETE!")
print("="*70)
print("\nNext: Download hetero_graph.pt for E6 (TRD-HHGTN training)")


SAVING OUTPUTS

Saved HeteroData: /kaggle/working/hetero_graph.pt
Saved summary: /kaggle/working/hetero_graph_summary.json
Saved mappings: /kaggle/working/node_mappings_sample.json

E5 MILESTONE COMPLETE!

Next: Download hetero_graph.pt for E6 (TRD-HHGTN training)


## Validation Checks

In [11]:
print("\n" + "="*70)
print("VALIDATION CHECKS")
print("="*70)

# Check 1: Node counts match
print("\n1. Node Counts:")
print(f"   Transactions: {data['transaction'].num_nodes:,} (expected: {len(tx_ids):,})")
print(f"   Addresses: {data['address'].num_nodes:,} (expected: {len(addr_ids):,})")
assert data['transaction'].num_nodes == len(tx_ids), "Transaction count mismatch!"
assert data['address'].num_nodes == len(addr_ids), "Address count mismatch!"
print("   ✓ PASS")

# Check 2: Edge indices are valid
print("\n2. Edge Index Validity:")
for edge_type in data.edge_types:
    src, rel, dst = edge_type
    edge_index = data[edge_type].edge_index
    if edge_index.shape[1] > 0:  # Only check non-empty edges
        src_max = edge_index[0].max().item()
        dst_max = edge_index[1].max().item()
        src_nodes = data[src].num_nodes
        dst_nodes = data[dst].num_nodes
        print(f"   {src} -> {dst}: src_max={src_max} < {src_nodes}, dst_max={dst_max} < {dst_nodes}")
        assert src_max < src_nodes, f"{src} index out of bounds!"
        assert dst_max < dst_nodes, f"{dst} index out of bounds!"
print("   ✓ PASS")

# Check 3: Split sizes
print("\n3. Split Integrity:")
train_count = data['transaction'].train_mask.sum().item()
val_count = data['transaction'].val_mask.sum().item()
test_count = data['transaction'].test_mask.sum().item()
total_labeled = (data['transaction'].y >= 0).sum().item()
print(f"   Train: {train_count:,}")
print(f"   Val: {val_count:,}")
print(f"   Test: {test_count:,}")
print(f"   Total: {train_count + val_count + test_count:,} (labeled: {total_labeled:,})")
assert train_count + val_count + test_count == total_labeled, "Split count mismatch!"
print("   ✓ PASS")

# Check 4: No NaN in features
print("\n4. Feature Quality:")
tx_nans = torch.isnan(data['transaction'].x).sum().item()
addr_nans = torch.isnan(data['address'].x).sum().item()
print(f"   Transaction NaNs: {tx_nans}")
print(f"   Address NaNs: {addr_nans}")
assert tx_nans == 0, "Transaction features contain NaN!"
assert addr_nans == 0, "Address features contain NaN!"
print("   ✓ PASS")

print("\n" + "="*70)
print("ALL VALIDATION CHECKS PASSED!")
print("="*70)


VALIDATION CHECKS

1. Node Counts:
   Transactions: 203,769 (expected: 203,769)
   Addresses: 100,000 (expected: 100,000)
   ✓ PASS

2. Edge Index Validity:
   transaction -> transaction: src_max=203768 < 203769, dst_max=203766 < 203769
   address -> transaction: src_max=99999 < 100000, dst_max=202785 < 203769
   transaction -> address: src_max=202791 < 203769, dst_max=99999 < 100000
   address -> address: src_max=99999 < 100000, dst_max=99987 < 100000
   ✓ PASS

3. Split Integrity:
   Train: 26,381
   Val: 8,999
   Test: 11,184
   Total: 46,564 (labeled: 46,564)
   ✓ PASS

4. Feature Quality:
   Transaction NaNs: 0
   Address NaNs: 0
   ✓ PASS

ALL VALIDATION CHECKS PASSED!
