In [3]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.nn import Linear, GRU
from torch_geometric.data import TemporalData
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import IdentityMessage, LastAggregator
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, recall_score, precision_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

In [4]:
from tqdm import tqdm
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent.parent))  # Adjust as needed
from config import DATAPATH, SAMPLE_DATAPATH

In [5]:
class TGNTimeEncoding(torch.nn.Module):
    """Time encoding module for TGN"""
    def __init__(self, out_channels):
        super().__init__()
        self.out_channels = out_channels
        self.lin = Linear(1, out_channels)

    def forward(self, t):
        # t has shape [num_events] and values are timestamps
        return self.lin(t.view(-1, 1))
    
class GraphAttentionEmbedding(torch.nn.Module):
    """Graph attention embedding module for TGN"""
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(
            in_channels, out_channels // 2, 
            heads=2, dropout=0.1, edge_dim=edge_dim
        )

    def forward(self, x, last_update, edge_index, t, msg):
        # Calculate relative time - ensure proper indexing
        src_nodes = edge_index[0]
        rel_t = last_update[src_nodes] - t
        
        # Ensure rel_t is on the same device and dtype as x
        rel_t = rel_t.to(x.device).to(x.dtype)
        rel_t_enc = self.time_enc(rel_t)
        
        # Concatenate edge attributes
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        
        return self.conv(x, edge_index, edge_attr)


class LinkPredictor(torch.nn.Module):
    """Link prediction module for binary classification"""
    def __init__(self, in_channels):
        super().__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, 1)
        self.dropout = torch.nn.Dropout(0.2)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = F.relu(h)
        h = self.dropout(h)
        return torch.sigmoid(self.lin_final(h))

In [6]:
class TGNAMLModel(torch.nn.Module):
    """Complete TGN model for AML detection"""

    def __init__(self, num_nodes, edge_dim, memory_dim=100, time_dim=100, 
                 embedding_dim=100, num_layers=2):
        super().__init__()

        self.num_nodes = num_nodes
        self.memory_dim = memory_dim
        self.time_dim = time_dim
        self.embedding_dim = embedding_dim

        # Time encoding
        self.time_enc = TGNTimeEncoding(time_dim)

        # Memory module
        self.memory = TGNMemory(
            num_nodes=num_nodes,
            raw_msg_dim=edge_dim,
            memory_dim=memory_dim,
            time_dim=time_dim,
            message_module=IdentityMessage(edge_dim, memory_dim, time_dim),
            aggregator_module=LastAggregator(),
        )

        # Graph neural network
        self.gnn = GraphAttentionEmbedding(
            in_channels=memory_dim,
            out_channels=embedding_dim,
            msg_dim=edge_dim,
            time_enc=self.time_enc
        )

        # Link predictor for classification
        self.link_pred = LinkPredictor(embedding_dim)

    def forward(self, batch):
        # Get unique nodes in the batch
        unique_nodes = torch.cat([batch.src, batch.dst]).unique()
        
        # Get memory for all unique nodes
        z_all, last_update = self.memory(unique_nodes)
        
        # Create mapping from node IDs to indices in z_all
        node_to_idx = {node.item(): idx for idx, node in enumerate(unique_nodes)}
        
        # Map batch nodes to indices
        src_indices = torch.tensor([node_to_idx[node.item()] for node in batch.src], 
                                dtype=torch.long, device=batch.src.device)
        dst_indices = torch.tensor([node_to_idx[node.item()] for node in batch.dst], 
                                dtype=torch.long, device=batch.dst.device)
        
        # Create edge index for GNN (using mapped indices)
        edge_index = torch.stack([src_indices, dst_indices], dim=0)
        
        # Apply GNN
        node_embeddings = self.gnn(z_all, last_update, edge_index, batch.t.float(), batch.msg)
        
        # Get embeddings for source and destination nodes
        z_src = node_embeddings[src_indices]
        z_dst = node_embeddings[dst_indices]

        # Predict
        return self.link_pred(z_src, z_dst)
    
    def update_memory(self, batch):
        """Update memory after each batch"""
        with torch.no_grad():
            self.memory.update_state(batch.src, batch.dst, batch.t, batch.msg)

In [33]:
class TGNTrainer:
    """Trainer class for TGN AML detection"""

    def __init__(self, model, device='cuda'):
        self.model = model.to(device)
        self.device = device
        self.optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
        # self.criterion = torch.nn.BCELoss()
        fraud_weight = 1000  # Heavy penalty for missing fraud
        pos_weight = torch.tensor([fraud_weight]).to(device)
        self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    def train_epoch(self, data_loader):
        """Train for one epoch"""
        self.model.train()
        self.model.memory.train()
        self.model.memory.reset_state()  # Reset at start of epoch

        total_loss = 0
        num_batches = 0

        # Add progress bar
        pbar = tqdm(data_loader, desc="Training")

        for batch in pbar:
            batch = batch.to(self.device)

            self.optimizer.zero_grad()

            # Forward pass
            out = self.model(batch)

            # Calculate loss
            loss = self.criterion(out.squeeze(), batch.y)

            # Backward pass
            loss.backward()
            self.optimizer.step()

            # CRITICAL: Update memory with detached tensors to break computation graph
            self.model.update_memory(batch)

            total_loss += loss.item()
            num_batches += 1

        return total_loss / num_batches

    @torch.no_grad()
    def evaluate(self, data_loader):
        """Evaluate model"""
        self.model.eval()
        self.model.memory.eval()

        all_preds = []
        all_labels = []

        # Add progress bar
        pbar = tqdm(data_loader, desc="Evaluation")

        for batch in pbar:
            batch = batch.to(self.device)

            out = self.model(batch)

            all_preds.append(out.cpu().numpy())
            all_labels.append(batch.y.cpu().numpy())

        y_pred = np.concatenate(all_preds)
        y_true = np.concatenate(all_labels)

        # Calculate metrics
        auc = roc_auc_score(y_true, y_pred)
        ap = average_precision_score(y_true, y_pred)

        # Binary predictions for F1, precision, recall
        y_pred_binary = (y_pred > 0.5).astype(int)
        f1 = f1_score(y_true, y_pred_binary)
        precision = precision_score(y_true, y_pred_binary, zero_division=0)
        recall = recall_score(y_true, y_pred_binary)

        return {
            'auc': auc,
            'ap': ap,
            'f1': f1,
            'precision': precision,
            'recall': recall,
            'predictions': y_pred,
            'labels': y_true
        }

In [34]:
class SAMLDataProcessor:
    """Data processor for SAML-D dataset"""
    
    def __init__(self):
        self.label_encoders = {}
        self.scaler = StandardScaler()
        
    def load_and_preprocess(self, df, first_n=None):
        """Load and preprocess SAML-D dataset"""
        print("Loading SAML-D dataset...")
        if first_n is not None:
            df = df.head(first_n)

        # Basic info about dataset
        print(f"Dataset shape: {df.shape}")
        print(f"Laundering transactions: {df['Is_laundering'].sum()}")
        print(f"Percentage of laundering: {df['Is_laundering'].mean()*100:.3f}%")
        
        # Feature engineering
        df = self._engineer_features(df)

        # Clear unnecessary columns
        df = self._clean_columns(df)

        # Create temporal graph data
        temporal_data = self._create_temporal_data(df)
        
        return temporal_data, df
    
    def _engineer_features(self, df):
        """Engineer features from raw data"""
        print("Engineering features...")
        
        # Combine date and time into datetime
        df['DateTime'] = pd.to_datetime(df['Date'] + ' ' + df['Time'], format='%Y-%m-%d %H:%M:%S')
        df['timestamp'] = df['DateTime'].astype('int64') // 10**9  # Convert to unix timestamp

        # Drop columns = ['Date', 'Time', 'Laundering_type]
        df = df.drop(columns=['Date', 'Time', 'Laundering_type'])
        
        # Extract temporal features
        df['hour'] = df['DateTime'].dt.hour.astype('int8')
        df['day_of_week'] = df['DateTime'].dt.dayofweek.astype('int8')
        df['weekend'] = (df['day_of_week'] >= 5).astype('int8')
        
        # Amount features
        df['log_amount'] = np.log1p(df['Amount']).astype('float32')
        
        # Cross-border indicator
        df['is_cross_border'] = (df['Payment_type'] == 'Cross-border').astype('int8')
        
        # Currency mismatch
        df['currency_mismatch'] = (df['Payment_currency'] != df['Received_currency']).astype('int8')

        # Is laundering datatype
        df['Is_laundering'] = df['Is_laundering'].astype('int8')

        # Encode categorical features
        categorical_features = [
            'Payment_type', 'Sender_bank_location', 'Receiver_bank_location',
            'Payment_currency', 'Received_currency'
        ]
        
        for feature in categorical_features:
            if feature not in self.label_encoders:
                self.label_encoders[feature] = LabelEncoder()
                encoded_values = self.label_encoders[feature].fit_transform(df[feature].astype(str))
            else:
                encoded_values = self.label_encoders[feature].transform(df[feature].astype(str))
            df[f'{feature}_encoded'] = encoded_values

        # Check if values fit in int8 range (0-127 for positive values)
            max_encoded = encoded_values.max()
            if max_encoded <= 127:
                df[f'{feature}_encoded'] = encoded_values.astype('int8')
            elif max_encoded <= 32767:
                df[f'{feature}_encoded'] = encoded_values.astype('int16')
                print(f"Warning: {feature} has {max_encoded} unique values, using int16")
            else:
                df[f'{feature}_encoded'] = encoded_values.astype('int32')
                print(f"Warning: {feature} has {max_encoded} unique values, using int32")

        return df
    
    def _clean_columns(self, df):
        """Remove unnecessary features"""
        print("Cleaning unnecessary columns...")
        unwanted_cols = ['Payment_type', 'Sender_bank_location', 'Receiver_bank_location',
                         'Payment_currency', 'Received_currency', 'Amount']
        df = df.drop(columns=unwanted_cols)
        return df
    
    def _create_temporal_data(self, df):
        """Create temporal graph data structure"""
        print("Creating temporal graph structure...")
        
        # Create node mapping
        senders = df['Sender_account'].unique()
        receivers = df['Receiver_account'].unique()
        all_nodes = np.unique(np.concatenate([senders, receivers]))
        
        node_to_idx = {node: idx for idx, node in enumerate(all_nodes)}
        num_nodes = len(all_nodes)
        
        print(f"Number of unique accounts (nodes): {num_nodes}")
        
        # Map accounts to indices
        src_nodes = df['Sender_account'].map(node_to_idx).values
        dst_nodes = df['Receiver_account'].map(node_to_idx).values

        # Edge features (transaction features)
        edge_features = [
            'log_amount', 'Payment_type_encoded', 'hour', 'day_of_week',
            'weekend', 'is_cross_border', 'currency_mismatch'
        ]
        
        edge_attr = df[edge_features].values.astype(np.float32)
        
        # Normalize edge features
        edge_attr = self.scaler.fit_transform(edge_attr)
        
        # Sort by timestamp
        sort_idx = np.argsort(df['timestamp'].values)
        
        temporal_data = TemporalData(
            src=torch.tensor(src_nodes[sort_idx], dtype=torch.long),
            dst=torch.tensor(dst_nodes[sort_idx], dtype=torch.long),
            t=torch.tensor(df['timestamp'].values[sort_idx], dtype=torch.long),
            msg=torch.tensor(edge_attr[sort_idx], dtype=torch.float),
            y=torch.tensor(df['Is_laundering'].values[sort_idx], dtype=torch.float)
        )
        
        # Add number of nodes
        temporal_data.num_nodes = num_nodes
        
        return temporal_data

In [13]:
# Load the entire dataset
df = pd.read_csv(DATAPATH)

# Filter by data range
df = df[df['Date'] < '2023-03-31']

In [14]:
# Pre-process data
data_processor = SAMLDataProcessor()

temporal_data, df = data_processor.load_and_preprocess(df)

Loading SAML-D dataset...
Dataset shape: (5198135, 12)
Laundering transactions: 5193
Percentage of laundering: 0.100%
Engineering features...
Cleaning unnecessary columns...
Creating temporal graph structure...
Number of unique accounts (nodes): 642216


In [None]:
# Split data temporally (80% train, 20% test)
split_idx = int(0.8 * len(temporal_data.t))

train_data = TemporalData(
    src=temporal_data.src[:split_idx],
    dst=temporal_data.dst[:split_idx],
    t=temporal_data.t[:split_idx],
    msg=temporal_data.msg[:split_idx],
    y=temporal_data.y[:split_idx]
)

test_data = TemporalData(
    src=temporal_data.src[split_idx:],
    dst=temporal_data.dst[split_idx:],
    t=temporal_data.t[split_idx:],
    msg=temporal_data.msg[split_idx:],
    y=temporal_data.y[split_idx:]
)

# Create data loaders
train_loader = TemporalDataLoader(train_data, batch_size=256, num_workers=4)
test_loader = TemporalDataLoader(test_data, batch_size=256, num_workers=4)

print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

Training batches: 16245
Test batches: 4062


In [40]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [41]:
import gc
def clear_cuda_memory():
    """Clear CUDA memory completely"""
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    gc.collect()
    print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"CUDA memory cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

In [42]:
clear_cuda_memory()

# Initialize model
model = TGNAMLModel(
    num_nodes=temporal_data.num_nodes,
    edge_dim=temporal_data.msg.size(1),
    memory_dim=32,
    time_dim=32,
    embedding_dim=32
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

CUDA memory allocated: 657.43 MB
CUDA memory cached: 674.00 MB
Model parameters: 20897


In [43]:
# Initialize trainer
trainer = TGNTrainer(model, device)

# Training loop
num_epochs = 50
train_losses = []
best_auc = 0

print("\nStarting training...")
for epoch in range(num_epochs):
    # Train
    train_loss = trainer.train_epoch(train_loader)
    train_losses.append(train_loss)

    # Print training progress without evaluation
    print(f"Epoch {epoch+1:2d}: Loss: {train_loss:.4f}")
    
    # Only evaluate every 10 epochs to save memory
    # if (epoch + 1) % 5 == 0:
    print(f"Saving model checkpoint at epoch {epoch+1}")
    torch.save(model.state_dict(), f'tgn_aml_model_epoch_{epoch+1}.pth')
    
    # Evaluate every 5 epochs
    # if (epoch + 1) % 5 == 0:
    test_metrics = trainer.evaluate(test_loader)
    
    print(f"Epoch {epoch+1:2d}: "
            f"Loss: {train_loss:.4f}, "
            f"AUC: {test_metrics['auc']:.4f}, "
            f"Precision: {test_metrics['precision']:.4f}, "
            f"Recall: {test_metrics['recall']:.4f}, "
            f"F1: {test_metrics['f1']:.4f}")
    
    if test_metrics['auc'] > best_auc:
        best_auc = test_metrics['auc']
        torch.save(model.state_dict(), 'best_tgn_aml_model.pth')


Starting training...


Training: 100%|██████████| 16245/16245 [13:22<00:00, 20.24it/s]


Epoch  1: Loss: 1.3867
Saving model checkpoint at epoch 1


Evaluation: 100%|██████████| 4062/4062 [01:02<00:00, 65.04it/s]


Epoch  1: Loss: 1.3867, AUC: 0.5000, Precision: 0.0000, Recall: 0.0000, F1: 0.0000


Training:  10%|▉         | 1607/16245 [01:19<12:04, 20.20it/s]


KeyboardInterrupt: 

In [24]:
def evaluate_saved_model(checkpoint_path, test_loader, num_nodes, edge_dim, device):
    """Load and evaluate a saved model checkpoint"""
    
    # 1. Create model with same architecture as training
    model = TGNAMLModel(
        num_nodes=num_nodes,
        edge_dim=edge_dim,
        memory_dim=32,    # Match your training config
        time_dim=32,
        embedding_dim=32
    )
    
    # 2. Load saved weights
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model = model.to(device)
    
    # 3. Create trainer with loaded model
    trainer = TGNTrainer(model, device)
    
    # 4. Standard evaluation
    print(f"Evaluating {checkpoint_path}...")
    metrics = trainer.evaluate(test_loader)
    
    return metrics

# Usage example:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Evaluate specific checkpoint
results = evaluate_saved_model(
    checkpoint_path='tgn_aml_model_epoch_10.pth',
    test_loader=test_loader,
    num_nodes=temporal_data.num_nodes,
    edge_dim=temporal_data.msg.size(1),
    device=device
)

print(f"AUC: {results['auc']:.4f}")
print(f"Recall: {results['recall']:.4f}")
print(f"Precision: {results['precision']:.4f}")

Evaluating tgn_aml_model_epoch_10.pth...


Evaluation: 100%|██████████| 7426/7426 [01:48<00:00, 68.48it/s]


AUC: 0.5000
Recall: 0.0000
Precision: 0.0000


In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Load and preprocess data
data_processor = SAMLDataProcessor()

temporal_data, df = data_processor.load_and_preprocess(filepath)


# Split data temporally (80% train, 20% test)
split_idx = int(0.8 * len(temporal_data.t))

train_data = TemporalData(
    src=temporal_data.src[:split_idx],
    dst=temporal_data.dst[:split_idx],
    t=temporal_data.t[:split_idx],
    msg=temporal_data.msg[:split_idx],
    y=temporal_data.y[:split_idx]
)

test_data = TemporalData(
    src=temporal_data.src[split_idx:],
    dst=temporal_data.dst[split_idx:],
    t=temporal_data.t[split_idx:],
    msg=temporal_data.msg[split_idx:],
    y=temporal_data.y[split_idx:]
)

# Create data loaders
train_loader = TemporalDataLoader(train_data, batch_size=200)
test_loader = TemporalDataLoader(test_data, batch_size=200)

print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

# Initialize model
model = TGNAMLModel(
    num_nodes=temporal_data.num_nodes,
    edge_dim=temporal_data.msg.size(1),
    memory_dim=100,
    time_dim=100,
    embedding_dim=100
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

# Initialize trainer
trainer = TGNTrainer(model, device)

# Training loop
num_epochs = 50
train_losses = []
best_auc = 0

print("\nStarting training...")
for epoch in range(num_epochs):
    # Train
    train_loss = trainer.train_epoch(train_loader)
    train_losses.append(train_loss)
    
    # Evaluate every 5 epochs
    if (epoch + 1) % 5 == 0:
        test_metrics = trainer.evaluate(test_loader)
        
        print(f"Epoch {epoch+1:2d}: "
                f"Loss: {train_loss:.4f}, "
                f"AUC: {test_metrics['auc']:.4f}, "
                f"AP: {test_metrics['ap']:.4f}, "
                f"F1: {test_metrics['f1']:.4f}, "
                f"Recall: {test_metrics['recall']:.4f}")
        
        if test_metrics['auc'] > best_auc:
            best_auc = test_metrics['auc']
            torch.save(model.state_dict(), 'best_tgn_aml_model.pth')


In [None]:
# Final evaluation
print("\nFinal evaluation...")
final_metrics = trainer.evaluate(test_loader)

print(f"\nFinal Test Results:")
print(f"AUC-ROC: {final_metrics['auc']:.4f}")
print(f"Average Precision: {final_metrics['ap']:.4f}")
print(f"F1 Score: {final_metrics['f1']:.4f}")
print(f"Precision: {final_metrics['precision']:.4f}")
print(f"Recall: {final_metrics['recall']:.4f}")

# Plotting results
plt.figure(figsize=(15, 5))

# Training loss
plt.subplot(1, 3, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

# ROC Curve
from sklearn.metrics import roc_curve
plt.subplot(1, 3, 2)
fpr, tpr, _ = roc_curve(final_metrics['labels'], final_metrics['predictions'])
plt.plot(fpr, tpr, label=f'AUC = {final_metrics["auc"]:.3f}')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()

# Precision-Recall Curve
from sklearn.metrics import precision_recall_curve
plt.subplot(1, 3, 3)
precision, recall, _ = precision_recall_curve(final_metrics['labels'], final_metrics['predictions'])
plt.plot(recall, precision, label=f'AP = {final_metrics["ap"]:.3f}')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend()

plt.tight_layout()
plt.savefig('tgn_aml_results.png', dpi=300, bbox_inches='tight')
plt.show()