# CHRONOS Model Inference

This notebook runs full inference on the trained CHRONOS GNN model.

**Steps:**
1. Install dependencies
2. Upload model checkpoint and dataset
3. Run inference
4. Download results

## Step 1: Install Dependencies

In [None]:
!pip install torch torch-geometric pandas numpy scikit-learn -q
print('Dependencies installed!')

## Step 2: Upload Files

Upload the following files:
1. `best_model.pt` - from `checkpoints/chronos_experiment/`
2. `elliptic_txs_features.csv` - from `data/raw/elliptic/raw/`
3. `elliptic_txs_classes.csv` - from `data/raw/elliptic/raw/`
4. `elliptic_txs_edgelist.csv` - from `data/raw/elliptic/raw/`

In [None]:
from google.colab import files
print('Upload best_model.pt:')
uploaded = files.upload()

In [None]:
print('Upload the 3 Elliptic CSV files:')
uploaded = files.upload()

## Step 3: Define Model Architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import numpy as np
import pandas as pd

class CHRONOSInference(nn.Module):
    def __init__(self, in_features=235, hidden_dim=256, num_heads=8, dropout=0.3):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        
        self.input_proj = nn.Linear(in_features, hidden_dim)
        
        self.temporal = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
        )
        
        self.gat_layers = nn.ModuleList([
            GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads, concat=True, dropout=dropout),
            GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads, concat=True, dropout=dropout),
            GATConv(hidden_dim, hidden_dim, heads=num_heads, concat=True, dropout=dropout),
        ])
        
        self._gat_norms = nn.ModuleList([
            nn.BatchNorm1d(hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.BatchNorm1d(hidden_dim * num_heads),
        ])
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2),
        )
    
    def forward(self, x, edge_index):
        h = self.input_proj(x)
        h = F.elu(h)
        t = self.temporal(h)
        
        gat_out = h
        for gat in self.gat_layers:
            gat_out = gat(gat_out, edge_index)
            gat_out = F.elu(gat_out)
            gat_out = F.dropout(gat_out, p=0.3, training=self.training)
        
        gat_out = gat_out.view(-1, self.num_heads, self.hidden_dim).mean(dim=1)
        combined = torch.cat([gat_out, t], dim=-1)
        return self.classifier(combined)
    
    def predict(self, x, edge_index):
        self.eval()
        with torch.no_grad():
            logits = self.forward(x, edge_index)
            probs = F.softmax(logits, dim=-1)[:, 1]
            preds = (probs > 0.5).long()
        return probs, preds

print('Model class defined!')

## Step 4: Load Data and Compute Features

In [None]:
import networkx as nx
from scipy import stats

print('Loading data...')
features_df = pd.read_csv('elliptic_txs_features.csv', header=None)
classes_df = pd.read_csv('elliptic_txs_classes.csv')
edges_df = pd.read_csv('elliptic_txs_edgelist.csv')

tx_ids = features_df[0].values
timesteps = features_df[1].values
X = features_df.iloc[:, 2:].values.astype(np.float32)

id_to_idx = {tx_id: idx for idx, tx_id in enumerate(tx_ids)}

y = np.full(len(tx_ids), -1)
for _, row in classes_df.iterrows():
    tx_id = row['txId']
    if tx_id in id_to_idx:
        idx = id_to_idx[tx_id]
        label = str(row['class'])
        if label == '1': y[idx] = 0
        elif label == '2': y[idx] = 1

edge_list = []
for _, row in edges_df.iterrows():
    src, dst = row['txId1'], row['txId2']
    if src in id_to_idx and dst in id_to_idx:
        edge_list.append([id_to_idx[src], id_to_idx[dst]])

edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

test_mask = (timesteps >= 43) & (timesteps <= 49) & (y >= 0)
print(f'Nodes: {len(X)}, Edges: {edge_index.shape[1]}, Test samples: {test_mask.sum()}')

In [None]:
# Compute engineered features (simplified version)
print('Computing engineered features...')

G = nx.DiGraph()
for i in range(edge_index.shape[1]):
    G.add_edge(edge_index[0, i].item(), edge_index[1, i].item())

n_nodes = len(X)
eng_features = np.zeros((n_nodes, 70), dtype=np.float32)

# Graph topology (20 features)
in_deg = dict(G.in_degree())
out_deg = dict(G.out_degree())
for i in range(n_nodes):
    eng_features[i, 0] = in_deg.get(i, 0)
    eng_features[i, 1] = out_deg.get(i, 0)
    eng_features[i, 2] = in_deg.get(i, 0) + out_deg.get(i, 0)

print('Computing PageRank...')
pr = nx.pagerank(G, max_iter=50)
for i in range(n_nodes):
    eng_features[i, 5] = pr.get(i, 0)

# Temporal features (25 features) - use timestep
for i in range(n_nodes):
    eng_features[i, 20] = timesteps[i]
    eng_features[i, 21] = timesteps[i] / 49.0

# Normalize
eng_features = (eng_features - eng_features.mean(axis=0)) / (eng_features.std(axis=0) + 1e-8)

# Combine
X_combined = np.concatenate([X, eng_features], axis=1)
print(f'Combined features: {X_combined.shape[1]} dimensions')

## Step 5: Load Model and Run Inference

In [None]:
# Load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

checkpoint = torch.load('best_model.pt', map_location=device)
model = CHRONOSInference(in_features=X_combined.shape[1])

# Load weights with flexible matching
model_state = checkpoint['model']
model_dict = model.state_dict()
loaded = 0
for key in model_dict.keys():
    ckpt_key = key
    if '_gat_norms' in key:
        ckpt_key = key.replace('_gat_norms', 'gat_norms')
        parts = ckpt_key.split('.')
        if len(parts) >= 2:
            ckpt_key = f"{parts[0]}.{parts[1]}.module.{'.'.join(parts[2:])}"
    if ckpt_key in model_state and model_dict[key].shape == model_state[ckpt_key].shape:
        model_dict[key] = model_state[ckpt_key]
        loaded += 1

model.load_state_dict(model_dict)
model.to(device)
model.eval()
print(f'Loaded {loaded}/{len(model_dict)} weights')

In [None]:
# Run inference
print('Running inference...')
X_tensor = torch.tensor(X_combined, dtype=torch.float32).to(device)
edge_index = edge_index.to(device)

with torch.no_grad():
    probs, preds = model.predict(X_tensor, edge_index)

probs = probs.cpu().numpy()
preds = preds.cpu().numpy()
print('Inference complete!')

## Step 6: Compute Metrics and Save Results

In [None]:
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score

# Test set metrics
test_probs = probs[test_mask]
test_preds = preds[test_mask]
test_labels = y[test_mask]

cm = confusion_matrix(test_labels, test_preds)
f1 = f1_score(test_labels, test_preds)
precision = precision_score(test_labels, test_preds)
recall = recall_score(test_labels, test_preds)

print('\n=== TEST SET RESULTS ===')
print(f'Confusion Matrix:')
print(f'  TN: {cm[0,0]:5d}  FP: {cm[0,1]:5d}')
print(f'  FN: {cm[1,0]:5d}  TP: {cm[1,1]:5d}')
print(f'\nF1 Score:  {f1:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall:    {recall:.4f}')

In [None]:
# Save results
cm_df = pd.DataFrame(cm, index=['Actual Licit', 'Actual Illicit'], 
                     columns=['Pred Licit', 'Pred Illicit'])
cm_df.to_csv('confusion_matrix.csv')

predictions_df = pd.DataFrame({
    'probability': test_probs,
    'prediction': test_preds,
    'label': test_labels
})
predictions_df.to_csv('predictions.csv', index=False)

metrics_df = pd.DataFrame({
    'metric': ['f1_score', 'precision', 'recall', 'TP', 'TN', 'FP', 'FN'],
    'value': [f1, precision, recall, cm[1,1], cm[0,0], cm[0,1], cm[1,0]]
})
metrics_df.to_csv('test_metrics.csv', index=False)

print('Results saved!')

In [None]:
# Download results
files.download('confusion_matrix.csv')
files.download('predictions.csv')
files.download('test_metrics.csv')