# CHRONOS Training with Logging

Train CHRONOS model and generate training curves.

**Output:**
- Training/validation loss curves
- F1/Precision/Recall over epochs
- Best model checkpoint

In [None]:
!pip install torch torch-geometric pandas numpy scikit-learn matplotlib -q
print('Dependencies installed!')

In [None]:
from google.colab import files
print('Upload the 3 Elliptic CSV files:')
uploaded = files.upload()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, precision_score, recall_score
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

In [None]:
# Load data
print('Loading data...')
features_df = pd.read_csv('elliptic_txs_features.csv', header=None)
classes_df = pd.read_csv('elliptic_txs_classes.csv')
edges_df = pd.read_csv('elliptic_txs_edgelist.csv')

tx_ids = features_df[0].values
timesteps = features_df[1].values
X = features_df.iloc[:, 2:].values.astype(np.float32)

id_to_idx = {tx_id: idx for idx, tx_id in enumerate(tx_ids)}

y = np.full(len(tx_ids), -1)
for _, row in classes_df.iterrows():
    tx_id = row['txId']
    if tx_id in id_to_idx:
        idx = id_to_idx[tx_id]
        label = str(row['class'])
        if label == '1': y[idx] = 0
        elif label == '2': y[idx] = 1

edge_list = []
for _, row in edges_df.iterrows():
    src, dst = row['txId1'], row['txId2']
    if src in id_to_idx and dst in id_to_idx:
        edge_list.append([id_to_idx[src], id_to_idx[dst]])

edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

# Masks
train_mask = torch.tensor((timesteps >= 1) & (timesteps <= 34) & (y >= 0), dtype=torch.bool)
val_mask = torch.tensor((timesteps >= 35) & (timesteps <= 42) & (y >= 0), dtype=torch.bool)
test_mask = torch.tensor((timesteps >= 43) & (timesteps <= 49) & (y >= 0), dtype=torch.bool)

# Data object
data = Data(
    x=torch.tensor(X, dtype=torch.float32),
    y=torch.tensor(y, dtype=torch.long),
    edge_index=edge_index,
    train_mask=train_mask,
    val_mask=val_mask,
    test_mask=test_mask
).to(device)

print(f'Nodes: {data.num_nodes}, Train: {train_mask.sum()}, Val: {val_mask.sum()}, Test: {test_mask.sum()}')

In [None]:
# Simple GNN Model
class SimpleGNN(nn.Module):
    def __init__(self, in_features, hidden_dim=64, num_heads=4):
        super().__init__()
        self.proj = nn.Linear(in_features, hidden_dim)
        self.gat1 = GATConv(hidden_dim, hidden_dim, heads=num_heads, concat=False)
        self.gat2 = GATConv(hidden_dim, hidden_dim, heads=num_heads, concat=False)
        self.classifier = nn.Linear(hidden_dim, 2)
    
    def forward(self, x, edge_index):
        h = F.relu(self.proj(x))
        h = F.relu(self.gat1(h, edge_index))
        h = F.dropout(h, p=0.3, training=self.training)
        h = F.relu(self.gat2(h, edge_index))
        return self.classifier(h)

model = SimpleGNN(data.num_features).to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

In [None]:
# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 5.0]).to(device))  # Class imbalance

# Logging
history = {
    'epoch': [], 'train_loss': [], 'val_loss': [],
    'train_f1': [], 'val_f1': [],
    'val_precision': [], 'val_recall': []
}

# Training loop
num_epochs = 100
best_val_f1 = 0

for epoch in range(1, num_epochs + 1):
    # Train
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    train_loss = criterion(out[data.train_mask], data.y[data.train_mask])
    train_loss.backward()
    optimizer.step()
    
    # Eval
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        val_loss = criterion(out[data.val_mask], data.y[data.val_mask])
        
        train_pred = out[data.train_mask].argmax(dim=1).cpu()
        val_pred = out[data.val_mask].argmax(dim=1).cpu()
        
        train_f1 = f1_score(data.y[data.train_mask].cpu(), train_pred)
        val_f1 = f1_score(data.y[data.val_mask].cpu(), val_pred)
        val_prec = precision_score(data.y[data.val_mask].cpu(), val_pred)
        val_rec = recall_score(data.y[data.val_mask].cpu(), val_pred)
    
    # Log
    history['epoch'].append(epoch)
    history['train_loss'].append(train_loss.item())
    history['val_loss'].append(val_loss.item())
    history['train_f1'].append(train_f1)
    history['val_f1'].append(val_f1)
    history['val_precision'].append(val_prec)
    history['val_recall'].append(val_rec)
    
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), 'best_model_training.pt')
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}: Train F1={train_f1:.4f}, Val F1={val_f1:.4f}')

print(f'\nBest Val F1: {best_val_f1:.4f}')

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['epoch'], history['train_loss'], label='Train')
axes[0].plot(history['epoch'], history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True)

# F1
axes[1].plot(history['epoch'], history['train_f1'], label='Train')
axes[1].plot(history['epoch'], history['val_f1'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('F1 Score')
axes[1].set_title('F1 Score')
axes[1].legend()
axes[1].grid(True)

# Precision/Recall
axes[2].plot(history['epoch'], history['val_precision'], label='Precision')
axes[2].plot(history['epoch'], history['val_recall'], label='Recall')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Score')
axes[2].set_title('Precision & Recall')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()

print('Saved training_curves.png')

In [None]:
# Save history
history_df = pd.DataFrame(history)
history_df.to_csv('training_history.csv', index=False)

# Download
files.download('training_curves.png')
files.download('training_history.csv')
files.download('best_model_training.pt')