In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
import networkx as nx
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
import pickle
import json
import os
from datetime import datetime
import gc
from tqdm import tqdm
import psutil
import warnings
warnings.filterwarnings('ignore')

In [2]:
class OptimizedAMLGraphConstructor:
    def __init__(self, batch_size=10000, checkpoint_dir='checkpoints'):
        self.batch_size = batch_size
        self.checkpoint_dir = checkpoint_dir
        self.graphs = {}
        self.annotations = {}
        self.node_mappings = {}
        self.preprocessors = {}
        self.metadata = {}
        
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        self.creation_timestamp = datetime.now().isoformat()
        self.version = "1.0_optimized"
        
    def save_checkpoint(self, stage, data, step=None):
        checkpoint_name = f"{stage}_{step}" if step else stage
        checkpoint_path = os.path.join(self.checkpoint_dir, f"{checkpoint_name}.pkl")
        
        checkpoint_data = {
            'stage': stage,
            'step': step,
            'data': data,
            'timestamp': datetime.now().isoformat(),
            'memory_usage': psutil.Process().memory_info().rss / 1024 / 1024
        }
        
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(checkpoint_data, f)
        
        print(f"Checkpoint saved: {checkpoint_name} ({checkpoint_data['memory_usage']:.1f}MB)")
        
    def load_checkpoint(self, stage, step=None):
        checkpoint_name = f"{stage}_{step}" if step else stage
        checkpoint_path = os.path.join(self.checkpoint_dir, f"{checkpoint_name}.pkl")
        
        if os.path.exists(checkpoint_path):
            with open(checkpoint_path, 'rb') as f:
                checkpoint_data = pickle.load(f)
            print(f"Loaded checkpoint: {checkpoint_name}")
            return checkpoint_data['data']
        return None
        
    def check_resume_point(self):
        checkpoints = []
        if os.path.exists(self.checkpoint_dir):
            for file in os.listdir(self.checkpoint_dir):
                if file.endswith('.pkl'):
                    checkpoints.append(file.replace('.pkl', ''))
        
        resume_order = [
            'data_loaded', 'node_mappings', 'account_features_batched',
            'transaction_flow_graph', 'temporal_proximity_graph',
            'account_behavior_graph', 'multimodal_integration_graph',
            'ground_truth_pattern_graph', 'ego_networks'
        ]
        
        for checkpoint in reversed(resume_order):
            if checkpoint in checkpoints:
                return checkpoint
        return None

    def monitor_memory(self, stage):
        memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
        print(f"[{stage}] Memory usage: {memory_mb:.1f}MB")
        
        if memory_mb > 14000:
            print("WARNING: High memory usage, forcing garbage collection")
            gc.collect()

In [3]:
def load_and_prepare_data():
    constructor = OptimizedAMLGraphConstructor()
    
    resume_point = constructor.check_resume_point()
    if resume_point:
        print(f"Resuming from checkpoint: {resume_point}")
        return constructor, resume_point
    
    print("Loading data from scratch...")
    
    data = constructor.load_checkpoint('data_loaded')
    if data is None:
        transactions_df = pd.read_csv('../data/processed_transactions.csv')
        accounts_df = pd.read_csv('../data/processed_accounts.csv')
        patterns_df = pd.read_csv('../data/processed_patterns.csv')
        pattern_transactions_df = pd.read_csv('../data/processed_pattern_transactions.csv')
        
        transactions_df['timestamp'] = pd.to_datetime(transactions_df['timestamp'])
        pattern_transactions_df['timestamp'] = pd.to_datetime(pattern_transactions_df['timestamp'])
        
        transactions_df = transactions_df.astype({
            'amount_paid': 'float32',
            'amount_received': 'float32',
            'is_laundering': 'int8'
        })
        
        transactions_df = transactions_df.sort_values('timestamp').reset_index(drop=True)
        
        data = {
            'transactions_df': transactions_df,
            'accounts_df': accounts_df,
            'patterns_df': patterns_df,
            'pattern_transactions_df': pattern_transactions_df
        }
        
        constructor.save_checkpoint('data_loaded', data)
    
    constructor.monitor_memory('data_loaded')
    return constructor, None

In [4]:
def create_optimized_node_mappings(constructor, data):
    node_mappings = constructor.load_checkpoint('node_mappings')
    if node_mappings is not None:
        constructor.node_mappings = node_mappings
        return
    
    print("Creating optimized node mappings...")
    transactions_df = data['transactions_df']
    
    unique_accounts = set(transactions_df['account_origin'].unique()) | set(transactions_df['account_destination'].unique())
    unique_banks = set(transactions_df['from_bank'].unique()) | set(transactions_df['to_bank'].unique())
    
    node_mappings = {
        'accounts': {acc: idx for idx, acc in enumerate(sorted(unique_accounts))},
        'banks': {bank: idx for idx, bank in enumerate(sorted(unique_banks))},
        'transactions': {idx: idx for idx in range(len(transactions_df))}
    }
    
    constructor.node_mappings = node_mappings
    constructor.save_checkpoint('node_mappings', node_mappings)
    constructor.monitor_memory('node_mappings')

def engineer_account_features_batched(constructor, data):
    account_features = constructor.load_checkpoint('account_features_batched')
    if account_features is not None:
        return account_features

    print("Engineering account features with batched processing...")
    transactions_df = data['transactions_df']
    node_mappings = constructor.node_mappings

    accounts = list(node_mappings['accounts'].keys())
    batch_size = constructor.batch_size
    total_batches = (len(accounts) + batch_size - 1) // batch_size

    account_features = {}

    transactions_indexed = transactions_df.set_index(['account_origin', 'account_destination'])

    outgoing_stats = transactions_df.groupby('account_origin').agg({
        'amount_paid': ['count', 'sum', 'mean', 'std'],
        'account_destination': 'nunique',
        'payment_currency': 'nunique',
        'payment_format': 'nunique',
        'is_laundering': 'sum',
        'timestamp': ['min', 'max']
    }).fillna(0)

    incoming_stats = transactions_df.groupby('account_destination').agg({
        'amount_paid': ['count', 'sum', 'mean'],
        'account_origin': 'nunique',
        'is_laundering': 'sum'
    }).fillna(0)

    outgoing_stats.columns = ['_'.join(col).strip() for col in outgoing_stats.columns]
    incoming_stats.columns = ['_'.join(col).strip() for col in incoming_stats.columns]

    feature_names = [
        'total_outgoing', 'total_incoming', 'total_transactions',
        'avg_outgoing_amount', 'avg_incoming_amount',
        'total_outgoing_volume', 'total_incoming_volume',
        'unique_recipients', 'unique_senders',
        'currency_diversity', 'payment_format_diversity',
        'transaction_velocity', 'fan_out_degree', 'fan_in_degree',
        'ml_rate', 'round_amount_ratio', 'outgoing_time_span'
    ]

    for batch_idx in tqdm(range(total_batches), desc="Processing account batches"):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(accounts))
        batch_accounts = accounts[start_idx:end_idx]
        
        for account in batch_accounts:
            out_stats = outgoing_stats.loc[account] if account in outgoing_stats.index else pd.Series(0, index=outgoing_stats.columns)
            in_stats = incoming_stats.loc[account] if account in incoming_stats.index else pd.Series(0, index=incoming_stats.columns)
            
            total_outgoing = out_stats.get('amount_paid_count', 0)
            total_incoming = in_stats.get('amount_paid_count', 0)
            
            time_span_seconds = 0
            if 'timestamp_max' in out_stats and 'timestamp_min' in out_stats:
                if pd.notna(out_stats['timestamp_max']) and pd.notna(out_stats['timestamp_min']):
                    time_span_seconds = (pd.to_datetime(out_stats['timestamp_max']) - pd.to_datetime(out_stats['timestamp_min'])).total_seconds()
            
            outgoing_time_span = time_span_seconds / 3600 if time_span_seconds > 0 else 0
            transaction_velocity = total_outgoing / max(outgoing_time_span, 1)
            
            ml_involvement = out_stats.get('is_laundering_sum', 0) + in_stats.get('is_laundering_sum', 0)
            ml_rate = ml_involvement / max(total_outgoing + total_incoming, 1)
            
            round_amount_ratio = 0
            if total_outgoing > 0:
                outgoing_amounts = transactions_df[transactions_df['account_origin'] == account]['amount_paid']
                if len(outgoing_amounts) > 0:
                    round_amounts = (outgoing_amounts % 100 == 0).sum()
                    round_amount_ratio = round_amounts / len(outgoing_amounts)
            
            features = [
                float(total_outgoing),
                float(total_incoming),
                float(total_outgoing + total_incoming),
                float(out_stats.get('amount_paid_mean', 0)),
                float(in_stats.get('amount_paid_mean', 0)),
                float(out_stats.get('amount_paid_sum', 0)),
                float(in_stats.get('amount_paid_sum', 0)),
                float(out_stats.get('account_destination_nunique', 0)),
                float(in_stats.get('account_origin_nunique', 0)),
                float(out_stats.get('payment_currency_nunique', 0)),
                float(out_stats.get('payment_format_nunique', 0)),
                float(transaction_velocity),
                float(out_stats.get('account_destination_nunique', 0)),
                float(in_stats.get('account_origin_nunique', 0)),
                float(ml_rate),
                float(round_amount_ratio),
                float(outgoing_time_span)
            ]
            
            account_features[account] = features
        
        if batch_idx % 10 == 0:
            constructor.monitor_memory(f'feature_batch_{batch_idx}')
            gc.collect()

    result = {
        'account_features': account_features,
        'feature_names': feature_names
    }

    constructor.save_checkpoint('account_features_batched', result)
    return result

In [5]:
def build_optimized_transaction_flow_graph(constructor, data, account_features_data):
    graph_data = constructor.load_checkpoint('transaction_flow_graph')
    if graph_data is not None:
        constructor.graphs['transaction_flow'] = graph_data['graph']
        constructor.preprocessors.update(graph_data['preprocessors'])
        return
    
    print("Building optimized transaction flow graph...")
    transactions_df = data['transactions_df']
    node_mappings = constructor.node_mappings
    account_features = account_features_data['account_features']
    feature_names = account_features_data['feature_names']
    
    payment_encoder = LabelEncoder()
    currency_encoder = LabelEncoder()
    
    payment_encoder.fit(transactions_df['payment_format'].unique())
    currency_encoder.fit(transactions_df['payment_currency'].unique())
    
    edge_list = []
    edge_attributes = []
    
    batch_size = 100000
    total_batches = (len(transactions_df) + batch_size - 1) // batch_size
    
    for batch_idx in tqdm(range(total_batches), desc="Processing transaction batches"):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(transactions_df))
        batch_df = transactions_df.iloc[start_idx:end_idx]
        
        for _, txn in batch_df.iterrows():
            if txn['account_origin'] in node_mappings['accounts'] and txn['account_destination'] in node_mappings['accounts']:
                source_idx = node_mappings['accounts'][txn['account_origin']]
                target_idx = node_mappings['accounts'][txn['account_destination']]
                
                edge_list.append([source_idx, target_idx])
                
                edge_attr = [
                    float(txn['amount_paid']),
                    float(txn['amount_received']),
                    float(payment_encoder.transform([txn['payment_format']])[0]),
                    float(currency_encoder.transform([txn['payment_currency']])[0]),
                    float(txn['timestamp'].hour),
                    float(txn['timestamp'].dayofweek),
                    float(txn['is_laundering'])
                ]
                edge_attributes.append(edge_attr)
        
        if batch_idx % 10 == 0:
            constructor.monitor_memory(f'tx_flow_batch_{batch_idx}')
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attributes, dtype=torch.float32)
    
    account_list = sorted(node_mappings['accounts'].keys())
    node_features = []
    node_labels = []
    
    ml_accounts = set(transactions_df[transactions_df['is_laundering'] == 1]['account_origin']) | set(transactions_df[transactions_df['is_laundering'] == 1]['account_destination'])
    
    for account in account_list:
        features = account_features[account]
        node_features.append(features)
        node_labels.append(1 if account in ml_accounts else 0)
    
    scaler = StandardScaler()
    node_features_scaled = scaler.fit_transform(node_features)
    
    x = torch.tensor(node_features_scaled, dtype=torch.float32)
    y = torch.tensor(node_labels, dtype=torch.long)
    
    transaction_flow_graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        num_nodes=len(account_list)
    )
    
    graph_data = {
        'graph': transaction_flow_graph,
        'preprocessors': {
            'feature_scaler': scaler,
            'payment_encoder': payment_encoder,
            'currency_encoder': currency_encoder
        }
    }
    
    constructor.graphs['transaction_flow'] = transaction_flow_graph
    constructor.preprocessors.update(graph_data['preprocessors'])
    constructor.save_checkpoint('transaction_flow_graph', graph_data)
    constructor.monitor_memory('transaction_flow_complete')

def build_sampled_temporal_proximity_graph(constructor, data):
    graph_data = constructor.load_checkpoint('temporal_proximity_graph')
    if graph_data is not None:
        constructor.graphs['temporal_proximity'] = graph_data['graph']
        constructor.preprocessors.update(graph_data['preprocessors'])
        return
    
    print("Building sampled temporal proximity graph...")
    transactions_df = data['transactions_df']
    
    sample_size = min(500000, len(transactions_df))
    sampled_indices = np.random.choice(len(transactions_df), sample_size, replace=False)
    sampled_transactions = transactions_df.iloc[sorted(sampled_indices)].reset_index(drop=True)
    
    print(f"Sampled {sample_size:,} transactions from {len(transactions_df):,}")
    
    edge_list = []
    edge_attributes = []
    
    time_window_hours = 24
    max_connections_per_transaction = 50
    
    payment_encoder = constructor.preprocessors['payment_encoder']
    currency_encoder = constructor.preprocessors['currency_encoder']
    
    for i in tqdm(range(len(sampled_transactions)), desc="Building temporal edges"):
        current_txn = sampled_transactions.iloc[i]
        current_time = current_txn['timestamp']
        
        window_start = i + 1
        window_end = min(i + max_connections_per_transaction, len(sampled_transactions))
        
        connections_made = 0
        
        for j in range(window_start, window_end):
            if connections_made >= max_connections_per_transaction:
                break
                
            next_txn = sampled_transactions.iloc[j]
            next_time = next_txn['timestamp']
            
            time_diff = (next_time - current_time).total_seconds() / 3600
            
            if time_diff > time_window_hours:
                break
            
            account_overlap = 0
            if (current_txn['account_origin'] == next_txn['account_origin'] or
                current_txn['account_destination'] == next_txn['account_destination'] or
                current_txn['account_origin'] == next_txn['account_destination'] or
                current_txn['account_destination'] == next_txn['account_origin']):
                account_overlap = 1
            
            amount_similarity = 1 - min(abs(current_txn['amount_paid'] - next_txn['amount_paid']) / max(current_txn['amount_paid'] + next_txn['amount_paid'], 1), 1)
            
            edge_list.append([i, j])
            edge_attributes.append([
                float(time_diff),
                float(account_overlap),
                float(amount_similarity),
                float(current_txn['is_laundering']),
                float(next_txn['is_laundering'])
            ])
            
            connections_made += 1
        
        if i % 50000 == 0:
            constructor.monitor_memory(f'temporal_edge_{i}')
    
    if len(edge_list) == 0:
        print("No temporal edges found, creating minimal graph")
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 5), dtype=torch.float32)
    else:
        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attributes, dtype=torch.float32)
    
    transaction_features = []
    transaction_labels = []
    
    for _, txn in sampled_transactions.iterrows():
        features = [
            float(np.log1p(txn['amount_paid'])),
            float(txn['timestamp'].hour),
            float(txn['timestamp'].dayofweek),
            float(payment_encoder.transform([txn['payment_format']])[0]),
            float(currency_encoder.transform([txn['payment_currency']])[0])
        ]
        transaction_features.append(features)
        transaction_labels.append(int(txn['is_laundering']))
    
    scaler_temporal = StandardScaler()
    transaction_features_scaled = scaler_temporal.fit_transform(transaction_features)
    
    x = torch.tensor(transaction_features_scaled, dtype=torch.float32)
    y = torch.tensor(transaction_labels, dtype=torch.long)
    
    temporal_proximity_graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        num_nodes=len(sampled_transactions)
    )
    
    graph_data = {
        'graph': temporal_proximity_graph,
        'preprocessors': {
            'temporal_scaler': scaler_temporal
        }
    }
    
    constructor.graphs['temporal_proximity'] = temporal_proximity_graph
    constructor.preprocessors.update(graph_data['preprocessors'])
    constructor.save_checkpoint('temporal_proximity_graph', graph_data)
    constructor.monitor_memory('temporal_proximity_complete')

In [6]:
def build_efficient_account_behavior_graph(constructor, account_features_data):
    graph_data = constructor.load_checkpoint('account_behavior_graph')
    if graph_data is not None:
        constructor.graphs['account_behavior'] = graph_data['graph']
        constructor.preprocessors.update(graph_data['preprocessors'])
        return
    
    print("Building efficient account behavior graph...")
    account_features = account_features_data['account_features']
    node_mappings = constructor.node_mappings
    
    account_list = sorted(node_mappings['accounts'].keys())
    
    feature_matrix = np.array([account_features[acc] for acc in account_list], dtype=np.float32)
    
    scaler_behavior = StandardScaler()
    feature_matrix_scaled = scaler_behavior.fit_transform(feature_matrix)
    
    batch_size = 5000
    similarity_threshold = 0.7
    edge_list = []
    edge_attributes = []
    
    total_batches = (len(account_list) + batch_size - 1) // batch_size
    
    for i in tqdm(range(total_batches), desc="Computing similarities"):
        start_i = i * batch_size
        end_i = min((i + 1) * batch_size, len(account_list))
        
        for j in range(i, total_batches):
            start_j = j * batch_size
            end_j = min((j + 1) * batch_size, len(account_list))
            
            batch_similarities = cosine_similarity(
                feature_matrix_scaled[start_i:end_i],
                feature_matrix_scaled[start_j:end_j]
            )
            
            for local_i, global_i in enumerate(range(start_i, end_i)):
                start_local_j = local_i if i == j else 0
                for local_j, global_j in enumerate(range(start_j, end_j)):
                    if global_j <= global_i:
                        continue
                    
                    similarity = batch_similarities[local_i, local_j]
                    if similarity > similarity_threshold:
                        edge_list.extend([[global_i, global_j], [global_j, global_i]])
                        edge_attributes.extend([[similarity], [similarity]])
        
        if i % 5 == 0:
            constructor.monitor_memory(f'behavior_similarity_{i}')
    
    if len(edge_list) == 0:
        print(f"No edges with threshold {similarity_threshold}, lowering to 0.5")
        similarity_threshold = 0.5
        
        for i in tqdm(range(0, min(10000, len(account_list))), desc="Computing with lower threshold"):
            for j in range(i + 1, min(10000, len(account_list))):
                similarity = cosine_similarity([feature_matrix_scaled[i]], [feature_matrix_scaled[j]])[0, 0]
                if similarity > similarity_threshold:
                    edge_list.extend([[i, j], [j, i]])
                    edge_attributes.extend([[similarity], [similarity]])
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous() if edge_list else torch.empty((2, 0), dtype=torch.long)
    edge_attr = torch.tensor(edge_attributes, dtype=torch.float32) if edge_attributes else torch.empty((0, 1), dtype=torch.float32)
    
    x = torch.tensor(feature_matrix_scaled, dtype=torch.float32)
    
    ml_accounts = set()
    for graph_name, graph in constructor.graphs.items():
        if hasattr(graph, 'y') and graph_name == 'transaction_flow':
            for i, label in enumerate(graph.y):
                if label == 1:
                    ml_accounts.add(account_list[i])
    
    y = torch.tensor([1 if acc in ml_accounts else 0 for acc in account_list], dtype=torch.long)
    
    account_behavior_graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        num_nodes=len(account_list)
    )
    
    graph_data = {
        'graph': account_behavior_graph,
        'preprocessors': {
            'behavior_scaler': scaler_behavior
        }
    }
    
    constructor.graphs['account_behavior'] = account_behavior_graph
    constructor.preprocessors.update(graph_data['preprocessors'])
    constructor.save_checkpoint('account_behavior_graph', graph_data)
    constructor.monitor_memory('account_behavior_complete')

In [7]:
def build_optimized_multimodal_graph(constructor, data, account_features_data):
    graph_data = constructor.load_checkpoint('multimodal_integration_graph')
    if graph_data is not None:
        constructor.graphs['multimodal_integration'] = graph_data['graph']
        constructor.preprocessors.update(graph_data['preprocessors'])
        return
    
    print("Building optimized multimodal integration graph...")
    transactions_df = data['transactions_df']
    node_mappings = constructor.node_mappings
    account_features = account_features_data['account_features']
    
    account_list = sorted(node_mappings['accounts'].keys())
    bank_list = sorted(node_mappings['banks'].keys())
    
    node_features_list = []
    node_types = []
    account_node_mapping = {}
    bank_node_mapping = {}
    node_idx = 0
    
    for account in account_list:
        account_node_mapping[account] = node_idx
        node_types.append(0)
        features = account_features[account]
        padded_features = features + [0.0] * (20 - len(features))
        node_features_list.append(padded_features[:20])
        node_idx += 1
    
    bank_stats = transactions_df.groupby('from_bank').agg({
        'amount_paid': ['count', 'sum', 'mean'],
        'is_laundering': 'sum',
        'payment_currency': 'nunique',
        'payment_format': 'nunique'
    }).fillna(0)
    
    bank_stats.columns = ['_'.join(col) for col in bank_stats.columns]
    
    for bank in bank_list:
        bank_node_mapping[bank] = node_idx
        node_types.append(1)
        
        if bank in bank_stats.index:
            stats = bank_stats.loc[bank]
            bank_features = [
                float(stats.get('amount_paid_count', 0)),
                float(stats.get('amount_paid_sum', 0)),
                float(stats.get('amount_paid_mean', 0)),
                float(stats.get('is_laundering_sum', 0)),
                float(stats.get('payment_currency_nunique', 0)),
                float(stats.get('payment_format_nunique', 0))
            ]
        else:
            bank_features = [0.0] * 6
        
        padded_features = bank_features + [0.0] * (20 - len(bank_features))
        node_features_list.append(padded_features[:20])
        node_idx += 1
    
    edge_list = []
    edge_attributes = []
    
    payment_encoder = constructor.preprocessors['payment_encoder']
    currency_encoder = constructor.preprocessors['currency_encoder']
    
    sample_size = min(1000000, len(transactions_df))
    sampled_indices = np.random.choice(len(transactions_df), sample_size, replace=False)
    sampled_transactions = transactions_df.iloc[sampled_indices]
    
    for _, txn in tqdm(sampled_transactions.iterrows(), total=len(sampled_transactions), desc="Building multimodal edges"):
        if txn['account_origin'] in account_node_mapping and txn['account_destination'] in account_node_mapping:
            source_account_idx = account_node_mapping[txn['account_origin']]
            target_account_idx = account_node_mapping[txn['account_destination']]
            
            edge_list.append([source_account_idx, target_account_idx])
            edge_attributes.append([
                float(txn['amount_paid']),
                float(payment_encoder.transform([txn['payment_format']])[0]),
                float(currency_encoder.transform([txn['payment_currency']])[0]),
                0.0
            ])
            
            if txn['from_bank'] in bank_node_mapping:
                from_bank_idx = bank_node_mapping[txn['from_bank']]
                edge_list.append([source_account_idx, from_bank_idx])
                edge_attributes.append([0.0, 0.0, 0.0, 1.0])
            
            if txn['to_bank'] in bank_node_mapping:
                to_bank_idx = bank_node_mapping[txn['to_bank']]
                edge_list.append([target_account_idx, to_bank_idx])
                edge_attributes.append([0.0, 0.0, 0.0, 1.0])
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attributes, dtype=torch.float32)
    
    scaler_multimodal = StandardScaler()
    node_features_scaled = scaler_multimodal.fit_transform(node_features_list)
    
    x = torch.tensor(node_features_scaled, dtype=torch.float32)
    node_type = torch.tensor(node_types, dtype=torch.long)
    
    ml_accounts = set(transactions_df[transactions_df['is_laundering'] == 1]['account_origin']) | set(transactions_df[transactions_df['is_laundering'] == 1]['account_destination'])
    ml_banks = set(transactions_df[transactions_df['is_laundering'] == 1]['from_bank']) | set(transactions_df[transactions_df['is_laundering'] == 1]['to_bank'])
    
    node_labels = []
    for account in account_list:
        node_labels.append(1 if account in ml_accounts else 0)
    for bank in bank_list:
        node_labels.append(1 if bank in ml_banks else 0)
    
    y = torch.tensor(node_labels, dtype=torch.long)
    
    multimodal_graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        node_type=node_type,
        num_nodes=len(node_features_list)
    )
    
    graph_data = {
        'graph': multimodal_graph,
        'preprocessors': {
            'multimodal_scaler': scaler_multimodal
        },
        'mappings': {
            'account_node_mapping': account_node_mapping,
            'bank_node_mapping': bank_node_mapping
        }
    }
    
    constructor.graphs['multimodal_integration'] = multimodal_graph
    constructor.preprocessors.update(graph_data['preprocessors'])
    constructor.save_checkpoint('multimodal_integration_graph', graph_data)
    constructor.monitor_memory('multimodal_complete')

In [8]:
def build_ground_truth_pattern_graph(constructor, data, account_features_data):
    graph_data = constructor.load_checkpoint('ground_truth_pattern_graph')
    if graph_data is not None:
        if graph_data['graph'] is not None:
            constructor.graphs['ground_truth_patterns'] = graph_data['graph']
        return
    
    print("Building ground truth pattern graph...")
    pattern_transactions_df = data['pattern_transactions_df']
    node_mappings = constructor.node_mappings
    account_features = account_features_data['account_features']
    
    if len(pattern_transactions_df) == 0:
        print("No pattern transactions available")
        constructor.save_checkpoint('ground_truth_pattern_graph', {'graph': None})
        return
    
    pattern_accounts = set(pattern_transactions_df['account_origin'].unique()) | set(pattern_transactions_df['account_destination'].unique())
    pattern_accounts = [acc for acc in pattern_accounts if acc in node_mappings['accounts']]
    
    if len(pattern_accounts) == 0:
        print("No pattern accounts found in node mappings")
        constructor.save_checkpoint('ground_truth_pattern_graph', {'graph': None})
        return
    
    pattern_account_mapping = {acc: idx for idx, acc in enumerate(pattern_accounts)}
    
    edge_list = []
    edge_attributes = []
    
    pattern_types = pattern_transactions_df['pattern_type'].unique()
    pattern_encoder = LabelEncoder()
    pattern_encoder.fit(pattern_types)
    
    for _, txn in tqdm(pattern_transactions_df.iterrows(), total=len(pattern_transactions_df), desc="Processing pattern transactions"):
        if (txn['account_origin'] in pattern_account_mapping and 
            txn['account_destination'] in pattern_account_mapping):
            
            source_idx = pattern_account_mapping[txn['account_origin']]
            target_idx = pattern_account_mapping[txn['account_destination']]
            
            edge_list.append([source_idx, target_idx])
            edge_attributes.append([
                float(pattern_encoder.transform([txn['pattern_type']])[0]),
                float(txn.get('txn_sequence', 1)),
                float(txn.get('total_txns_in_pattern', 1)),
                float(txn['amount_paid'])
            ])
    
    if len(edge_list) == 0:
        print("No edges found in pattern graph")
        constructor.save_checkpoint('ground_truth_pattern_graph', {'graph': None})
        return
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attributes, dtype=torch.float32)
    
    node_features = []
    for account in pattern_accounts:
        features = account_features[account]
        node_features.append(features)
    
    scaler_pattern = StandardScaler()
    node_features_scaled = scaler_pattern.fit_transform(node_features)
    
    x = torch.tensor(node_features_scaled, dtype=torch.float32)
    y = torch.ones(len(pattern_accounts), dtype=torch.long)
    
    pattern_graph = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        num_nodes=len(pattern_accounts)
    )
    
    graph_data = {
        'graph': pattern_graph,
        'pattern_account_mapping': pattern_account_mapping
    }
    
    constructor.graphs['ground_truth_patterns'] = pattern_graph
    constructor.save_checkpoint('ground_truth_pattern_graph', graph_data)
    constructor.monitor_memory('pattern_graph_complete')

In [9]:
def create_ego_network_framework(constructor):
    ego_data = constructor.load_checkpoint('ego_networks')
    if ego_data is not None:
        constructor.graphs['ego_networks'] = ego_data
        return
    
    print("Creating ego network extraction framework...")
    
    if 'transaction_flow' not in constructor.graphs:
        print("Transaction flow graph not available for ego networks")
        constructor.save_checkpoint('ego_networks', {})
        return
    
    transaction_flow_graph = constructor.graphs['transaction_flow']
    node_mappings = constructor.node_mappings
    
    def extract_ego_network(target_account, k_hops=2):
        if target_account not in node_mappings['accounts']:
            return None
        
        target_idx = node_mappings['accounts'][target_account]
        edge_index = transaction_flow_graph.edge_index
        
        current_nodes = {target_idx}
        
        for hop in range(k_hops):
            next_nodes = set()
            for node in current_nodes:
                neighbors_out = edge_index[1][edge_index[0] == node].tolist()
                neighbors_in = edge_index[0][edge_index[1] == node].tolist()
                next_nodes.update(neighbors_out + neighbors_in)
            current_nodes.update(next_nodes)
        
        ego_nodes = sorted(list(current_nodes))
        
        if len(ego_nodes) <= 1:
            return None
        
        node_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(ego_nodes)}
        
        ego_edge_list = []
        ego_edge_attrs = []
        
        for i, (src, dst) in enumerate(edge_index.t().tolist()):
            if src in node_mapping and dst in node_mapping:
                ego_edge_list.append([node_mapping[src], node_mapping[dst]])
                ego_edge_attrs.append(transaction_flow_graph.edge_attr[i].tolist())
        
        if len(ego_edge_list) == 0:
            return None
        
        ego_edge_index = torch.tensor(ego_edge_list, dtype=torch.long).t().contiguous()
        ego_edge_attr = torch.tensor(ego_edge_attrs, dtype=torch.float32)
        
        ego_x = transaction_flow_graph.x[ego_nodes]
        ego_y = transaction_flow_graph.y[ego_nodes]
        
        ego_graph = Data(
            x=ego_x,
            edge_index=ego_edge_index,
            edge_attr=ego_edge_attr,
            y=ego_y,
            num_nodes=len(ego_nodes)
        )
        
        return ego_graph
    
    sample_ego_networks = {}
    ml_accounts = []
    
    if 'transaction_flow' in constructor.graphs:
        account_list = sorted(node_mappings['accounts'].keys())
        for i, label in enumerate(constructor.graphs['transaction_flow'].y):
            if label == 1:
                ml_accounts.append(account_list[i])
    
    for account in ml_accounts[:5]:
        ego_result = extract_ego_network(account)
        if ego_result is not None:
            sample_ego_networks[account] = ego_result
            print(f"Created ego network for {account}: {ego_result.num_nodes} nodes")
    
    ego_data = {
        'sample_networks': sample_ego_networks,
        'extractor_function': extract_ego_network
    }
    
    constructor.graphs['ego_networks'] = ego_data
    constructor.save_checkpoint('ego_networks', ego_data)
    constructor.monitor_memory('ego_networks_complete')

In [10]:
def save_all_graphs_and_metadata(constructor):
   print("Saving all graphs and comprehensive metadata...")
   
   graph_save_data = {}
   for graph_name, graph_obj in constructor.graphs.items():
       if graph_obj is not None:
           if graph_name == 'ego_networks':
               graph_save_data['sample_ego_networks'] = graph_obj.get('sample_networks', {})
           else:
               graph_save_data[f'{graph_name}_graph'] = graph_obj
   
   torch.save(graph_save_data, 'data/optimized_graphs.pt')
   
   preprocessor_data = {
       'preprocessors': constructor.preprocessors,
       'metadata': {
           'creation_timestamp': constructor.creation_timestamp,
           'version': constructor.version,
           'optimization_applied': True,
           'memory_efficient': True
       }
   }
   
   torch.save(preprocessor_data, 'data/optimized_preprocessors.pt')
   
   comprehensive_metadata = {
       'node_mappings': constructor.node_mappings,
       'creation_info': {
           'timestamp': constructor.creation_timestamp,
           'version': constructor.version,
           'graphs_created': len([g for g in constructor.graphs.values() if g is not None]),
           'optimization_level': 'high_performance'
       }
   }
   
   with open('data/optimized_metadata.pkl', 'wb') as f:
       pickle.dump(comprehensive_metadata, f)
   
   summary = {
       'total_graphs': len([g for g in constructor.graphs.values() if g is not None]),
       'optimization_applied': True,
       'memory_usage_reduced': True,
       'checkpointing_enabled': True,
       'files_created': [
           'data/optimized_graphs.pt',
           'data/optimized_preprocessors.pt', 
           'data/optimized_metadata.pkl'
       ]
   }
   
   print("All graphs saved successfully!")
   print(f"Created {summary['total_graphs']} optimized graphs")
   
   return summary



In [12]:
def run_optimized_stage3():
    print("Starting Optimized Stage 3: Graph Construction with Checkpointing")
    print("="*80)
    
    constructor, resume_point = load_and_prepare_data()
    
    if resume_point is None or resume_point == 'data_loaded':
        data = constructor.load_checkpoint('data_loaded')
        create_optimized_node_mappings(constructor, data)
        
    if resume_point is None or resume_point in ['data_loaded', 'node_mappings']:
        data = constructor.load_checkpoint('data_loaded')
        account_features_data = engineer_account_features_batched(constructor, data)
    else:
        account_features_data = constructor.load_checkpoint('account_features_batched')
    
    if resume_point is None or resume_point in ['data_loaded', 'node_mappings', 'account_features_batched']:
        data = constructor.load_checkpoint('data_loaded')
        build_optimized_transaction_flow_graph(constructor, data, account_features_data)
    
    if resume_point is None or resume_point in ['data_loaded', 'node_mappings', 'account_features_batched', 'transaction_flow_graph']:
        data = constructor.load_checkpoint('data_loaded')
        build_sampled_temporal_proximity_graph(constructor, data)
    
    if resume_point is None or resume_point in ['data_loaded', 'node_mappings', 'account_features_batched', 'transaction_flow_graph', 'temporal_proximity_graph']:
        build_efficient_account_behavior_graph(constructor, account_features_data)
    
    if resume_point is None or resume_point in ['data_loaded', 'node_mappings', 'account_features_batched', 'transaction_flow_graph', 'temporal_proximity_graph', 'account_behavior_graph']:
        data = constructor.load_checkpoint('data_loaded')
        build_optimized_multimodal_graph(constructor, data, account_features_data)
    
    if resume_point is None or resume_point in ['data_loaded', 'node_mappings', 'account_features_batched', 'transaction_flow_graph', 'temporal_proximity_graph', 'account_behavior_graph', 'multimodal_integration_graph']:
        data = constructor.load_checkpoint('data_loaded')
        build_ground_truth_pattern_graph(constructor, data, account_features_data)
    
    if resume_point is None or resume_point in ['data_loaded', 'node_mappings', 'account_features_batched', 'transaction_flow_graph', 'temporal_proximity_graph', 'account_behavior_graph', 'multimodal_integration_graph', 'ground_truth_pattern_graph']:
        create_ego_network_framework(constructor)
    
    summary = save_all_graphs_and_metadata(constructor)
    
    print("\n" + "="*80)
    print("OPTIMIZED STAGE 3 COMPLETED SUCCESSFULLY")
    print("="*80)
    print(f"Created {summary['total_graphs']} optimized graphs")
    print("Memory usage optimized with batching and sampling")
    print("Checkpointing enabled for resumable execution") 
    print("All graphs saved to optimized format")
    print("Ready for Stage 4: Explainable GAT Architecture")
    print("="*80)
    
    return constructor, summary

if __name__ == "__main__":
    constructor, summary = run_optimized_stage3()

Starting Optimized Stage 3: Graph Construction with Checkpointing
Resuming from checkpoint: temporal_proximity_graph
Loaded checkpoint: account_features_batched
Building efficient account behavior graph...


KeyError: 'accounts'