In [1]:
"""
Temporal Graph Neural Network for Anti-Money Laundering Detection
================================================================

This implementation creates a temporal graph model optimized for detecting
money laundering patterns in the SAML-D dataset with focus on improving recall
while maintaining precision.

Key Features:
- Multi-scale temporal modeling (hourly, daily, weekly)
- Dynamic node embeddings with memory mechanisms
- Attention-based temporal aggregation
- Class imbalance handling for 0.15% positive class
- Scalable architecture for 9.5M transactions
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.data import Data, Batch
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import precision_recall_curve, roc_auc_score, f1_score
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

In [4]:
from tqdm import tqdm
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent.parent))  # Adjust as needed
from config import DATAPATH, SAMPLE_DATAPATH

In [None]:
class TemporalGraphDataProcessor:
    """
    Processes SAML-D dataset into temporal graph format
    """
    
    def __init__(self, time_window_hours=168):  # 7 days default
        self.time_window_hours = time_window_hours
        self.scalers = {}
        self.encoders = {}
        
    def load_and_preprocess(self, df):
        """Load SAML-D dataset and perform initial preprocessing"""

        print("Loading data...")
        # Combine date and time into datetime
        df['datetime'] = pd.to_datetime(df['Date'] + ' ' + df['Time'])
        
        # Sort by datetime for temporal processing
        df = df.sort_values('datetime').reset_index(drop=True)
        
        print(f"Loaded {len(df)} transactions")
        print(f"Suspicious transactions: {df['Is_laundering'].sum()} ({df['Is_laundering'].mean()*100:.3f}%)")
        
        return df
    
    def engineer_features(self, df):
        """Create temporal and graph-specific features"""
        print("Engineering features...")
        
        # Time-based features
        df['hour'] = df['datetime'].dt.hour.astype('int8')
        df['day_of_week'] = df['datetime'].dt.dayofweek.astype('int8')
        df['day_of_month'] = df['datetime'].dt.day.astype('int8')
        df['is_weekend'] = (df['day_of_week'] >= 5).astype('int8')

        # Log transform amount to handle skewness
        df['log_amount'] = np.log1p(df['Amount']).astype('float32')
        
        # Cross-border indicator
        df['is_cross_border'] = (df['Payment_type'] == 'Cross-border').astype('int8')

        # Risky countries
        risky_countries = {'Mexico', 'Turkey', 'Morocco', 'UAE'}  # Example risky countries
        df['sender_high_risk'] = df['Sender_bank_location'].isin(risky_countries).astype('int8')
        df['receiver_high_risk'] = df['Receiver_bank_location'].isin(risky_countries).astype('int8')

        # Currency mismatch
        df['currency_mismatch'] = (df['Payment_currency'] != df['Received_currency']).astype('int8')

        # Converting Is_laundering to int8
        df['Is_laundering'] = df['Is_laundering'].astype('int8')

        # Delete unnecessary columns
        df = df.drop(columns=['Date', 'Time', 'Amount', 'Sender_bank_location', 'Receiver_bank_location', 
                              'Payment_currency', 'Received_currency', 'Laundering_type'])

        # Transaction frequency features (last 24h, 7d, 30d)
        print("Computing temporal features...")
        df = self._compute_temporal_frequencies(df)
        
        # Fan-in/Fan-out patterns
        print("Computing network features...")
        df = self._compute_network_features(df)
        
        return df
    
    def _compute_temporal_frequencies(self, df, window_hours_list=[24, 168]):
        """Compute transaction frequencies for different time windows"""
        df = df.copy()
        
        timestamps_numeric = df['datetime'].values.astype('int64') // 10**9
        senders = df['Sender_account'].values
        receivers = df['Receiver_account'].values

        for window_hours in window_hours_list:  # 1 day, 7 days
            window_nanoseconds = int(pd.Timedelta(hours=window_hours).total_seconds())
            
            sender_freq = np.zeros(len(df), dtype=int)
            receiver_freq = np.zeros(len(df), dtype=int)

            for idx in tqdm(range(len(df))):
                current_time_numeric = timestamps_numeric[idx]
                start_time_numeric = current_time_numeric - window_nanoseconds
                
                # Binary search for time window
                start_idx = np.searchsorted(timestamps_numeric, start_time_numeric, side='left')
                end_idx = idx + 1
                
                # Count in window
                if start_idx < end_idx:
                    window_senders = senders[start_idx:end_idx]
                    window_receivers = receivers[start_idx:end_idx]
                    
                    sender_freq[idx] = np.sum(window_senders == senders[idx])
                    receiver_freq[idx] = np.sum(window_receivers == receivers[idx])
            
            df[f'sender_freq_{window_hours}h'] = sender_freq
            df[f'receiver_freq_{window_hours}h'] = receiver_freq
        
        return df

    def _compute_network_features(self, df):
        """Vectorized network features calculation"""
        df = df.copy().sort_values('datetime').reset_index(drop=True)
        
        # Pre-convert to categoricals for faster comparisons
        df['Sender_account'] = df['Sender_account'].astype('category')
        df['Receiver_account'] = df['Receiver_account'].astype('category')
        
        timestamps = df['datetime'].astype('int64').values
        window_ns = pd.Timedelta(days=30).value
        
        # Pre-allocate results
        n = len(df)
        fanout_30d = np.zeros(n, dtype=int)
        fanin_30d = np.zeros(n, dtype=int)
        back_forth_transfers = np.zeros(n, dtype=int)
        
        # Convert to numpy arrays for faster access
        senders = df['Sender_account'].cat.codes.values
        receivers = df['Receiver_account'].cat.codes.values
    
        for idx in tqdm(range(n)):
            current_time = timestamps[idx]
            start_time = current_time - window_ns
            
            # Find window indices
            start_idx = np.searchsorted(timestamps, start_time, side='left')
            
            # Work with numpy arrays instead of pandas
            window_senders = senders[start_idx:idx+1]
            window_receivers = receivers[start_idx:idx+1]
            window_timestamps = timestamps[start_idx:idx+1]
            
            current_sender = senders[idx]
            current_receiver = receivers[idx]
            
            # Vectorized calculations
            sender_mask = window_senders == current_sender
            receiver_mask = window_receivers == current_receiver
            
            fanout_30d[idx] = len(np.unique(window_receivers[sender_mask]))
            fanin_30d[idx] = len(np.unique(window_senders[receiver_mask]))
            
            # Back-forth: both directions between current sender/receiver
            back_forth_mask = ((window_senders == current_sender) & 
                            (window_receivers == current_receiver)) | \
                            ((window_senders == current_receiver) & 
                            (window_receivers == current_sender))
            back_forth_transfers[idx] = np.sum(back_forth_mask)
        
        df['fanout_30d'] = fanout_30d
        df['fanin_30d'] = fanin_30d
        df['back_forth_transfers'] = back_forth_transfers
        
        return df
    
    def create_temporal_snapshots(self, df):
        """Create temporal graph snapshots"""
        print("Creating temporal graph snapshots...")
        
        # Sort by datetime
        df = df.sort_values('datetime')

        # Get all unique accounts globally
        all_accounts = list(set(df['Sender_account'].unique()) | set(df['Receiver_account'].unique()))
        global_account_to_idx = {acc: idx for idx, acc in enumerate(all_accounts)}
        global_num_nodes = len(all_accounts)
        
        # Define time windows
        start_time = df['datetime'].min().normalize()  # Start of first day
        end_time = df['datetime'].max().normalize() + pd.Timedelta(days=1)  # Start of day after last day

        snapshots = []
        current_time = start_time
        print(f"Total time range: {start_time.date()} to {end_time.date()}")

        while current_time < end_time:
            window_end = current_time + pd.Timedelta(hours=self.time_window_hours)
            print(f"Processing window: {current_time} to {window_end}")

            # Get transactions in current window
            window_mask = (df['datetime'] >= current_time) & (df['datetime'] < window_end)
            window_data = df[window_mask].copy()
            
            if len(window_data) > 0:
                # Create graph for this window
                graph_data = self._create_graph_snapshot(window_data, current_time, global_account_to_idx, global_num_nodes)
                if graph_data is not None:
                    snapshots.append(graph_data)
            
            current_time = window_end
        
        print(f"Created {len(snapshots)} temporal snapshots")
        return snapshots, global_num_nodes
    
    def _create_graph_snapshot(self, window_data, timestamp, global_account_to_idx, global_num_nodes):
        """Create a single graph snapshot"""
        if len(window_data) == 0:
            return None
        
        # Get unique accounts
        active_accounts = set(window_data['Sender_account'].tolist() + 
                           window_data['Receiver_account'].tolist())

        # Create edges (transactions)
        feature_columns = [
            'log_amount', 'hour', 'day_of_week', 'is_weekend',
            'is_cross_border', 'currency_mismatch',
            'sender_high_risk', 'receiver_high_risk'
        ]

        sender_mapped = window_data['Sender_account'].map(global_account_to_idx)
        receiver_mapped = window_data['Receiver_account'].map(global_account_to_idx)
        edge_index = np.column_stack((sender_mapped, receiver_mapped))
        edge_features = window_data[feature_columns].values
        transaction_labels = window_data['Is_laundering'].values
        
        # Create node features (account features) - zero for all, update active
        node_features = np.zeros((global_num_nodes, 5))  # 5 features: sender_freq_24h, receiver_freq_24h, fanout_30d, fanin_30d, back_forth_transfers

        for account in active_accounts:
            # Get account statistics from window
            account_data = window_data[
                (window_data['Sender_account'] == account) | 
                (window_data['Receiver_account'] == account)
            ]     
            if len(account_data) > 0:
                node_feat = [
                    account_data['sender_freq_24h'].max() if 'sender_freq_24h' in account_data.columns else 0,
                    account_data['receiver_freq_24h'].max() if 'receiver_freq_24h' in account_data.columns else 0,
                    account_data['fanout_30d'].max() if 'fanout_30d' in account_data.columns else 0,
                    account_data['fanin_30d'].max() if 'fanin_30d' in account_data.columns else 0,
                    account_data['back_forth_transfers'].max() if 'back_forth_transfers' in account_data.columns else 0,
                ]
                global_idx = global_account_to_idx[account]
                node_features[global_idx] = node_feat
        
        # Convert to tensors
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        node_features = torch.tensor(node_features, dtype=torch.float)
        edge_features = torch.tensor(edge_features, dtype=torch.float)
        transaction_labels = torch.tensor(transaction_labels, dtype=torch.float)
        
        # Create PyTorch Geometric data object
        data = Data(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_features,
            y=transaction_labels,
            timestamp=timestamp,
            num_nodes=global_num_nodes
        )
        
        return data

In [6]:
# Temporal GNN Model for Edge Classification
class TemporalEdgeClassifier(nn.Module):
    def __init__(self, node_dim, edge_dim, hidden_dim):
        super(TemporalEdgeClassifier, self).__init__()
        self.rnn = nn.GRUCell(node_dim, hidden_dim)
        self.gnn1 = GATConv(hidden_dim, hidden_dim)
        self.gnn2 = GATConv(hidden_dim, hidden_dim)
        self.lin = nn.Linear(hidden_dim * 2 + edge_dim, 1)  # Binary classification

    def forward(self, data, h):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        
        # Update node hidden states with RNN (using current x)
        h = self.rnn(x, h)
        
        # Apply GNN layers
        h = F.relu(self.gnn1(h, edge_index))
        h = F.relu(self.gnn2(h, edge_index))
        
        # Edge features: concat sender h, receiver h, edge_attr
        h_i = h[edge_index[0]]
        h_j = h[edge_index[1]]
        edge_input = torch.cat([h_i, h_j, edge_attr], dim=-1)
        
        # Prediction
        out = self.lin(edge_input)
        
        return out, h  # Return logits and updated h

In [7]:
def train_model(snapshots, global_num_nodes, epochs=50, device='cuda' if torch.cuda.is_available() else 'cpu'):
    node_dim = 5  # From node features
    edge_dim = 8  # From edge features: log_amount, hour, day_of_week, is_weekend, is_cross_border, currency_mismatch, sender_high_risk, receiver_high_risk
    hidden_dim = 64
    
    model = TemporalEdgeClassifier(node_dim, edge_dim, hidden_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Compute pos_weight for imbalance (~960:1 from ~0.1% positive)
    total_trans = sum(len(s.y) for s in snapshots)
    total_pos = sum(s.y.sum() for s in snapshots)
    pos_weight = torch.tensor([(total_trans - total_pos) / total_pos]).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    # Split snapshots chronologically: 70% train, 15% val, 15% test
    n = len(snapshots)
    train_end = int(0.7 * n)
    val_end = int(0.85 * n)
    train_snaps = snapshots[:train_end]
    val_snaps = snapshots[train_end:val_end]
    test_snaps = snapshots[val_end:]
    
    print(f"Training on {len(train_snaps)} snapshots, validating on {len(val_snaps)}, testing on {len(test_snaps)}")

    train_loss_history = []
    val_loss_history = []
    val_preds, val_labels = [], []
    test_preds, test_labels = [], []
    
    for epoch in range(epochs):
        model.train()
        h = torch.zeros(global_num_nodes, hidden_dim).to(device)  # Initial hidden state
        
        train_loss = 0
        for snap in train_snaps:
            snap = snap.to(device)
            out, h = model(snap, h)
            loss = criterion(out.squeeze(), snap.y)
            train_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            h = h.detach()  # Detach history
        
        avg_train_loss = train_loss / len(train_snaps)
        train_loss_history.append(avg_train_loss)
        
        # Validation
        model.eval()
        with torch.no_grad():
            h = torch.zeros(global_num_nodes, hidden_dim).to(device)
            val_loss = 0
            for snap in val_snaps:
                snap = snap.to(device)
                out, h = model(snap, h)
                loss = criterion(out.squeeze(), snap.y)
                val_loss += loss.item()
            
            avg_val_loss = val_loss / len(val_snaps)
            val_loss_history.append(avg_val_loss)
        
        # Print loss every 10 epochs
        if (epoch + 1) % 10 == 0:
            print(f"After Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
    
    print("Training complete. Generating predictions...")
    # Validation set    
    model.eval()
    with torch.no_grad():
        h = torch.zeros(global_num_nodes, hidden_dim).to(device)
        for snap in val_snaps:
            snap = snap.to(device)
            out, h = model(snap, h)
            preds = torch.sigmoid(out).squeeze().cpu().numpy()
            val_preds.extend(preds)
            val_labels.extend(snap.y.cpu().numpy())

    # Test set
    model.eval()
    with torch.no_grad():
        h = torch.zeros(global_num_nodes, hidden_dim).to(device)
        for snap in test_snaps:
            snap = snap.to(device)
            out, h = model(snap, h)
            preds = torch.sigmoid(out).squeeze().cpu().numpy()
            test_preds.extend(preds)
            test_labels.extend(snap.y.cpu().numpy())
    
    return {
        'train_loss_history': train_loss_history,
        'val_loss_history': val_loss_history,
        'val_preds': np.array(val_preds),
        'val_labels': np.array(val_labels),
        'test_preds': np.array(test_preds),
        'test_labels': np.array(test_labels)
    }

# Usage Example (assuming data is loaded and processed)
# processor = TemporalGraphDataProcessor()
# df = processor.load_and_preprocess(df)
# df = processor.engineer_features(df)
# snapshots, global_num_nodes = processor.create_temporal_snapshots(df)
# results = train_model(snapshots, global_num_nodes)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the entire dataset
df = pd.read_csv(DATAPATH)

# Filter by data range
# df = df[df['Date'] < '2022-12-31']
# df = df.head(300000).copy()

In [9]:
processor = TemporalGraphDataProcessor()
df = processor.load_and_preprocess(df)
df = processor.engineer_features(df)

Loading data...
Loaded 9504852 transactions
Suspicious transactions: 9873 (0.104%)
Engineering features...
Computing temporal features...


100%|██████████| 9504852/9504852 [07:30<00:00, 21075.47it/s]
100%|██████████| 9504852/9504852 [50:29<00:00, 3137.94it/s]


Computing network features...


100%|██████████| 9504852/9504852 [4:49:29<00:00, 547.22it/s]  


In [None]:
# df = pd.read_csv('tmp/processed_data_tgnn.csv')
df.to_csv('tmp/processed_data_tgnn_full.csv', index=False)
df['datetime'] = pd.to_datetime(df['datetime'])

In [8]:
snapshots, global_num_nodes = processor.create_temporal_snapshots(df)

Creating temporal graph snapshots...
Total time range: 2022-10-07 to 2022-12-31
Processing window: 2022-10-07 00:00:00 to 2022-10-14 00:00:00
Processing window: 2022-10-14 00:00:00 to 2022-10-21 00:00:00
Processing window: 2022-10-21 00:00:00 to 2022-10-28 00:00:00
Processing window: 2022-10-28 00:00:00 to 2022-11-04 00:00:00
Processing window: 2022-11-04 00:00:00 to 2022-11-11 00:00:00
Processing window: 2022-11-11 00:00:00 to 2022-11-18 00:00:00
Processing window: 2022-11-18 00:00:00 to 2022-11-25 00:00:00
Processing window: 2022-11-25 00:00:00 to 2022-12-02 00:00:00
Processing window: 2022-12-02 00:00:00 to 2022-12-09 00:00:00
Processing window: 2022-12-09 00:00:00 to 2022-12-16 00:00:00
Processing window: 2022-12-16 00:00:00 to 2022-12-23 00:00:00
Processing window: 2022-12-23 00:00:00 to 2022-12-30 00:00:00
Processing window: 2022-12-30 00:00:00 to 2023-01-06 00:00:00
Created 13 temporal snapshots


In [17]:
results = train_model(snapshots, global_num_nodes, epochs=100)

Training on 9 snapshots, validating on 2, testing on 2
After Epoch 10: Train Loss: 0.6765, Validation Loss: 0.8470
After Epoch 20: Train Loss: 0.4913, Validation Loss: 0.6752
After Epoch 30: Train Loss: 0.4004, Validation Loss: 0.5848
After Epoch 40: Train Loss: 0.3407, Validation Loss: 0.5350
After Epoch 50: Train Loss: 0.3011, Validation Loss: 0.5738
After Epoch 60: Train Loss: 0.3020, Validation Loss: 0.6525
After Epoch 70: Train Loss: 0.3384, Validation Loss: 0.8250
After Epoch 80: Train Loss: 0.2510, Validation Loss: 0.5744
After Epoch 90: Train Loss: 0.2582, Validation Loss: 0.6248
After Epoch 100: Train Loss: 0.2237, Validation Loss: 0.5874
Training complete. Generating predictions...


In [18]:
from sklearn.metrics import confusion_matrix


# Function to compute and print confusion matrix
def compute_confusion_matrix(labels, preds, threshold=0.5):

    # Convert probabilities to binary predictions using the threshold
    binary_preds = (preds >= threshold).astype(int)

    # Compute confusion matrix
    cm = confusion_matrix(labels, binary_preds)
    
    print("Confusion Matrix:")
    print(cm)
    
    # Optional: Extract and print TP, TN, FP, FN
    tn, fp, fn, tp = cm.ravel()
    print(f"True Negatives (TN): {tn}")
    print(f"False Positives (FP): {fp}")
    print(f"False Negatives (FN): {fn}")
    print(f"True Positives (TP): {tp}")
    print(f"Precision: {tp / (tp + fp + 1e-8):.4f}")
    print(f"Recall: {tp / (tp + fn + 1e-8):.4f}")

In [19]:
compute_confusion_matrix(results['val_labels'], results['val_preds'], threshold=0.5)

Confusion Matrix:
[[379668  30472]
 [    43    484]]
True Negatives (TN): 379668
False Positives (FP): 30472
False Negatives (FN): 43
True Positives (TP): 484
Precision: 0.0156
Recall: 0.9184


In [20]:
compute_confusion_matrix(results['test_labels'], results['test_preds'], threshold=0.5)

Confusion Matrix:
[[211391  22034]
 [    25    121]]
True Negatives (TN): 211391
False Positives (FP): 22034
False Negatives (FN): 25
True Positives (TP): 121
Precision: 0.0055
Recall: 0.8288
