In [None]:
# !pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118
# !pip install torch_geometric
# # Optional dependencies:
# !pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+cu118.html
# !pip install pandas numpy torch torch_geometric networkx matplotlib

In [18]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv
from torch_geometric.data import HeteroData
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from datetime import datetime
from tqdm import tqdm
import warnings
import os
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

import multiprocessing
from collections import defaultdict, Counter

In [19]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

import torch.backends.cudnn as cudnn
cudnn.benchmark = True

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

Using device: cuda


## Class to create Graph

In [20]:
def process_chunk(df_chunk):
    user_currency_counter = defaultdict(Counter)
    user_location_counter = defaultdict(Counter)
    
    for _, row in df_chunk.iterrows():
        sender = row['Sender_account']
        receiver = row['Receiver_account']
        
        # For sender
        user_currency_counter[sender][row['Payment_currency']] += 1
        user_location_counter[sender][row['Sender_bank_location']] += 1
        
        # For receiver
        user_currency_counter[receiver][row['Received_currency']] += 1
        user_location_counter[receiver][row['Receiver_bank_location']] += 1
    
    return user_currency_counter, user_location_counter

In [None]:
class HetroGraphDataset:
    def __init__(self, csv_file, graph_path, force_recreate=False):
        self.path = csv_file
        self.graph_path = graph_path
        self.graph = HeteroData()
        if os.path.exists(self.graph_path) and not force_recreate:
            print(f"Graph file found at {self.graph_path}. Loading existing graph...")
            self.graph = torch.load(self.graph_path)
            self.n_transactions = self.graph['transaction'].num_nodes
            self.user_map = {i: idx for idx, i in enumerate(range(self.graph['user'].num_nodes))}
            print("Graph loaded successfully.")
        else:
            print(f"Creating graph from CSV (force_recreate={force_recreate})...")
            self.df = self._read_csv_file()
            self.n_transactions = len(self.df)
            self.user_map = {}
            self.create_mapping()
            print("Start Creating Graph...")
            steps = ["Initializing Graph Nodes", "Adding Edges"]
            with tqdm(total=len(steps), desc="Graph Creation Progress") as pbar:
                self.initalize_graph_nodes()
                pbar.update(1)
                self.add_edges()
                pbar.update(1)
            print("Graph creation completed.")
            self.save_graph(self.graph_path)

    def _read_csv_file(self):
        print("Reading the Data ...")
        dtype = {
            'Sender_account': 'int64',
            'Receiver_account': 'int64',
            'Sender_bank_location': 'int8',
            'Receiver_bank_location': 'int8',
            'Payment_currency': 'int8',
            'Received_currency': 'int8',
            'Amount': 'float32',
            'Payment_type': 'int8',
            'Year': 'int16',
            'Month': 'int8',
            'Day': 'int8',
            'Is_laundering': 'int8',
            'Laundering_type': 'int8'
        }
        return pd.read_csv(self.path, dtype=dtype)

    def create_mapping(self):
        print("Computing most frequent currency and location for each user...")
        
        # Create mapping from account numbers to user indices
        all_accounts = set(self.df['Sender_account'].unique()).union(set(self.df['Receiver_account'].unique()))
        self.user_map = {acc: i for i, acc in enumerate(sorted(all_accounts))}
        print(f"Number of users = {len(self.user_map)}")
        print(f"Number of transactions = {self.n_transactions}")
        
        # Determine number of processes
        num_processes = multiprocessing.cpu_count()
        
        # Split dataframe into chunks
        indices = np.array_split(self.df.index, num_processes)
        chunks = [self.df.loc[idx] for idx in indices]
        
        # Use Pool to map over chunks with progress bar
        with multiprocessing.Pool() as pool:
            with tqdm(total=len(chunks), desc="Processing chunks") as pbar:
                results = list(pool.imap(process_chunk, chunks))
        
        # Combine partial counters
        total_user_currency_counter = defaultdict(Counter)
        total_user_location_counter = defaultdict(Counter)
        
        for chunk_currency_counter, chunk_location_counter in results:
            for user, counter in chunk_currency_counter.items():
                total_user_currency_counter[user].update(counter)
            for user, counter in chunk_location_counter.items():
                total_user_location_counter[user].update(counter)
        
        # Find most frequent currency and location for each user
        self.user_currency = {}
        self.user_location = {}
        for account in self.user_map:
            currency_counter = total_user_currency_counter[account]
            location_counter = total_user_location_counter[account]
            
            if currency_counter:
                most_frequent_currency = currency_counter.most_common(1)[0][0]
            else:
                most_frequent_currency = 0
            if location_counter:
                most_frequent_location = location_counter.most_common(1)[0][0]
            else:
                most_frequent_location = 0
            self.user_currency[account] = most_frequent_currency
            self.user_location[account] = most_frequent_location

    def initalize_graph_nodes(self):
        self.graph['transaction'].num_nodes = self.n_transactions
        payment_type_values = self.df['Payment_type'].values
        payment_type_onehot = torch.nn.functional.one_hot(
            torch.from_numpy(payment_type_values).to(dtype=torch.int64),
            num_classes=7  # Adjust if different
        )
        amount = torch.from_numpy(self.df['Amount'].values).unsqueeze(1)
        transaction_features = torch.cat([amount, payment_type_onehot], dim=1).to(dtype=torch.float32)
        self.graph['transaction'].x = transaction_features
        self.graph['transaction'].date = [
            f"{int(row['Year'])}-{int(row['Month']):02d}-{int(row['Day']):02d}"
            for _, row in self.df.iterrows()
        ]
        self.graph['transaction'].is_laundering = torch.from_numpy(self.df['Is_laundering'].values).to(dtype=torch.int8)
        self.graph['user'].num_nodes = len(self.user_map)
        user_currency_values = [self.user_currency[acc] for acc in self.user_map]
        user_currency_onehot = torch.nn.functional.one_hot(
            torch.from_numpy(np.array(user_currency_values)).to(dtype=torch.int64),
            num_classes=13  # Adjust if different
        )
        user_location_values = [self.user_location[acc] for acc in self.user_map]
        user_location_onehot = torch.nn.functional.one_hot(
            torch.from_numpy(np.array(user_location_values)).to(dtype=torch.int64),
            num_classes=18  # Adjust if different
        )
        user_features = torch.cat([user_currency_onehot, user_location_onehot], dim=1).to(dtype=torch.float32)
        self.graph['user'].x = user_features

    def add_edges(self):
        sender_to_transaction = torch.tensor(
            [[self.user_map[row['Sender_account']], i] for i, (_, row) in enumerate(self.df.iterrows())],
            dtype=torch.long
        ).t()
        self.graph['user', 'sends', 'transaction'].edge_index = sender_to_transaction
        transaction_to_receiver = torch.tensor(
            [[i, self.user_map[row['Receiver_account']]] for i, (_, row) in enumerate(self.df.iterrows())],
            dtype=torch.long
        ).t()
        self.graph['transaction', 'received_by', 'user'].edge_index = transaction_to_receiver

    def save_graph(self, path):
        print(f"Saving graph to {path}...")
        torch.save(self.graph, path)
        print(f"Graph saved successfully to {path}")

    def info_about_graph(self):
        print("Heterogeneous Graph Summary:")
        print(self.graph)
        
        print("\nNode counts:")
        print(f"user: {self.graph['user'].num_nodes}")
        print(f"Transactions: {self.graph['transaction'].num_nodes}")
        
        print("\nEdge types and counts:")
        for edge_type in self.graph.edge_types:
            print(f"{edge_type}: {self.graph[edge_type].edge_index.shape[1]} edges")

In [None]:
# Main execution
dataset = HetroGraphDataset('/kaggle/input/complete-dataset/data.csv', graph_path='/kaggle/input/complete-graph/complete-graph.pt', force_recreate=True)  # Force recreate to update features
data = dataset.graph

# Verify feature shapes
print(f"Transaction features shape: {data['transaction'].x.shape}")
print(f"User features shape: {data['user'].x.shape}")
dataset.info_about_graph()

Graph file found at /kaggle/input/complete-graph/complete-graph.pt. Loading existing graph...


  self.graph = torch.load(self.graph_path)


Graph loaded successfully.
Transaction features shape: torch.Size([9504852, 8])
User features shape: torch.Size([855460, 31])
Heterogeneous Graph Summary:
HeteroData(
  transaction={
    num_nodes=9504852,
    x=[9504852, 8],
    date=[9504852],
    is_laundering=[9504852],
  },
  user={
    num_nodes=855460,
    x=[855460, 31],
  },
  (user, sends, transaction)={ edge_index=[2, 9504852] },
  (transaction, received_by, user)={ edge_index=[2, 9504852] }
)

Node counts:
user: 855460
Transactions: 9504852

Edge types and counts:
('user', 'sends', 'transaction'): 9504852 edges
('transaction', 'received_by', 'user'): 9504852 edges
