In [None]:
import os
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, precision_score, precision_recall_curve, auc
from sklearn.preprocessing import StandardScaler
from datetime import datetime, timedelta

# Load and preprocess data
PKL_FOLDER = "./simulated-data-transformed/data"
list_of_transaction_dfs = []
for filename in os.listdir(PKL_FOLDER):
    if filename.endswith(".pkl"):
        list_of_transaction_dfs.append(pd.read_pickle(os.path.join(PKL_FOLDER, filename)))

transactions_df = pd.concat(list_of_transaction_dfs, ignore_index=True)
transactions_df['timestamp'] = pd.to_datetime(transactions_df['TX_DATETIME'])
transactions_df.sort_values('timestamp', inplace=True)
transactions_df['prev_id'] = transactions_df.groupby('CUSTOMER_ID')['TRANSACTION_ID'].shift(1)

# Graph construction
edges = transactions_df.dropna(subset=['prev_id'])
edge_index = torch.tensor(edges[['prev_id', 'TRANSACTION_ID']].to_numpy().T, dtype=torch.long)

# Node features and scaling
scaler = StandardScaler()
features = scaler.fit_transform(transactions_df[['TX_AMOUNT', 'TX_TIME_SECONDS', 'CUSTOMER_ID_NB_TX_1DAY_WINDOW', 'CUSTOMER_ID_AVG_AMOUNT_1DAY_WINDOW']].fillna(0))
x = torch.tensor(features, dtype=torch.float)

# Labels and mask
labels = transactions_df['TX_FRAUD'].to_numpy()
y = torch.tensor(labels, dtype=torch.long)
train_mask = transactions_df['timestamp'] < datetime(2018, 8, 1)
test_mask = transactions_df['timestamp'] >= datetime(2018, 8, 1)
data = Data(x=x, edge_index=edge_index, y=y)
data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
data.test_mask = torch.tensor(test_mask, dtype=torch.bool)

# Improved GTAN Model with attribute embeddings and risk embeddings
class ImprovedGTAN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(ImprovedGTAN, self).__init__()
        self.conv1 = GATConv(num_features, 16, heads=8)  # Multi-head for better feature capture
        self.conv2 = GATConv(16 * 8, num_classes, heads=1, concat=False)  # Final GAT layer for classification

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)  # Regularization with dropout
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ImprovedGTAN(data.num_features, 2).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.NLLLoss()

# Training loop
for epoch in range(32):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Evaluate the Model
model.eval()
with torch.no_grad():
    pred_proba = model(data).exp()[:, 1]  # Probabilities for class 1
    pred_labels = pred_proba > 0.5
    test_mask = data.test_mask.cpu().numpy()
    test_y = data.y.cpu().numpy()
    pred_proba = pred_proba.cpu().numpy()
    pred_labels = pred_labels.cpu().numpy()
    auc_score = roc_auc_score(test_y[test_mask], pred_proba[test_mask])
    precision = precision_score(test_y[test_mask], pred_labels[test_mask])
    precision, recall, _ = precision_recall_curve(test_y[test_mask], pred_proba[test_mask])
    pr_auc = auc(recall, precision)
    print(f'Test AUC: {auc_score}, PR AUC: {pr_auc}')
