In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, HeteroData
from torch_geometric.utils import to_undirected
import networkx as nx
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict, Counter
from datetime import datetime, timedelta
import pickle
import warnings
warnings.filterwarnings('ignore')


transactions_df = pd.read_csv('../data/processed_transactions.csv')
accounts_df = pd.read_csv('../data/processed_accounts.csv')
patterns_df = pd.read_csv('../data/processed_patterns.csv')
pattern_transactions_df = pd.read_csv('../data/processed_pattern_transactions.csv')

transactions_df['timestamp'] = pd.to_datetime(transactions_df['timestamp'])
pattern_transactions_df['timestamp'] = pd.to_datetime(pattern_transactions_df['timestamp'])

with open('../data/insights_summary.pkl', 'rb') as f:
    insights_summary = pickle.load(f)

with open('../data/network_metrics.pkl', 'rb') as f:
    network_metrics = pickle.load(f)

print(f"Loaded {len(transactions_df):,} transactions")
print(f"Loaded {len(accounts_df):,} accounts") 
print(f"Loaded {len(patterns_df)} patterns")

Loaded 5,078,345 transactions
Loaded 518,581 accounts
Loaded 370 patterns


In [6]:
print("Creating comprehensive graph annotation system...")

class AnnotatedGraphConstructor:
    def __init__(self, transactions_df, accounts_df, pattern_transactions_df, insights_summary):
        self.transactions_df = transactions_df
        self.accounts_df = accounts_df
        self.pattern_transactions_df = pattern_transactions_df
        self.insights = insights_summary
        
        self.node_mappings = {}
        self.account_features = {}
        self.graphs = {}
        self.annotations = {}
        self.metadata = {}
        
        self.creation_timestamp = datetime.now().isoformat()
        self.version = "1.0"
        
    def create_comprehensive_annotations(self):
        """Create detailed annotations for all graph components"""
        self.annotations = {
            'project_info': {
                'name': 'Explainable AML Detection with Graph Neural Networks',
                'version': self.version,
                'creation_date': self.creation_timestamp,
                'description': 'Multi-modal graph construction for explainable anti-money laundering detection',
                'dataset': 'IBM Synthetic AML Dataset (HI-Small)',
                'purpose': 'Regulatory-compliant explainable AI for financial crime detection'
            },
            'data_sources': {
                'transactions': {
                    'file': 'processed_transactions.csv',
                    'rows': len(self.transactions_df),
                    'columns': list(self.transactions_df.columns),
                    'date_range': f"{self.transactions_df['timestamp'].min()} to {self.transactions_df['timestamp'].max()}",
                    'ml_rate': self.transactions_df['is_laundering'].mean(),
                    'description': 'Standardized transaction data with temporal and amount features'
                },
                'accounts': {
                    'file': 'processed_accounts.csv',
                    'rows': len(self.accounts_df),
                    'columns': list(self.accounts_df.columns),
                    'unique_banks': self.accounts_df['bank_id'].nunique(),
                    'unique_entities': self.accounts_df['entity_id'].nunique(),
                    'description': 'Account information with entity types and banking relationships'
                },
                'patterns': {
                    'file': 'processed_pattern_transactions.csv',
                    'rows': len(self.pattern_transactions_df),
                    'pattern_types': list(self.pattern_transactions_df['pattern_type'].unique()) if len(self.pattern_transactions_df) > 0 else [],
                    'description': 'Ground truth money laundering patterns for validation'
                }
            },
            'methodology': {
                'graph_construction_approach': 'Multi-modal heterogeneous graph construction',
                'feature_engineering': 'Behavioral pattern analysis with temporal and structural features',
                'explainability_focus': 'Graph Attention Network compatible structure for regulatory explanations',
                'validation_strategy': 'Ground truth pattern validation with ego-network analysis'
            }
        }
        
        return self.annotations
    
    def create_node_mappings_with_annotations(self):
        """Create node mappings with detailed annotations"""
        print("Creating annotated node mappings...")
        
        unique_accounts = set()
        unique_accounts.update(self.transactions_df['account_origin'].unique())
        unique_accounts.update(self.transactions_df['account_destination'].unique())
        
        self.node_mappings['accounts'] = {acc: idx for idx, acc in enumerate(sorted(unique_accounts))}
        
        unique_banks = set()
        unique_banks.update(self.transactions_df['from_bank'].unique())
        unique_banks.update(self.transactions_df['to_bank'].unique())
        
        self.node_mappings['banks'] = {bank: idx for idx, bank in enumerate(sorted(unique_banks))}
        
        self.node_mappings['transactions'] = {idx: idx for idx in self.transactions_df.index}
        
        self.annotations['node_mappings'] = {
            'accounts': {
                'count': len(self.node_mappings['accounts']),
                'description': 'Bank account identifiers mapped to sequential integers',
                'mapping_strategy': 'Sorted alphabetical order for consistency',
                'usage': 'Primary entities in transaction flow and behavior graphs'
            },
            'banks': {
                'count': len(self.node_mappings['banks']),
                'description': 'Bank identifiers mapped to sequential integers',
                'mapping_strategy': 'Sorted alphabetical order for consistency',
                'usage': 'Institutional nodes in multi-modal integration graph'
            },
            'transactions': {
                'count': len(self.node_mappings['transactions']),
                'description': 'Transaction indices for temporal proximity graph',
                'mapping_strategy': 'Original dataframe indices preserved',
                'usage': 'Transaction-level nodes for sequence analysis'
            }
        }
        
        print(f"Mapped {len(self.node_mappings['accounts']):,} accounts")
        print(f"Mapped {len(self.node_mappings['banks']):,} banks")
        print(f"Mapped {len(self.node_mappings['transactions']):,} transactions")
        
        return self.node_mappings

annotated_constructor = AnnotatedGraphConstructor(transactions_df, accounts_df, pattern_transactions_df, insights_summary)
annotations = annotated_constructor.create_comprehensive_annotations()
node_mappings = annotated_constructor.create_node_mappings_with_annotations()

Creating comprehensive graph annotation system...
Creating annotated node mappings...
Mapped 515,080 accounts
Mapped 30,470 banks
Mapped 5,078,345 transactions


In [None]:
print("Engineering annotated account-level behavioral features...")

def engineer_annotated_account_features(transactions_df, accounts_df, node_mappings, annotated_constructor):
    """Engineer account features with comprehensive annotations"""
    account_features = {}
    
    feature_definitions = {
        'total_outgoing': 'Number of outgoing transactions from account',
        'total_incoming': 'Number of incoming transactions to account', 
        'total_transactions': 'Total transaction count (outgoing + incoming)',
        'avg_outgoing_amount': 'Average amount of outgoing transactions (USD)',
        'avg_incoming_amount': 'Average amount of incoming transactions (USD)',
        'total_outgoing_volume': 'Total volume of outgoing transactions (USD)',
        'total_incoming_volume': 'Total volume of incoming transactions (USD)',
        'unique_recipients': 'Number of unique recipient accounts (fan-out indicator)',
        'unique_senders': 'Number of unique sender accounts (fan-in indicator)',
        'currency_diversity': 'Number of different currencies used',
        'payment_format_diversity': 'Number of different payment formats used',
        'transaction_velocity': 'Transactions per hour (velocity indicator)',
        'fan_out_degree': 'Outgoing connection count (laundering pattern indicator)',
        'fan_in_degree': 'Incoming connection count (collection pattern indicator)',
        'ml_rate': 'Money laundering involvement rate (0-1)',
        'round_amount_ratio': 'Proportion of round-number transactions (structuring indicator)',
        'outgoing_time_span': 'Time span of outgoing transactions (hours)'
    }
    
    feature_engineering_stats = {
        'total_features': len(feature_definitions),
        'categories': {
            'transaction_counts': ['total_outgoing', 'total_incoming', 'total_transactions'],
            'amount_statistics': ['avg_outgoing_amount', 'avg_incoming_amount', 'total_outgoing_volume', 'total_incoming_volume'],
            'network_properties': ['unique_recipients', 'unique_senders', 'fan_out_degree', 'fan_in_degree'],
            'behavioral_indicators': ['currency_diversity', 'payment_format_diversity', 'transaction_velocity'],
            'risk_indicators': ['ml_rate', 'round_amount_ratio'],
            'temporal_features': ['outgoing_time_span']
        },
        'interpretability_focus': 'All features designed for regulatory explanation and compliance auditing',
        'missing_value_handling': 'Zero-filled for accounts with no transactions in specific directions'
    }
    
    for account in node_mappings['accounts'].keys():
        outgoing_txns = transactions_df[transactions_df['account_origin'] == account]
        incoming_txns = transactions_df[transactions_df['account_destination'] == account]
        
        total_outgoing = len(outgoing_txns)
        total_incoming = len(incoming_txns)
        
        if total_outgoing > 0:
            avg_outgoing_amount = outgoing_txns['amount_paid'].mean()
            total_outgoing_volume = outgoing_txns['amount_paid'].sum()
            unique_recipients = outgoing_txns['account_destination'].nunique()
            currency_diversity = outgoing_txns['payment_currency'].nunique()
            payment_format_diversity = outgoing_txns['payment_format'].nunique()
            ml_outgoing_count = outgoing_txns['is_laundering'].sum()
            
            outgoing_time_span = 0
            if len(outgoing_txns) > 1:
                outgoing_time_span = (outgoing_txns['timestamp'].max() - 
                                    outgoing_txns['timestamp'].min()).total_seconds() / 3600
            
            round_amounts = outgoing_txns['amount_paid'].apply(lambda x: x == round(x, -2)).sum()
            round_amount_ratio = round_amounts / total_outgoing
        else:
            avg_outgoing_amount = 0
            total_outgoing_volume = 0
            unique_recipients = 0
            currency_diversity = 0
            payment_format_diversity = 0
            ml_outgoing_count = 0
            outgoing_time_span = 0
            round_amount_ratio = 0
        
        if total_incoming > 0:
            avg_incoming_amount = incoming_txns['amount_paid'].mean()
            total_incoming_volume = incoming_txns['amount_paid'].sum()
            unique_senders = incoming_txns['account_origin'].nunique()
            ml_incoming_count = incoming_txns['is_laundering'].sum()
        else:
            avg_incoming_amount = 0
            total_incoming_volume = 0
            unique_senders = 0
            ml_incoming_count = 0
        
        transaction_velocity = total_outgoing / max(outgoing_time_span, 1)
        fan_out_degree = unique_recipients
        fan_in_degree = unique_senders
        ml_rate = (ml_outgoing_count + ml_incoming_count) / max(total_outgoing + total_incoming, 1)
        
        features = [
            total_outgoing, total_incoming, total_outgoing + total_incoming,
            avg_outgoing_amount, avg_incoming_amount,
            total_outgoing_volume, total_incoming_volume,
            unique_recipients, unique_senders,
            currency_diversity, payment_format_diversity,
            transaction_velocity, fan_out_degree, fan_in_degree,
            ml_rate, round_amount_ratio, outgoing_time_span
        ]
        
        account_features[account] = features
    
    feature_names = list(feature_definitions.keys())
    
    annotated_constructor.annotations['feature_engineering'] = {
        'feature_definitions': feature_definitions,
        'feature_names': feature_names,
        'engineering_stats': feature_engineering_stats,
        'total_accounts_processed': len(account_features),
        'feature_vector_length': len(feature_names),
        'normalization_note': 'Features require StandardScaler normalization before model training'
    }
    
    return account_features, feature_names, feature_definitions

account_features, feature_names, feature_definitions = engineer_annotated_account_features(
    transactions_df, accounts_df, node_mappings, annotated_constructor
)

print(f"Engineered {len(feature_names)} annotated behavioral features per account")
print("Feature categories created with full documentation")

Engineering annotated account-level behavioral features...


In [None]:
print("Building annotated Transaction Flow Graph...")

def build_annotated_transaction_flow_graph(transactions_df, node_mappings, account_features, annotated_constructor):
    """Build transaction flow graph with comprehensive annotations"""
    
    graph_annotation = {
        'graph_name': 'Transaction Flow Graph',
        'graph_type': 'Homogeneous Directed Graph',
        'purpose': 'Direct money transfer pattern detection for fan-out, fan-in, and layering schemes',
        'node_type': 'Bank Accounts',
        'edge_type': 'Direct Money Transfers',
        'construction_method': 'One edge per transaction between accounts',
        'key_applications': [
            'Fan-out pattern detection (single source to multiple destinations)',
            'Fan-in pattern detection (multiple sources to single destination)', 
            'Direct layering scheme identification',
            'Account-level risk scoring',
            'Transaction flow visualization for investigators'
        ],
        'regulatory_compliance': 'Supports SAR (Suspicious Activity Report) generation with clear transaction chains'
    }
    
    edge_list = []
    edge_attributes = []
    
    payment_format_encoder = LabelEncoder()
    currency_encoder = LabelEncoder()
    
    all_payment_formats = transactions_df['payment_format'].unique()
    all_currencies = transactions_df['payment_currency'].unique()
    
    payment_format_encoder.fit(all_payment_formats)
    currency_encoder.fit(all_currencies)
    
    edge_feature_definitions = {
        'amount_paid': 'Transaction amount in original currency (log-normalized recommended)',
        'amount_received': 'Transaction amount received (for currency conversion analysis)',
        'payment_format_encoded': f'Payment method encoded (0-{len(all_payment_formats)-1}): {list(all_payment_formats)}',
        'currency_encoded': f'Currency encoded (0-{len(all_currencies)-1}): {list(all_currencies)[:10]}...',
        'hour_of_day': 'Transaction hour (0-23) for temporal pattern analysis',
        'day_of_week': 'Day of week (0-6) for weekly pattern analysis',
        'is_laundering': 'Ground truth money laundering label (0=normal, 1=ML)'
    }
    
    for _, txn in transactions_df.iterrows():
        if txn['account_origin'] in node_mappings['accounts'] and txn['account_destination'] in node_mappings['accounts']:
            source_idx = node_mappings['accounts'][txn['account_origin']]
            target_idx = node_mappings['accounts'][txn['account_destination']]
            
            edge_list.append([source_idx, target_idx])
            
            hour_of_day = txn['timestamp'].hour
            day_of_week = txn['timestamp'].dayofweek
            
            edge_attr = [
                txn['amount_paid'],
                txn['amount_received'], 
                payment_format_encoder.transform([txn['payment_format']])[0],
                currency_encoder.transform([txn['payment_currency']])[0],
                hour_of_day,
                day_of_week,
                int(txn['is_laundering'])
            ]
            edge_attributes.append(edge_attr)
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attributes, dtype=torch.float)
    
    account_list = sorted(node_mappings['accounts'].keys())
    node_features = []
    node_labels = []
    
    for account in account_list:
        features = account_features[account]
        node_features.append(features)
        
        account_ml_involvement = transactions_df[
            (transactions_df['account_origin'] == account) | 
            (transactions_df['account_destination'] == account)
        ]['is_laundering'].sum()
        
        node_labels.append(1 if account_ml_involvement > 0 else 0)
    
    scaler = StandardScaler()
    node_features_scaled = scaler.fit_transform(node_features)
    
    x = torch.tensor(node_features_scaled, dtype=torch.float)
    y = torch.tensor(node_labels, dtype=torch.long)
    
    transaction_flow_graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        num_nodes=len(account_list)
    )
    
    graph_stats = {
        'nodes': transaction_flow_graph.num_nodes,
        'edges': transaction_flow_graph.edge_index.size(1),
        'node_features': transaction_flow_graph.x.size(1),
        'edge_features': transaction_flow_graph.edge_attr.size(1),
        'ml_accounts': y.sum().item(),
        'ml_account_rate': (y.sum().item() / len(account_list)),
        'avg_degree': (edge_index.size(1) / len(account_list)),
        'self_loops': (edge_index[0] == edge_index[1]).sum().item(),
        'density': edge_index.size(1) / (len(account_list) * (len(account_list) - 1))
    }
    
    graph_annotation.update({
        'node_features': feature_definitions,
        'edge_features': edge_feature_definitions,
        'statistics': graph_stats,
        'preprocessing': {
            'node_feature_scaling': 'StandardScaler applied',
            'edge_feature_encoding': 'LabelEncoder for categorical variables',
            'missing_value_handling': 'Zero-filled for inactive accounts'
        },
        'usage_recommendations': {
            'training': 'Primary graph for basic GNN training and validation',
            'explanation': 'Use for account-level risk explanations and direct transfer analysis',
            'investigation': 'Extract ego-networks around suspicious accounts for detailed analysis',
            'validation': 'Compare predictions with ground truth ML labels'
        }
    })
    
    annotated_constructor.annotations['graphs'] = annotated_constructor.annotations.get('graphs', {})
    annotated_constructor.annotations['graphs']['transaction_flow'] = graph_annotation
    
    print(f"Transaction Flow Graph created with annotations:")
    print(f"  Nodes: {graph_stats['nodes']:,}")
    print(f"  Edges: {graph_stats['edges']:,}")
    print(f"  ML accounts: {graph_stats['ml_accounts']:,} ({graph_stats['ml_account_rate']:.4f})")
    print(f"  Average degree: {graph_stats['avg_degree']:.2f}")
    
    return transaction_flow_graph, scaler, payment_format_encoder, currency_encoder, graph_annotation

transaction_flow_graph, feature_scaler, payment_encoder, currency_encoder, tf_annotation = build_annotated_transaction_flow_graph(
    transactions_df, node_mappings, account_features, annotated_constructor
)

annotated_constructor.graphs['transaction_flow'] = transaction_flow_graph

In [None]:
print("Building annotated Temporal Proximity Graph...")

def build_annotated_temporal_proximity_graph(transactions_df, payment_encoder, currency_encoder, annotated_constructor, time_window_hours=24):
    """Build temporal proximity graph with comprehensive annotations"""
    
    graph_annotation = {
        'graph_name': 'Temporal Proximity Graph',
        'graph_type': 'Homogeneous Undirected Graph',
        'purpose': 'Complex multi-transaction money laundering scheme detection',
        'node_type': 'Individual Transactions',
        'edge_type': 'Temporal Relationships',
        'construction_method': f'Connect transactions within {time_window_hours} hour windows',
        'time_window_hours': time_window_hours,
        'key_applications': [
            'Scatter-gather pattern detection (rapid dispersion then collection)',
            'Velocity pattern analysis (unusual transaction frequency)',
            'Coordinated attack identification (simultaneous transactions)',
            'Temporal clustering of suspicious activities',
            'Transaction sequence analysis for complex schemes'
        ],
        'regulatory_compliance': 'Enables detection of sophisticated multi-step laundering operations',
        'computational_complexity': 'O(n*k) where n=transactions, k=average neighbors in time window'
    }
    
    transactions_sorted = transactions_df.sort_values('timestamp').reset_index(drop=True)
    
    edge_list = []
    edge_attributes = []
    
    edge_feature_definitions = {
        'time_gap_hours': f'Time difference between transactions (0-{time_window_hours} hours)',
        'account_overlap': 'Binary indicator of shared accounts between transactions',
        'amount_similarity': 'Normalized amount similarity (0-1, higher = more similar)',
        'source_is_ml': 'Whether source transaction is money laundering (ground truth)',
        'target_is_ml': 'Whether target transaction is money laundering (ground truth)'
    }
    
    print(f"Building temporal edges with {time_window_hours}h window...")
    
    for i in range(len(transactions_sorted)):
        current_txn = transactions_sorted.iloc[i]
        current_time = current_txn['timestamp']
        
        window_start = i + 1
        window_end = min(i + 1000, len(transactions_sorted))
        
        for j in range(window_start, window_end):
            next_txn = transactions_sorted.iloc[j]
            next_time = next_txn['timestamp']
            
            time_diff = (next_time - current_time).total_seconds() / 3600
            
            if time_diff > time_window_hours:
                break
            
            account_overlap = 0
            if (current_txn['account_origin'] == next_txn['account_origin'] or
                current_txn['account_destination'] == next_txn['account_destination'] or
                current_txn['account_origin'] == next_txn['account_destination'] or
                current_txn['account_destination'] == next_txn['account_origin']):
                account_overlap = 1
            
            amount_similarity = 1 - abs(current_txn['amount_paid'] - next_txn['amount_paid']) / max(
                current_txn['amount_paid'] + next_txn['amount_paid'], 1)
            
            edge_list.append([i, j])
            edge_attributes.append([
                time_diff,
                account_overlap,
                amount_similarity,
                int(current_txn['is_laundering']),
                int(next_txn['is_laundering'])
            ])
        
        if i % 100000 == 0:
            print(f"Processed {i:,} transactions for temporal edges")
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attributes, dtype=torch.float)
    
    transaction_features = []
    transaction_labels = []
    
    node_feature_definitions = {
        'amount_paid': 'Transaction amount (requires log normalization)',
        'hour_of_day': 'Transaction hour (0-23)',
        'day_of_week': 'Day of week (0-6)',
        'payment_format_encoded': 'Encoded payment method',
        'currency_encoded': 'Encoded currency type'
    }
    
    for _, txn in transactions_sorted.iterrows():
        features = [
            txn['amount_paid'],
            txn['timestamp'].hour,
            txn['timestamp'].dayofweek,
            payment_encoder.transform([txn['payment_format']])[0],
            currency_encoder.transform([txn['payment_currency']])[0]
        ]
        transaction_features.append(features)
        transaction_labels.append(int(txn['is_laundering']))
    
    scaler_temporal = StandardScaler()
    transaction_features_scaled = scaler_temporal.fit_transform(transaction_features)
    
    x = torch.tensor(transaction_features_scaled, dtype=torch.float)
    y = torch.tensor(transaction_labels, dtype=torch.long)
    
    temporal_proximity_graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        num_nodes=len(transactions_sorted)
    )
    
    graph_stats = {
        'nodes': temporal_proximity_graph.num_nodes,
        'edges': temporal_proximity_graph.edge_index.size(1),
        'node_features': temporal_proximity_graph.x.size(1),
        'edge_features': temporal_proximity_graph.edge_attr.size(1),
        'ml_transactions': y.sum().item(),
        'ml_transaction_rate': (y.sum().item() / len(transactions_sorted)),
        'avg_degree': (edge_index.size(1) / len(transactions_sorted)),
        'avg_time_gap': torch.mean(edge_attr[:, 0]).item(),
        'account_overlap_ratio': torch.mean(edge_attr[:, 1]).item()
    }
    
    graph_annotation.update({
        'node_features': node_feature_definitions,
        'edge_features': edge_feature_definitions,
        'statistics': graph_stats,
        'preprocessing': {
            'temporal_sorting': 'Transactions sorted chronologically for window construction',
            'node_feature_scaling': 'StandardScaler applied to transaction features',
            'edge_construction_optimization': 'Limited window search to prevent quadratic complexity'
        },
        'usage_recommendations': {
            'training': 'Use for complex pattern detection requiring temporal sequence analysis',
            'explanation': 'Identify rapid transaction sequences and coordinated timing patterns',
            'investigation': 'Trace temporal relationships in suspected laundering schemes',
            'validation': 'Verify model attention on temporally clustered ML transactions'
        },
        'memory_considerations': {
            'large_graph_warning': 'High memory usage due to dense temporal connections',
            'recommended_sampling': 'Use temporal subgraphs for training efficiency',
            'batch_processing': 'Process in time-based chunks for large datasets'
        }
    })
    
    annotated_constructor.annotations['graphs']['temporal_proximity'] = graph_annotation
    
    print(f"Temporal Proximity Graph created with annotations:")
    print(f"  Nodes: {graph_stats['nodes']:,}")
    print(f"  Edges: {graph_stats['edges']:,}")
    print(f"  ML transactions: {graph_stats['ml_transactions']:,} ({graph_stats['ml_transaction_rate']:.4f})")
    print(f"  Average time gap: {graph_stats['avg_time_gap']:.2f} hours")
    
    return temporal_proximity_graph, scaler_temporal, graph_annotation

temporal_proximity_graph, temporal_scaler, tp_annotation = build_annotated_temporal_proximity_graph(
    transactions_df, payment_encoder, currency_encoder, annotated_constructor, time_window_hours=24
)

annotated_constructor.graphs['temporal_proximity'] = temporal_proximity_graph

In [None]:
print("Building annotated Account Behavior Graph...")

def build_annotated_account_behavior_graph(node_mappings, account_features, feature_definitions, annotated_constructor, similarity_threshold=0.7):
    """Build account behavior graph with comprehensive annotations"""
    
    graph_annotation = {
        'graph_name': 'Account Behavior Graph',
        'graph_type': 'Homogeneous Undirected Graph',
        'purpose': 'Account risk profiling and criminal network identification',
        'node_type': 'Bank Accounts with Behavioral Features',
        'edge_type': 'Behavioral Similarity Connections',
        'construction_method': f'Cosine similarity > {similarity_threshold} between behavioral feature vectors',
        'similarity_threshold': similarity_threshold,
        'key_applications': [
            'Money mule network identification (accounts with similar suspicious behavior)',
            'Shell company network detection (coordinated fake entities)',
            'Criminal organization structure mapping through behavioral patterns',
            'Account-level risk scoring and clustering',
            'Identification of coordinated account activities'
        ],
        'regulatory_compliance': 'Supports network-based investigation and risk assessment for compliance teams'
    }
    
    account_list = sorted(node_mappings['accounts'].keys())
    
    feature_matrix = []
    for account in account_list:
        feature_matrix.append(account_features[account])
    
    feature_matrix = np.array(feature_matrix)
    
    scaler_behavior = StandardScaler()
    feature_matrix_scaled = scaler_behavior.fit_transform(feature_matrix)
    
    similarity_matrix = cosine_similarity(feature_matrix_scaled)
    
    edge_list = []
    edge_attributes = []
    
    print(f"Computing behavioral similarities with threshold {similarity_threshold}...")
    
    similarity_stats = {
        'total_pairs_evaluated': len(account_list) * (len(account_list) - 1) // 2,
        'edges_created': 0,
        'avg_similarity': 0,
        'max_similarity': 0
    }
    
    all_similarities = []
    
    for i in range(len(account_list)):
        for j in range(i + 1, len(account_list)):
            similarity = similarity_matrix[i, j]
            all_similarities.append(similarity)
            
            if similarity > similarity_threshold:
                edge_list.extend([[i, j], [j, i]])
                edge_attributes.extend([[similarity], [similarity]])
                similarity_stats['edges_created'] += 2
    
    if len(edge_list) == 0:
        print(f"No edges found with threshold {similarity_threshold}, lowering to 0.5...")
        similarity_threshold = 0.5
        
        for i in range(len(account_list)):
            for j in range(i + 1, len(account_list)):
                similarity = similarity_matrix[i, j]
                
                if similarity > similarity_threshold:
                    edge_list.extend([[i, j], [j, i]])
                    edge_attributes.extend([[similarity], [similarity]])
    
    similarity_stats.update({
        'avg_similarity': np.mean(all_similarities),
        'max_similarity': np.max(all_similarities),
        'similarity_threshold_used': similarity_threshold,
        'connectivity_rate': len(edge_list) / (2 * similarity_stats['total_pairs_evaluated']) if similarity_stats['total_pairs_evaluated'] > 0 else 0
    })
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous() if edge_list else torch.empty((2, 0), dtype=torch.long)
    edge_attr = torch.tensor(edge_attributes, dtype=torch.float) if edge_attributes else torch.empty((0, 1), dtype=torch.float)
    
    x = torch.tensor(feature_matrix_scaled, dtype=torch.float)
    
    account_labels = []
    for account in account_list:
        account_ml_involvement = transactions_df[
            (transactions_df['account_origin'] == account) | 
            (transactions_df['account_destination'] == account)
        ]['is_laundering'].sum()
        account_labels.append(1 if account_ml_involvement > 0 else 0)
    
    y = torch.tensor(account_labels, dtype=torch.long)
    
    account_behavior_graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        num_nodes=len(account_list)
    )
    
    edge_feature_definitions = {
        'behavioral_similarity': f'Cosine similarity between behavioral feature vectors (range: {similarity_threshold}-1.0)'
    }
    
    graph_stats = {
        'nodes': account_behavior_graph.num_nodes,
        'edges': account_behavior_graph.edge_index.size(1),
        'node_features': account_behavior_graph.x.size(1),
        'edge_features': account_behavior_graph.edge_attr.size(1),
        'ml_accounts': y.sum().item(),
        'ml_account_rate': (y.sum().item() / len(account_list)),
        'avg_degree': (edge_index.size(1) / len(account_list)) if len(account_list) > 0 else 0,
        'similarity_threshold_used': similarity_threshold,
        'connectivity_metrics': similarity_stats
    }
    
    graph_annotation.update({
        'node_features': feature_definitions,
        'edge_features': edge_feature_definitions,
        'statistics': graph_stats,
        'preprocessing': {
            'feature_scaling': 'StandardScaler applied to behavioral features before similarity computation',
            'similarity_metric': 'Cosine similarity for scale-invariant behavioral comparison',
            'threshold_adaptation': 'Automatic threshold reduction if no edges found with initial threshold'
        },
        'usage_recommendations': {
            'training': 'Use for network-based risk assessment and community detection',
            'explanation': 'Identify similar behavioral patterns and account clustering',
            'investigation': 'Map potential criminal networks through behavioral similarity',
            'validation': 'Verify that ML accounts cluster together behaviorally'
        },
        'interpretation_guidelines': {
            'high_similarity_edges': 'Indicate accounts with very similar transaction patterns',
            'isolated_nodes': 'Accounts with unique behavioral signatures',
            'dense_clusters': 'Potential coordinated account networks',
            'ml_account_clustering': 'ML accounts should show higher interconnectivity'
        }
    })
    
    annotated_constructor.annotations['graphs']['account_behavior'] = graph_annotation
    
    print(f"Account Behavior Graph created with annotations:")
    print(f"  Nodes: {graph_stats['nodes']:,}")
    print(f"  Edges: {graph_stats['edges']:,}")
    print(f"  ML accounts: {graph_stats['ml_accounts']:,} ({graph_stats['ml_account_rate']:.4f})")
    print(f"  Similarity threshold used: {similarity_threshold}")
    print(f"  Average similarity: {similarity_stats['avg_similarity']:.4f}")
    
    return account_behavior_graph, scaler_behavior, graph_annotation

account_behavior_graph, behavior_scaler, ab_annotation = build_annotated_account_behavior_graph(
    node_mappings, account_features, feature_definitions, annotated_constructor, similarity_threshold=0.7
    )
annotated_constructor.graphs['account_behavior'] = account_behavior_graph


In [None]:
print("Building annotated Multi-Modal Integration Graph...")

def build_annotated_multimodal_integration_graph(transactions_df, node_mappings, account_features, feature_definitions, payment_encoder, currency_encoder, annotated_constructor):
    """Build multi-modal integration graph with comprehensive annotations"""
    
    graph_annotation = {
        'graph_name': 'Multi-Modal Integration Graph',
        'graph_type': 'Heterogeneous Directed Graph',
        'purpose': 'Comprehensive AML detection with full institutional and behavioral context',
        'node_types': ['Bank Accounts', 'Financial Institutions'],
        'edge_types': ['Account-to-Account Transfers', 'Account-to-Bank Relationships'],
        'construction_method': 'Heterogeneous graph combining multiple entity types and relationship types',
        'key_applications': [
            'Cross-institutional money laundering detection',
            'Multi-bank coordination analysis',
            'Comprehensive risk assessment with full context',
            'Currency-based laundering pattern detection',
            'Institution-level compliance monitoring'
        ],
        'regulatory_compliance': 'Complete institutional view for comprehensive SAR reporting and cross-bank analysis',
        'heterogeneous_features': 'Different feature sets for accounts vs banks with unified dimensionality'
    }
    
    account_list = sorted(node_mappings['accounts'].keys())
    bank_list = sorted(node_mappings['banks'].keys())
    
    node_types = []
    node_features_list = []
    
    account_node_mapping = {}
    bank_node_mapping = {}
    node_idx = 0
    
    # Process account nodes
    for account in account_list:
        account_node_mapping[account] = node_idx
        node_types.append(0)  # 0 = account
        features = account_features[account]
        padded_features = features + [0] * (20 - len(features))
        node_features_list.append(padded_features[:20])
        node_idx += 1
    
    # Engineer bank features
    bank_features = {}
    bank_feature_definitions = {
        'total_transactions': 'Total number of transactions involving this bank',
        'total_volume': 'Total transaction volume (USD)',
        'avg_transaction_amount': 'Average transaction amount for this bank',
        'ml_transactions': 'Number of money laundering transactions',
        'currency_diversity': 'Number of different currencies handled',
        'payment_format_diversity': 'Number of different payment formats supported'
    }
    
    for bank in bank_list:
        bank_txns = transactions_df[
            (transactions_df['from_bank'] == bank) | 
            (transactions_df['to_bank'] == bank)
        ]
        
        bank_feature = [
            len(bank_txns),
            bank_txns['amount_paid'].sum(),
            bank_txns['amount_paid'].mean() if len(bank_txns) > 0 else 0,
            bank_txns['is_laundering'].sum(),
            bank_txns['payment_currency'].nunique(),
            bank_txns['payment_format'].nunique()
        ]
        bank_features[bank] = bank_feature
    
    # Process bank nodes
    for bank in bank_list:
        bank_node_mapping[bank] = node_idx
        node_types.append(1)  # 1 = bank
        features = bank_features[bank]
        padded_features = features + [0] * (20 - len(features))
        node_features_list.append(padded_features[:20])
        node_idx += 1
    
    edge_list = []
    edge_attributes = []
    
    edge_feature_definitions = {
        'transaction_amount': 'Amount transferred (0 for institutional relationships)',
        'payment_format_encoded': 'Encoded payment format (0 for institutional relationships)',
        'currency_encoded': 'Encoded currency (0 for institutional relationships)', 
        'relationship_type': 'Edge type (0=transfer, 1=institutional_relationship)'
    }
    
    # Add transaction edges (account to account)
    for _, txn in transactions_df.iterrows():
        if (txn['account_origin'] in account_node_mapping and 
            txn['account_destination'] in account_node_mapping):
            
            source_account_idx = account_node_mapping[txn['account_origin']]
            target_account_idx = account_node_mapping[txn['account_destination']]
            
            edge_list.append([source_account_idx, target_account_idx])
            edge_attributes.append([
                txn['amount_paid'],
                payment_encoder.transform([txn['payment_format']])[0],
                currency_encoder.transform([txn['payment_currency']])[0],
                0  # transfer relationship
            ])
    
    # Add institutional edges (account to bank)
    institutional_edges = 0
    for _, txn in transactions_df.iterrows():
        if txn['account_origin'] in account_node_mapping and txn['from_bank'] in bank_node_mapping:
            source_account_idx = account_node_mapping[txn['account_origin']]
            from_bank_idx = bank_node_mapping[txn['from_bank']]
            edge_list.append([source_account_idx, from_bank_idx])
            edge_attributes.append([0, 0, 0, 1])  # institutional relationship
            institutional_edges += 1
        
        if txn['account_destination'] in account_node_mapping and txn['to_bank'] in bank_node_mapping:
            target_account_idx = account_node_mapping[txn['account_destination']]
            to_bank_idx = bank_node_mapping[txn['to_bank']]
            edge_list.append([target_account_idx, to_bank_idx])
            edge_attributes.append([0, 0, 0, 1])  # institutional relationship
            institutional_edges += 1
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attributes, dtype=torch.float)
    
    scaler_multimodal = StandardScaler()
    node_features_scaled = scaler_multimodal.fit_transform(node_features_list)
    
    x = torch.tensor(node_features_scaled, dtype=torch.float)
    node_type = torch.tensor(node_types, dtype=torch.long)
    
    # Create labels (ML involvement for both accounts and banks)
    node_labels = []
    for account in account_list:
        account_ml_involvement = transactions_df[
            (transactions_df['account_origin'] == account) | 
            (transactions_df['account_destination'] == account)
        ]['is_laundering'].sum()
        node_labels.append(1 if account_ml_involvement > 0 else 0)
    
    for bank in bank_list:
        bank_ml_involvement = transactions_df[
            ((transactions_df['from_bank'] == bank) | 
             (transactions_df['to_bank'] == bank)) & 
            (transactions_df['is_laundering'] == 1)
        ].shape[0]
        node_labels.append(1 if bank_ml_involvement > 0 else 0)
    
    y = torch.tensor(node_labels, dtype=torch.long)
    
    multimodal_graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        node_type=node_type,
        num_nodes=len(node_features_list)
    )
    
    graph_stats = {
        'total_nodes': multimodal_graph.num_nodes,
        'account_nodes': len(account_list),
        'bank_nodes': len(bank_list),
        'total_edges': multimodal_graph.edge_index.size(1),
        'transfer_edges': multimodal_graph.edge_index.size(1) - institutional_edges,
        'institutional_edges': institutional_edges,
        'node_features': multimodal_graph.x.size(1),
        'edge_features': multimodal_graph.edge_attr.size(1),
        'ml_entities': y.sum().item(),
        'ml_entity_rate': (y.sum().item() / len(node_features_list)),
        'heterogeneous_ratio': len(bank_list) / len(account_list)
    }
    
    unified_feature_definitions = {}
    for i, feature_name in enumerate(list(feature_definitions.keys())[:17]):
        unified_feature_definitions[f'feature_{i:02d}_{feature_name}'] = feature_definitions[feature_name]
    for i, feature_name in enumerate(list(bank_feature_definitions.keys())[:6]):
        unified_feature_definitions[f'feature_{i+17:02d}_{feature_name}_bank'] = bank_feature_definitions[feature_name] + ' (bank nodes only)'
    
    graph_annotation.update({
        'node_features': unified_feature_definitions,
        'edge_features': edge_feature_definitions,
        'bank_features': bank_feature_definitions,
        'statistics': graph_stats,
        'preprocessing': {
            'feature_unification': 'Account and bank features padded to common 20-dimensional space',
            'node_type_encoding': '0=account, 1=bank for heterogeneous processing',
            'edge_type_encoding': '0=transfer, 1=institutional_relationship',
            'scaling': 'StandardScaler applied across all node types'
        },
        'usage_recommendations': {
            'training': 'Primary graph for production deployment with full institutional context',
            'explanation': 'Complete multi-entity explanations for complex investigations',
            'investigation': 'Cross-institutional analysis and bank-level risk assessment',
            'validation': 'Comprehensive validation using both account and bank-level ground truth'
        },
        'heterogeneous_considerations': {
            'node_type_handling': 'Use node_type tensor for type-specific processing in GNNs',
            'edge_type_handling': 'Filter by relationship_type for different analysis modes',
            'feature_interpretation': 'First 17 features are account-based, remaining are bank-based',
            'training_strategy': 'Consider separate loss functions for different node types'
        }
    })
    
    annotated_constructor.annotations['graphs']['multimodal_integration'] = graph_annotation
    
    print(f"Multi-Modal Integration Graph created with annotations:")
    print(f"  Total nodes: {graph_stats['total_nodes']:,}")
    print(f"  Account nodes: {graph_stats['account_nodes']:,}")
    print(f"  Bank nodes: {graph_stats['bank_nodes']:,}")
    print(f"  Transfer edges: {graph_stats['transfer_edges']:,}")
    print(f"  Institutional edges: {graph_stats['institutional_edges']:,}")
    print(f"  ML entities: {graph_stats['ml_entities']:,}")
    
    return multimodal_graph, scaler_multimodal, account_node_mapping, bank_node_mapping, graph_annotation

multimodal_graph, multimodal_scaler, account_node_mapping, bank_node_mapping, mm_annotation = build_annotated_multimodal_integration_graph(
    transactions_df, node_mappings, account_features, feature_definitions, payment_encoder, currency_encoder, annotated_constructor
)

annotated_constructor.graphs['multimodal_integration'] = multimodal_graph

In [None]:
print("Building annotated Ground Truth Pattern Graph...")

def build_annotated_ground_truth_pattern_graph(pattern_transactions_df, node_mappings, account_features, feature_definitions, annotated_constructor):
    """Build ground truth pattern graph with comprehensive annotations"""
    
    if len(pattern_transactions_df) == 0:
        print("No pattern transactions available")
        return None, None, None
    
    graph_annotation = {
        'graph_name': 'Ground Truth Pattern Graph',
        'graph_type': 'Homogeneous Directed Graph with Pattern Metadata',
        'purpose': 'Model validation and explanation quality assessment using labeled money laundering patterns',
        'node_type': 'Accounts Involved in Known Money Laundering Patterns',
        'edge_type': 'Pattern-Based Transactions with Sequence Information',
        'construction_method': 'Subset of accounts and transactions from labeled ML patterns',
        'pattern_types_included': list(pattern_transactions_df['pattern_type'].unique()) if len(pattern_transactions_df) > 0 else [],
        'key_applications': [
            'GNN attention weight validation against known patterns',
            'Explanation quality assessment',
            'Pattern-specific model performance evaluation',
            'Attention mechanism calibration',
            'Regulatory explanation validation'
        ],
        'regulatory_compliance': 'Provides ground truth for validating AI explanations against expert knowledge',
        'validation_purpose': 'Critical for ensuring model explanations align with known laundering structures'
    }
    
    pattern_accounts = set()
    pattern_accounts.update(pattern_transactions_df['account_origin'].unique())
    pattern_accounts.update(pattern_transactions_df['account_destination'].unique())
    
    pattern_accounts = [acc for acc in pattern_accounts if acc in node_mappings['accounts']]
    
    if len(pattern_accounts) == 0:
        print("No pattern accounts found in node mappings")
        return None, None, None
    
    pattern_account_mapping = {acc: idx for idx, acc in enumerate(pattern_accounts)}
    
    edge_list = []
    edge_attributes = []
    
    pattern_encoder = LabelEncoder()
    pattern_types = pattern_transactions_df['pattern_type'].unique()
    pattern_encoder.fit(pattern_types)
    
    edge_feature_definitions = {
        'pattern_type_encoded': f'Money laundering pattern type (0-{len(pattern_types)-1}): {list(pattern_types)}',
        'transaction_sequence': 'Position of transaction within the pattern (1-based)',
        'total_pattern_transactions': 'Total number of transactions in this pattern instance',
        'transaction_amount': 'Amount of this specific pattern transaction'
    }
    
    pattern_statistics = {
        'total_pattern_types': len(pattern_types),
        'pattern_type_distribution': pattern_transactions_df['pattern_type'].value_counts().to_dict(),
        'avg_pattern_length': pattern_transactions_df.groupby('pattern_id')['total_txns_in_pattern'].first().mean(),
        'total_pattern_instances': pattern_transactions_df['pattern_id'].nunique()
    }
    
    for _, txn in pattern_transactions_df.iterrows():
        if (txn['account_origin'] in pattern_account_mapping and 
            txn['account_destination'] in pattern_account_mapping):
            
            source_idx = pattern_account_mapping[txn['account_origin']]
            target_idx = pattern_account_mapping[txn['account_destination']]
            
            edge_list.append([source_idx, target_idx])
            edge_attributes.append([
                pattern_encoder.transform([txn['pattern_type']])[0],
                txn['txn_sequence'],
                txn['total_txns_in_pattern'],
                txn['amount_paid']
            ])
    
    if len(edge_list) == 0:
        print("No edges found in pattern graph")
        return None, None, None
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attributes, dtype=torch.float)
    
    node_features = []
    for account in pattern_accounts:
        features = account_features[account]
        node_features.append(features)
    
    scaler_pattern = StandardScaler()
    node_features_scaled = scaler_pattern.fit_transform(node_features)
    
    x = torch.tensor(node_features_scaled, dtype=torch.float)
    y = torch.ones(len(pattern_accounts), dtype=torch.long)  # All nodes are ML-involved
    
    pattern_graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        num_nodes=len(pattern_accounts)
    )
    
    graph_stats = {
        'nodes': pattern_graph.num_nodes,
        'edges': pattern_graph.edge_index.size(1),
        'node_features': pattern_graph.x.size(1),
        'edge_features': pattern_graph.edge_attr.size(1),
        'pattern_types': len(pattern_types),
        'ml_label_rate': 1.0,  # All accounts are ML-involved by definition
        'avg_degree': (edge_index.size(1) / len(pattern_accounts)),
        'pattern_coverage': len(pattern_accounts) / len(node_mappings['accounts'])
    }
    
    graph_annotation.update({
        'node_features': feature_definitions,
        'edge_features': edge_feature_definitions,
        'pattern_statistics': pattern_statistics,
        'statistics': graph_stats,
        'preprocessing': {
            'account_filtering': 'Only accounts involved in labeled patterns included',
            'pattern_encoding': 'LabelEncoder for pattern types with sequence preservation',
            'feature_consistency': 'Same behavioral features as other graphs for comparison'
        },
        'usage_recommendations': {
            'training': 'Use for specialized pattern-aware training and validation',
            'explanation': 'Gold standard for validating attention weight quality',
            'investigation': 'Reference for understanding known pattern structures',
            'validation': 'Primary validation graph for explanation quality assessment'
        },
        'validation_applications': {
            'attention_validation': 'Compare GAT attention weights with known pattern edges',
            'explanation_quality': 'Measure explanation coverage of known patterns',
            'pattern_detection': 'Verify model can identify different pattern types',
            'sequence_analysis': 'Validate temporal sequence understanding'
        }
    })
    
    annotated_constructor.annotations['graphs']['ground_truth_patterns'] = graph_annotation
    
    print(f"Ground Truth Pattern Graph created with annotations:")
    print(f"  Nodes: {graph_stats['nodes']:,}")
    print(f"  Edges: {graph_stats['edges']:,}")
    print(f"  Pattern types: {graph_stats['pattern_types']}")
    print(f"  Pattern coverage: {graph_stats['pattern_coverage']:.4f}")
    
    return pattern_graph, pattern_account_mapping, graph_annotation

pattern_graph, pattern_account_mapping, pt_annotation = build_annotated_ground_truth_pattern_graph(
    pattern_transactions_df, node_mappings, account_features, feature_definitions, annotated_constructor
)

if pattern_graph is not None:
    annotated_constructor.graphs['ground_truth_patterns'] = pattern_graph

In [None]:
print("Building annotated Ego-Network Extraction Framework...")

def create_annotated_ego_network_extractor(transaction_flow_graph, node_mappings, annotated_constructor, k_hops=2):
    """Create ego-network extractor with comprehensive annotations"""
    
    framework_annotation = {
        'framework_name': 'Ego-Network Extraction System',
        'purpose': 'Focused subgraph extraction for real-time investigation and explanation',
        'extraction_method': f'{k_hops}-hop neighborhood extraction around target accounts',
        'k_hops': k_hops,
        'source_graph': 'Transaction Flow Graph',
        'key_applications': [
            'Real-time focused investigation around suspicious accounts',
            'Manageable subgraph visualization for compliance teams',
            'Detailed explanation generation for specific accounts',
            'Local pattern analysis within account neighborhoods',
            'Scalable analysis of large transaction networks'
        ],
        'regulatory_compliance': 'Enables focused analysis for specific account investigations in SARs',
        'computational_efficiency': 'Reduces graph complexity for real-time processing'
    }
    
    def extract_annotated_ego_network(target_account, k=k_hops):
        """Extract ego-network with detailed annotations"""
        
        if target_account not in node_mappings['accounts']:
            return None
        
        target_idx = node_mappings['accounts'][target_account]
        
        edge_index = transaction_flow_graph.edge_index
        current_nodes = {target_idx}
        hop_nodes = {0: {target_idx}}
        
        # Multi-hop expansion with tracking
        for hop in range(k):
            next_nodes = set()
            
            for node in current_nodes:
                neighbors_out = edge_index[1][edge_index[0] == node].tolist()
                neighbors_in = edge_index[0][edge_index[1] == node].tolist()
                next_nodes.update(neighbors_out + neighbors_in)
            
            new_nodes = next_nodes - current_nodes
            hop_nodes[hop + 1] = new_nodes
            current_nodes.update(next_nodes)
        
        ego_nodes = sorted(list(current_nodes))
        
        if len(ego_nodes) <= 1:
            return None
        
        node_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(ego_nodes)}
        
        ego_edge_list = []
        ego_edge_attrs = []
        edge_metadata = []
        
        for i, (src, dst) in enumerate(edge_index.t().tolist()):
            if src in node_mapping and dst in node_mapping:
                ego_edge_list.append([node_mapping[src], node_mapping[dst]])
                ego_edge_attrs.append(transaction_flow_graph.edge_attr[i].tolist())
                
                # Add edge metadata for explanation
                src_hop = None
                dst_hop = None
                for hop, nodes in hop_nodes.items():
                    if src in nodes:
                        src_hop = hop
                    if dst in nodes:
                        dst_hop = hop
                
                edge_metadata.append({
                    'source_hop': src_hop,
                    'target_hop': dst_hop,
                    'edge_type': 'internal' if src_hop == dst_hop else 'cross_hop'
                })
        
        if len(ego_edge_list) == 0:
            return None
        
        ego_edge_index = torch.tensor(ego_edge_list, dtype=torch.long).t().contiguous()
        ego_edge_attr = torch.tensor(ego_edge_attrs, dtype=torch.float)
        
        ego_x = transaction_flow_graph.x[ego_nodes]
        ego_y = transaction_flow_graph.y[ego_nodes]
        
        # Create hop-based node annotations
        node_hop_labels = torch.zeros(len(ego_nodes), dtype=torch.long)
        for hop, nodes in hop_nodes.items():
            for node_idx in nodes:
                if node_idx in node_mapping:
                    node_hop_labels[node_mapping[node_idx]] = hop
        
        ego_graph = Data(
            x=ego_x,
            edge_index=ego_edge_index,
            edge_attr=ego_edge_attr,
            y=ego_y,
            hop_labels=node_hop_labels,
            num_nodes=len(ego_nodes)
        )
        
        # Create extraction metadata
        extraction_metadata = {
            'target_account': target_account,
            'target_node_idx': target_idx,
            'extraction_timestamp': datetime.now().isoformat(),
            'k_hops': k,
            'total_nodes_extracted': len(ego_nodes),
            'nodes_by_hop': {hop: len(nodes) for hop, nodes in hop_nodes.items()},
            'total_edges_extracted': ego_edge_index.size(1),
            'ml_nodes_in_ego': ego_y.sum().item(),
            'ml_rate_in_ego': ego_y.float().mean().item(),
            'coverage_ratio': len(ego_nodes) / transaction_flow_graph.num_nodes,
            'hop_distribution': node_hop_labels.bincount().tolist(),
            'edge_metadata': edge_metadata
        }
        
        return ego_graph, ego_nodes, node_mapping, extraction_metadata
    
    # Test extraction system
    ml_accounts = transactions_df[transactions_df['is_laundering'] == 1]['account_origin'].unique()[:5]
    
    sample_ego_networks = {}
    extraction_stats = {
        'total_extractions': 0,
        'successful_extractions': 0,
        'avg_nodes_per_ego': 0,
        'avg_edges_per_ego': 0,
        'avg_ml_rate_per_ego': 0
    }
    
    print("Testing ego-network extraction on ML accounts...")
    
    total_nodes = 0
    total_edges = 0
    total_ml_rates = 0
    
    for account in ml_accounts:
        extraction_stats['total_extractions'] += 1
        ego_result = extract_annotated_ego_network(account)
        
        if ego_result is not None:
            ego_graph, ego_nodes, ego_mapping, ego_metadata = ego_result
            sample_ego_networks[account] = {
                'graph': ego_graph,
                'metadata': ego_metadata
            }
            extraction_stats['successful_extractions'] += 1
            total_nodes += ego_graph.num_nodes
            total_edges += ego_graph.edge_index.size(1)
            total_ml_rates += ego_metadata['ml_rate_in_ego']
            
            print(f"  Account {account}: {ego_graph.num_nodes} nodes, {ego_graph.edge_index.size(1)} edges, ML rate: {ego_metadata['ml_rate_in_ego']:.3f}")
    
    if extraction_stats['successful_extractions'] > 0:
        extraction_stats['avg_nodes_per_ego'] = total_nodes / extraction_stats['successful_extractions']
        extraction_stats['avg_edges_per_ego'] = total_edges / extraction_stats['successful_extractions']
        extraction_stats['avg_ml_rate_per_ego'] = total_ml_rates / extraction_stats['successful_extractions']
    
    framework_annotation.update({
        'extraction_function': extract_annotated_ego_network,
        'sample_extractions': len(sample_ego_networks),
        'extraction_statistics': extraction_stats,
        'usage_recommendations': {
            'real_time_investigation': 'Extract ego-networks for suspicious accounts during investigation',
            'explanation_generation': 'Use for focused, interpretable explanations',
            'visualization': 'Create manageable network visualizations for compliance teams',
            'validation': 'Compare model attention with ego-network structure'
        },
        'output_components': {
            'ego_graph': 'PyTorch Geometric Data object with subgraph',
            'ego_nodes': 'List of original node indices in ego-network',
            'node_mapping': 'Mapping from original to ego-network indices',
            'extraction_metadata': 'Detailed metadata about extraction process and results'
        }
    })
    
    annotated_constructor.annotations['frameworks'] = annotated_constructor.annotations.get('frameworks', {})
    annotated_constructor.annotations['frameworks']['ego_network_extraction'] = framework_annotation
    
    return extract_annotated_ego_network, sample_ego_networks, framework_annotation

ego_network_extractor, sample_ego_networks, ego_annotation = create_annotated_ego_network_extractor(
    transaction_flow_graph, node_mappings, annotated_constructor, k_hops=2
)

annotated_constructor.graphs['ego_networks'] = sample_ego_networks
print(f"Created annotated ego-network extractor with {len(sample_ego_networks)} sample networks")

In [None]:
def save_annotated_graphs_with_metadata(annotated_constructor):
    """Save all graphs and annotations with comprehensive metadata"""
    
    # Save graphs with proper error handling
    graph_save_data = {}
    
    # Check and save each graph individually
    if 'transaction_flow' in annotated_constructor.graphs and annotated_constructor.graphs['transaction_flow'] is not None:
        graph_save_data['transaction_flow_graph'] = annotated_constructor.graphs['transaction_flow']
    
    if 'temporal_proximity' in annotated_constructor.graphs and annotated_constructor.graphs['temporal_proximity'] is not None:
        graph_save_data['temporal_proximity_graph'] = annotated_constructor.graphs['temporal_proximity']
    
    if 'account_behavior' in annotated_constructor.graphs and annotated_constructor.graphs['account_behavior'] is not None:
        graph_save_data['account_behavior_graph'] = annotated_constructor.graphs['account_behavior']
    
    if 'multimodal_integration' in annotated_constructor.graphs and annotated_constructor.graphs['multimodal_integration'] is not None:
        graph_save_data['multimodal_integration_graph'] = annotated_constructor.graphs['multimodal_integration']
    
    if 'ground_truth_patterns' in annotated_constructor.graphs and annotated_constructor.graphs['ground_truth_patterns'] is not None:
        graph_save_data['ground_truth_pattern_graph'] = annotated_constructor.graphs['ground_truth_patterns']
    
    if 'ego_networks' in annotated_constructor.graphs and annotated_constructor.graphs['ego_networks'] is not None:
        graph_save_data['sample_ego_networks'] = annotated_constructor.graphs['ego_networks']
    
    torch.save(graph_save_data, 'data/annotated_graphs.pt')
    
    # Save preprocessors with metadata - need to handle globals properly
    preprocessor_save_data = {
        'preprocessor_metadata': {}
    }
    
    # Only save preprocessors that exist
    if 'feature_scaler' in globals() and feature_scaler is not None:
        preprocessor_save_data['feature_scaler'] = feature_scaler
        preprocessor_save_data['preprocessor_metadata']['feature_scaler_info'] = {
            'type': 'StandardScaler',
            'applied_to': 'Account behavioral features',
            'feature_count': len(feature_names) if 'feature_names' in globals() else 0,
            'mean_values': feature_scaler.mean_.tolist(),
            'scale_values': feature_scaler.scale_.tolist()
        }
    
    if 'temporal_scaler' in globals() and temporal_scaler is not None:
        preprocessor_save_data['temporal_scaler'] = temporal_scaler
        preprocessor_save_data['preprocessor_metadata']['temporal_scaler_info'] = {
            'type': 'StandardScaler', 
            'applied_to': 'Transaction-level features for temporal graph',
            'feature_count': temporal_scaler.n_features_in_
        }
    
    if 'behavior_scaler' in globals() and behavior_scaler is not None:
        preprocessor_save_data['behavior_scaler'] = behavior_scaler
        preprocessor_save_data['preprocessor_metadata']['behavior_scaler_info'] = {
            'type': 'StandardScaler',
            'applied_to': 'Behavioral features for similarity computation',
            'feature_count': behavior_scaler.n_features_in_
        }
    
    if 'multimodal_scaler' in globals() and multimodal_scaler is not None:
        preprocessor_save_data['multimodal_scaler'] = multimodal_scaler
        preprocessor_save_data['preprocessor_metadata']['multimodal_scaler_info'] = {
            'type': 'StandardScaler',
            'applied_to': 'Unified account and bank features',
            'feature_count': multimodal_scaler.n_features_in_
        }
    
    if 'payment_encoder' in globals() and payment_encoder is not None:
        preprocessor_save_data['payment_encoder'] = payment_encoder
        preprocessor_save_data['preprocessor_metadata']['payment_encoder_info'] = {
            'type': 'LabelEncoder',
            'classes': payment_encoder.classes_.tolist(),
            'class_count': len(payment_encoder.classes_)
        }
    
    if 'currency_encoder' in globals() and currency_encoder is not None:
        preprocessor_save_data['currency_encoder'] = currency_encoder
        preprocessor_save_data['preprocessor_metadata']['currency_encoder_info'] = {
            'type': 'LabelEncoder',
            'classes': currency_encoder.classes_.tolist(),
            'class_count': len(currency_encoder.classes_)
        }
    
    torch.save(preprocessor_save_data, 'data/annotated_preprocessors.pt')
    
    # Save comprehensive metadata with proper error handling
    comprehensive_metadata = {
        'project_metadata': annotated_constructor.annotations if hasattr(annotated_constructor, 'annotations') else {},
        'creation_info': {
            'creation_timestamp': annotated_constructor.creation_timestamp if hasattr(annotated_constructor, 'creation_timestamp') else datetime.now().isoformat(),
            'version': annotated_constructor.version if hasattr(annotated_constructor, 'version') else "1.0",
            'total_graphs_created': len(annotated_constructor.graphs),
            'data_sources_processed': len(annotated_constructor.annotations.get('data_sources', {})) if hasattr(annotated_constructor, 'annotations') else 0,
            'total_annotations': sum(len(v) if isinstance(v, dict) else 1 for v in annotated_constructor.annotations.values()) if hasattr(annotated_constructor, 'annotations') else 0
        }
    }
    
    # Add other metadata if variables exist in global scope
    if 'node_mappings' in globals():
        comprehensive_metadata['node_mappings'] = node_mappings
    
    if 'account_features' in globals():
        comprehensive_metadata['account_features'] = account_features
    
    if 'feature_names' in globals():
        comprehensive_metadata['feature_names'] = feature_names
    
    if 'feature_definitions' in globals():
        comprehensive_metadata['feature_definitions'] = feature_definitions
    
    if 'account_node_mapping' in globals():
        comprehensive_metadata['account_node_mapping'] = account_node_mapping
    
    if 'bank_node_mapping' in globals():
        comprehensive_metadata['bank_node_mapping'] = bank_node_mapping
    
    if 'pattern_account_mapping' in globals():
        comprehensive_metadata['pattern_account_mapping'] = pattern_account_mapping
    elif 'pattern_graph' in globals() and pattern_graph is not None:
        comprehensive_metadata['pattern_account_mapping'] = None
    
    with open('data/comprehensive_graph_metadata.pkl', 'wb') as f:
        pickle.dump(comprehensive_metadata, f)
    
    # Save JSON version for human readability
    try:
        json_serializable_annotations = convert_to_json_serializable(annotated_constructor.annotations)
        with open('data/graph_annotations.json', 'w') as f:
            json.dump(json_serializable_annotations, f, indent=2, default=str)
    except Exception as e:
        print(f"Warning: Could not save JSON annotations: {e}")
        # Create minimal JSON file
        minimal_annotations = {
            'project_info': {
                'name': 'Explainable AML Detection with Graph Neural Networks',
                'version': '1.0',
                'creation_date': datetime.now().isoformat(),
                'status': 'Stage 3 Completed'
            },
            'graphs_created': list(graph_save_data.keys()),
            'total_graphs': len(graph_save_data)
        }
        with open('data/graph_annotations.json', 'w') as f:
            json.dump(minimal_annotations, f, indent=2, default=str)
    
    return {
        'graphs_saved': 'data/annotated_graphs.pt',
        'preprocessors_saved': 'data/annotated_preprocessors.pt',
        'metadata_saved': 'data/comprehensive_graph_metadata.pkl',
        'annotations_saved': 'data/graph_annotations.json',
        'graphs_count': len(graph_save_data),
        'preprocessors_count': len([k for k in preprocessor_save_data.keys() if k != 'preprocessor_metadata'])
    }

def convert_to_json_serializable(obj):
    """Convert complex objects to JSON-serializable format"""
    if obj is None:
        return None
    elif isinstance(obj, dict):
        return {k: convert_to_json_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_json_serializable(item) for item in obj]
    elif isinstance(obj, tuple):
        return list(obj)
    elif isinstance(obj, (np.integer, np.floating)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (int, float, str, bool)):
        return obj
    elif hasattr(obj, '__dict__'):
        return str(obj)
    else:
        try:
            return str(obj)
        except:
            return "<<non-serializable object>>"