# Simple Classifier vs multi-task



Imports and installs

In [None]:
#clean up
%pip uninstall -y pyg pyg-lib torch-geometric torch-scatter torch-sparse \
  torch-cluster torch-spline-conv torch-geometric-temporal


In [None]:
#Pin Torch/CUDA to a wheel-rich combo (GPU: CUDA 12.1)
%pip install --upgrade --force-reinstall \
  torch==2.4.0+cu121 torchvision==0.19.0+cu121 torchaudio==2.4.0+cu121 \
  --index-url https://download.pytorch.org/whl/cu121


In [None]:
#compiled PyG addons that MATCH torch 2.4.0+cu121 (no builds)
%pip install --no-cache-dir --only-binary=:all: \
  pyg-lib torch_scatter torch_sparse torch_cluster torch_spline_conv \
  -f https://data.pyg.org/whl/torch-2.4.0+cu121.html


In [None]:
#pure-Python core (fast, no compiling)
%pip install --no-cache-dir torch_geometric==2.6.1


In [None]:
#temporal install (don't let pip re-resolve addons)
%pip install --no-cache-dir --no-deps torch-geometric-temporal


In [None]:
#check it worked
import torch
print("Torch:", torch.__version__, "| CUDA:", torch.version.cuda, "| GPU:", torch.cuda.is_available())

import torch_geometric, torch_scatter, torch_sparse, torch_cluster, torch_spline_conv
print("PyG:", torch_geometric.__version__,
      "| scatter:", torch_scatter.__version__,
      "| sparse:", torch_sparse.__version__)

from torch_geometric_temporal.nn.recurrent import A3TGCN
print("A3TGCN OK")


In [None]:
from google.colab import drive # mount to drive
drive.mount('/content/drive')

In [None]:
#changing dir
import os
os.chdir("/content/drive/MyDrive/Final Project")

In [None]:
import os, random
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

from sklearn.metrics import (
    accuracy_score, f1_score, roc_auc_score, average_precision_score,
    confusion_matrix, roc_curve, precision_recall_curve
)

# Temporal GNN
from torch_geometric_temporal.nn.recurrent import A3TGCN

# helpers
from helper_functions import get_edge_data, compute_classification_metrics, compute_regression_metrics

# Reproducibility
SEED = 123
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

def seed_worker(worker_id):
    worker_seed = SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator().manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Hyper Parameters

In [None]:
#Hyperparameters
HIDDEN_DIM        = 256
PERIODS           = 1
LR                = 1e-2
WEIGHT_DECAY      = 1e-5
EPOCHS            = 200
LOSS_ALPHA        = 0.5     # weight for classification loss (multi‑task only)
BATCH_SIZE        = 1
VAL_WINDOWS_COUNT = 4

#optimiser split LRs (trunk / clf / reg)
LR_TRUNK = LR * 0.
LR_CLF   = LR
LR_REG   = LR * 5.0

PRETRAIN_EPOCHS = 50  #warm‑up for regression (set 0 for classifier‑only)
PATIENCE        = 20   #early stop


## Loading data and defining datasets

In [None]:
#load edges and timestamps and mappings
edges_df = pd.read_csv("data/txs_edgelist.csv", usecols=["txId1","txId2"])
ts_df    = pd.read_csv("data/txs_features.csv", usecols=["txId","Time step"]).rename(
    columns={"txId":"txId2","Time step":"timestamp"}
)
edges_ts = edges_df.merge(ts_df, on="txId2", how="left")

verts    = pd.read_parquet("data/verts_int.parquet").astype({"txId":str})
int_map  = verts.set_index('txId')['int_id'].to_dict()

edges_ts['src'] = edges_ts['txId1'].astype(str).map(int_map)
edges_ts['dst'] = edges_ts['txId2'].astype(str).map(int_map)
edges_ts = edges_ts[['src','dst','timestamp']]
edges_ts.head()

In [None]:
class WindowDataset(Dataset):
    def __init__(self, features_csv: str, classes_csv: str, verts_parquet: str, include_orbits: bool = True):
        super().__init__()
        self.include_orbits = include_orbits

        df = pd.read_csv(features_csv)
        # Orbit targets are in columns prefixed with 'orbit_'
        self.orbit_cols = [c for c in df.columns if c.startswith('orbit_')]

        #Meta/PCA feature columns (exclude ids, window, and orbit columns from this)
        drop = ['int_id', 'window_start']
        drop += self.orbit_cols  # X excludes orbit targets in both modes
        self.meta_pca_cols = [c for c in df.columns if c not in drop]

        self.df = df.copy()

        #Attach class labels via txId to int_id mapping
        classes = pd.read_csv(classes_csv, usecols=['txId','class']).astype({'txId':str})
        verts   = pd.read_parquet(verts_parquet).astype({'txId':str})
        cls     = verts.merge(classes, on='txId', how='left')[['int_id','class']]
        cls['class'] = cls['class'].fillna(3).astype(int)   # unknown is 3
        cls.loc[cls['class'] == 3, 'class'] = -100          # mask unknown
        self.cls_map = dict(zip(cls['int_id'], cls['class']))

        #sorted unique window starts
        self.windows = sorted(self.df['window_start'].unique())

    def __len__(self):
        return len(self.windows)

    def __getitem__(self, idx: int):
        w = self.windows[idx]
        sub = self.df[self.df['window_start'] == w].sort_values('int_id')

        #classification labels (masked unknowns = -100)
        y_clf = torch.tensor([self.cls_map[i] for i in sub['int_id']], dtype=torch.long)

        #if include_orbits=False, drop unlabelled rows within this window
        if not self.include_orbits:
            keep = (y_clf != -100)
            sub = sub.loc[keep.cpu().numpy()].copy()
            y_clf = y_clf[keep]

        #Features X (meta/pca only; orbit targets are not features)
        x_meta = torch.tensor(sub[self.meta_pca_cols].values, dtype=torch.float32)

        #regression targets and previous window orbits (only if include_orbits = True)
        if self.include_orbits and len(self.orbit_cols) > 0:
            y_reg = torch.tensor(sub[self.orbit_cols].values, dtype=torch.float32)
            if idx > 0:
                w_prev = self.windows[idx-1]
                prev = self.df[self.df['window_start'] == w_prev][['int_id'] + self.orbit_cols]
                prev = prev.sort_values('int_id')
                prev_orbits = torch.tensor(prev[self.orbit_cols].values, dtype=torch.float32)
            else:
                prev_orbits = torch.zeros_like(y_reg)
        else:
            # empty placeholders that collate cleanly for classification only
            y_reg = torch.zeros((x_meta.size(0), 0), dtype=torch.float32)
            prev_orbits = torch.zeros((x_meta.size(0), 0), dtype=torch.float32)

        #scalars for window ids
        window_start_t       = torch.tensor(int(w), dtype=torch.long)
        window_prev_start_t  = torch.tensor(int(self.windows[idx-1]) if idx>0 else int(w), dtype=torch.long)

        int_ids_list = [int(i) if i is not None else -1 for i in sub['int_id'].tolist()]

        return {
            'x': x_meta,
            'y_reg': y_reg,
            'y_clf': y_clf,
            'prev_orbits': prev_orbits,
            'window_start': window_start_t,
            'window_prev_start': window_prev_start_t,
            'int_ids': int_ids_list,
        }


In [None]:
def build_dataloaders(
    features_csv: str = 'data/selected_features_windows_scaled.csv',
    classes_csv: str  = 'data/txs_classes.csv',
    verts_parquet: str= 'data/verts_int.parquet',
    batch_size: int = 1,
    val_windows_count: int = 4,
    include_orbits: bool = True,
    # Extras
    generator=None, worker_init_fn=None, collate_fn=None,
    num_workers: int = 0, pin_memory: bool = False, shuffle_train: bool = True,
):
    dataset = WindowDataset(features_csv, classes_csv, verts_parquet, include_orbits=include_orbits)
    total   = len(dataset)
    val_n   = min(val_windows_count, total)
    train_n = total - val_n

    train_ds = Subset(dataset, list(range(train_n)))
    val_ds   = Subset(dataset, list(range(train_n, total)))

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=shuffle_train,
        worker_init_fn=worker_init_fn, generator=generator,
        collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        worker_init_fn=worker_init_fn, generator=generator,
        collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory
    )
    print(f'Total windows: {total}, Train: {train_n}, Val: {val_n} (last {val_windows_count})')
    return train_loader, val_loader

In [None]:
train_loader, val_loader = build_dataloaders(
    features_csv      = 'data/selected_features_windows_scaled.csv',
    classes_csv       = 'data/txs_classes.csv',
    verts_parquet     = 'data/verts_int.parquet',
    batch_size        = BATCH_SIZE,
    val_windows_count = VAL_WINDOWS_COUNT,
    include_orbits    = True,    # Multi‑task
    worker_init_fn    = seed_worker,
    generator         = g,
)

# Shapes
sample     = next(iter(train_loader))
IN_FEATS   = sample['x'].shape[-1]
NUM_ORBITS = (sample['y_reg'].shape[-1] if sample['y_reg'].numel() > 0 else len([c for c in pd.read_csv('data/selected_features_windows_scaled.csv').columns if c.startswith('orbit_')]))
NUM_CLASSES= 2
print(f'Features={IN_FEATS}, Orbits={NUM_ORBITS}, Classes={NUM_CLASSES}')

In [None]:
class MultiTaskA3TGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_orbits, num_classes, periods=1, dropout=0.1):
        super().__init__()
        self.num_orbits = num_orbits
        self.a3tgcn = A3TGCN(in_channels=in_channels, out_channels=hidden_channels, periods=periods)

        # Regressor predicts delta, then we add prev_orbits (residual style)
        self.regressor = nn.Sequential(
            nn.Linear(hidden_channels + num_orbits, hidden_channels * 2),
            nn.ReLU(),
            nn.Linear(hidden_channels * 2, hidden_channels),
            nn.ReLU(),
            nn.LayerNorm(hidden_channels),
            nn.Linear(hidden_channels, num_orbits),
        )
        self.classifier = nn.Linear(hidden_channels, num_classes)

        # Zero‑init final reg layer so we start near "copy prev_orbits"
        with torch.no_grad():
            last = self.regressor[-1]
            last.weight.zero_()
            last.bias.zero_()

    def forward(self, x, edge_index, edge_weight=None, prev_orbits=None):
        # A3TGCN expects [N, F, T]; we use T=1 here
        x_t = x.unsqueeze(-1) if x.dim() == 2 else x
        h   = self.a3tgcn(x_t, edge_index, edge_weight)  # [N, hidden]

        if prev_orbits is None or (isinstance(prev_orbits, torch.Tensor) and prev_orbits.numel() == 0):
            prev_orbits = torch.zeros(h.size(0), self.num_orbits, device=h.device, dtype=h.dtype)

        z      = torch.cat([h, prev_orbits], dim=1)
        delta  = self.regressor(z)
        y_reg  = prev_orbits + delta
        y_clf  = self.classifier(h)
        return y_reg, y_clf


In [None]:
#instantiating model
model = MultiTaskA3TGCN(
    in_channels     = IN_FEATS,
    hidden_channels = HIDDEN_DIM,
    num_orbits      = NUM_ORBITS,
    num_classes     = NUM_CLASSES,
    periods         = PERIODS,
    dropout         = 0.1,
).to(device)

model


In [None]:
# Losses
criterion_reg = nn.SmoothL1Loss(beta=0.5, reduction='mean')
criterion_clf = nn.CrossEntropyLoss(ignore_index=-100)

#Optimiser with split parameter groups
optimizer = torch.optim.Adam([
    {'params': model.a3tgcn.parameters(), 'lr': LR_TRUNK, 'weight_decay': WEIGHT_DECAY},
    {'params': model.classifier.parameters(), 'lr': LR_CLF,   'weight_decay': WEIGHT_DECAY},
    {'params': model.regressor.parameters(),  'lr': LR_REG,   'weight_decay': 0.0},
])

## Helper Functions

In [None]:
def _ensure_dir(path: str | None):
    if path is None:
        return
    d = os.path.dirname(path)
    if d and not os.path.exists(d):
        os.makedirs(d, exist_ok=True)

def _make_checkpoint(model, optimizer, epoch: int, train_stats: dict, val_stats: dict, do_regression: bool):
    ckpt = {
        'epoch': epoch,
        'model_class': type(model).__name__,
        'backbone': 'A3TGCN',
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'train': train_stats,
        'val': val_stats,
        'hparams': {
            'HIDDEN_DIM': HIDDEN_DIM, 'PERIODS': PERIODS, 'LR': LR,
            'WEIGHT_DECAY': WEIGHT_DECAY, 'EPOCHS': EPOCHS, 'LOSS_ALPHA': LOSS_ALPHA,
            'BATCH_SIZE': BATCH_SIZE, 'VAL_WINDOWS_COUNT': VAL_WINDOWS_COUNT,
            'LR_TRUNK': LR_TRUNK, 'LR_CLF': LR_CLF, 'LR_REG': LR_REG,
            'PRETRAIN_EPOCHS': PRETRAIN_EPOCHS, 'PATIENCE': PATIENCE,
            'do_regression': do_regression,
            'SEED': SEED,
        },
        'cpu_rng_state': torch.random.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
        'device': str(device),
    }
    return ckpt

def compute_r2_per_orbit(y_true_np: np.ndarray, y_pred_np: np.ndarray):
    """
    y_true_np, y_pred_np: shape [N_total, num_orbits]
    Returns: list of R^2 per orbit (float, NaN if undefined)
    """
    if y_true_np.size == 0 or y_pred_np.size == 0:
        return []
    if y_true_np.ndim == 1:
        y_true_np = y_true_np[:, None]
        y_pred_np = y_pred_np[:, None]

    ss_res = np.sum((y_true_np - y_pred_np) ** 2, axis=0)
    mu     = np.mean(y_true_np, axis=0)
    ss_tot = np.sum((y_true_np - mu) ** 2, axis=0)

    r2 = np.empty_like(ss_res, dtype=float)
    mask = ss_tot > 0
    r2[mask]  = 1.0 - (ss_res[mask] / (ss_tot[mask] + 1e-12))
    r2[~mask] = np.nan
    return r2.tolist()

def run_epoch(model, loader, training: bool, do_regression: bool, alpha_clf: float):
    if training:
        model.train()
    else:
        model.eval()

    total, reg_sum, clf_sum = 0.0, 0.0, 0.0
    all_true, all_probs = [], []
    reg_true_chunks, reg_pred_chunks = [], []   #for per-orbit R^2

    with torch.set_grad_enabled(training):
        for batch in loader:
            x       = batch['x'].squeeze(0).to(device, non_blocking=True)
            y_clf0  = batch['y_clf'].squeeze(0).to(device, non_blocking=True)
            y_reg   = batch['y_reg'].squeeze(0).to(device, non_blocking=True)
            prev_orb= batch['prev_orbits']
            prev_orb= (prev_orb.squeeze(0).to(device, non_blocking=True) if isinstance(prev_orb, torch.Tensor) else None)

            #Remap labels keeping mask
            y_clf = torch.full_like(y_clf0, -100)
            y_clf[y_clf0 == 1] = 0
            y_clf[y_clf0 == 2] = 1

            #Build edges for window
            int_ids = batch['int_ids']
            id_map  = {int(nid): i for i, nid in enumerate(int_ids)}
            w0      = int(batch['window_start'].item())
            ei, ew_raw = get_edge_data(w0, id_map, edges_ts)
            if ei.numel() == 0 or ew_raw.numel() == 0:
                continue

            # Normalise weights to be in [0,1]
            ew = (ew_raw - ew_raw.min()) / (ew_raw.max() - ew_raw.min() + 1e-6)
            ei = ei.to(device, non_blocking=True)
            ew = ew.to(device, non_blocking=True)

            # Forward pass
            y_reg_pred, y_clf_logits = model(x, ei, ew, prev_orbits=prev_orb)

            #compute losses
            if do_regression:
                loss_reg = criterion_reg(y_reg_pred, y_reg)
            else:
                loss_reg = torch.tensor(0.0, device=device)

            # Only compute classification loss if it contributes; otherwise keep it zero
            loss_clf = criterion_clf(y_clf_logits, y_clf) if alpha_clf > 0 else torch.tensor(0.0, device=device)
            loss = loss_reg + alpha_clf * loss_clf

            if training:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

            total   += loss.item()
            reg_sum += loss_reg.item() if do_regression else 0.0
            clf_sum += (alpha_clf * loss_clf).item()   # log weighted CE (0 during pretrain)

            # Collect clf metrics for labelled nodes only
            mask = (y_clf != -100)
            if mask.sum() > 0:
                probs  = F.softmax(y_clf_logits[mask], dim=1)[:, 1].detach().cpu().numpy()
                labels = y_clf[mask].detach().cpu().numpy()
                all_probs.extend(probs.tolist())
                all_true.extend(labels.tolist())

            # Accumulate for per-orbit R^2
            if do_regression and y_reg.numel() > 0:
                reg_true_chunks.append(y_reg.detach().cpu().numpy())
                reg_pred_chunks.append(y_reg_pred.detach().cpu().numpy())

    #Aggregate classification metrics
    if len(all_true) > 0 and len(set(all_true)) > 1:
        auroc = roc_auc_score(all_true, all_probs)
        auprc = average_precision_score(all_true, all_probs)
        acc   = accuracy_score(all_true, (np.array(all_probs) >= 0.5).astype(int))
        f1    = f1_score(all_true, (np.array(all_probs) >= 0.5).astype(int))
    else:
        auroc = auprc = acc = f1 = float('nan')

    #Per-orbit R^2
    if do_regression and len(reg_true_chunks) > 0:
        y_true_np = np.vstack(reg_true_chunks)
        y_pred_np = np.vstack(reg_pred_chunks)
        r2_by_orbit = compute_r2_per_orbit(y_true_np, y_pred_np)
    else:
        r2_by_orbit = []

    return {
        'loss': total / max(len(loader), 1),
        'reg_loss': reg_sum / max(len(loader), 1),
        'clf_loss': clf_sum / max(len(loader), 1),
        'auroc': auroc, 'auprc': auprc, 'acc': acc, 'f1': f1,
        'r2_by_orbit': r2_by_orbit,
    }

def train_model(model, train_loader, val_loader, do_regression: bool, pretrain_epochs: int,
                alpha_clf: float, checkpoint_path: str | None = None,
                start_epoch: int = 1, max_epochs: int | None = None,
                init_eval: bool = True):
    """
    Works from scratch and when resuming after a checkpoint:
      - start_epoch: epoch index to start from (use your loaded ckpt epoch+1)
      - max_epochs: number of epochs to run in THIS call (defaults to EPOCHS)
      - init_eval: if True, seeds 'best' with an initial validation at start_epoch-1 state
    """
    best = {'val_loss': float('inf'), 'state': None, 'history': [], 'path': None}
    epochs_no_improve = 0
    _ensure_dir(checkpoint_path)

    total_epochs = (max_epochs if max_epochs is not None else EPOCHS)
    end_epoch    = start_epoch + total_epochs - 1

    # Seed 'best' with a validation pass of the current state (useful when resuming)
    if init_eval:
        # Use the same warm-up logic for the very first validation depending on start_epoch
        init_val_alpha = 0.0 if (do_regression and start_epoch <= pretrain_epochs) else alpha_clf
        va0 = run_epoch(model, val_loader, training=False, do_regression=do_regression, alpha_clf=init_val_alpha)
        best['val_loss'] = va0['loss']
        # Keep current weights as "best" snapshot
        best['state'] = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        best['history'].append({'epoch': start_epoch - 1, 'train': None, 'val': va0})
        print(f"Init @ epoch {start_epoch-1}: val_loss={va0['loss']:.4f}, AUROC={va0['auroc']:.3f}, "
              f"AUPRC={va0['auprc']:.3f}, F1={va0['f1']:.3f}"
              + (f", R2μ(val)={np.nanmean([r for r in va0['r2_by_orbit'] if r==r]):.3f}" if do_regression and va0['r2_by_orbit'] else ""))

    for epoch in range(start_epoch, end_epoch + 1):
        # warm-up only if we're still within pretrain window
        if do_regression and epoch <= pretrain_epochs:
            alpha_now = 0.0
            val_alpha = 0.0
        else:
            alpha_now = alpha_clf
            val_alpha = alpha_clf

        tr = run_epoch(model, train_loader, training=True,  do_regression=do_regression, alpha_clf=alpha_now)
        va = run_epoch(model, val_loader,   training=False, do_regression=do_regression, alpha_clf=val_alpha)

        best['history'].append({'epoch': epoch, 'train': tr, 'val': va})

        # Compact R² means for log
        def _mean_safe(r2s):
            arr = np.array([r for r in r2s if r == r])  # drop NaNs
            return float(np.mean(arr)) if arr.size else float('nan')

        msg = (f"Epoch {epoch:3d}/{end_epoch} | "
               f"train: loss={tr['loss']:.4f} (reg={tr['reg_loss']:.4f}, clf={tr['clf_loss']:.4f}) "
               f"| val: loss={va['loss']:.4f}, AUROC={va['auroc']:.3f}, AUPRC={va['auprc']:.3f}, F1={va['f1']:.3f}")
        if do_regression:
            msg += (f" | R2μ(train)={_mean_safe(tr['r2_by_orbit']):.3f}, R2μ(val)={_mean_safe(va['r2_by_orbit']):.3f}")
        print(msg)

        # Optionally print full per-orbit R^2 for val
        if do_regression and va['r2_by_orbit']:
            r2_str = ", ".join(f"{r:.3f}" if np.isfinite(r) else "nan" for r in va['r2_by_orbit'])
            print(f"    Val R2 per-orbit: [{r2_str}]")

        # Early stopping on val loss (+ save best checkpoint if improved)
        if va['loss'] + 1e-6 < best['val_loss']:
            best['val_loss'] = va['loss']
            best['state'] = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            epochs_no_improve = 0

            if checkpoint_path is not None:
                ckpt = _make_checkpoint(model, optimizer, epoch, tr, va, do_regression=do_regression)
                torch.save(ckpt, checkpoint_path)
                best['path'] = checkpoint_path
                print(f"[Saved best checkpoint → {checkpoint_path}]")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"Early stop at epoch {epoch} (no improvement for {PATIENCE} epochs)")
                break

    #restoring best weights if needed
    if best['state'] is not None:
        model.load_state_dict(best['state'], strict=True)
    else:
        print("Warning: no best state captured; model left as-is.")

    return best


## Training multi-task model

In [None]:
print('--- Training Multi‑Task (regression + classification) ---')

mt_best = train_model(
    model, train_loader, val_loader,
    do_regression=True,
    pretrain_epochs=PRETRAIN_EPOCHS,
    alpha_clf=LOSS_ALPHA,
    checkpoint_path="/content/drive/MyDrive/Final Project/models/best_multi.pt",
)

In [None]:
# --- Resume Multi-Task from checkpoint  ---

CKPT_PATH = "/content/drive/MyDrive/Final Project/models/best_multi.pt"

# Explicitly set weights_only=False s
ckpt = torch.load(CKPT_PATH, map_location=device, weights_only=False)

# Rebuild model exactly as before
model = MultiTaskA3TGCN(
    in_channels     = IN_FEATS,
    hidden_channels = HIDDEN_DIM,
    num_orbits      = NUM_ORBITS,
    num_classes     = NUM_CLASSES,
    periods         = PERIODS,
    dropout         = 0.1,
).to(device)
model.load_state_dict(ckpt["model_state"], strict=True)

# Recreate optimiser with same param groups
optimizer = torch.optim.Adam([
    {'params': model.a3tgcn.parameters(), 'lr': LR_TRUNK, 'weight_decay': WEIGHT_DECAY},
    {'params': model.classifier.parameters(), 'lr': LR_CLF,   'weight_decay': WEIGHT_DECAY},
    {'params': model.regressor.parameters(),  'lr': LR_REG,   'weight_decay': 0.0},
])
optimizer.load_state_dict(ckpt["optimizer_state"])

# --- Safe RNG restore ---
def _to_byte_cpu_tensor(x):
    """Coerce x into a CPU torch.uint8 tensor."""
    if isinstance(x, torch.Tensor):
        return x.detach().to('cpu', dtype=torch.uint8)
    import numpy as np
    if isinstance(x, (bytes, bytearray)):
        return torch.tensor(list(x), dtype=torch.uint8)
    if isinstance(x, (list, tuple, np.ndarray)):
        return torch.as_tensor(x, dtype=torch.uint8, device='cpu')
    return torch.as_tensor(x, dtype=torch.uint8, device='cpu')

def _restore_cpu_rng(state):
    if state is None:
        return
    try:
        torch.random.set_rng_state(_to_byte_cpu_tensor(state))
    except Exception as e:
        print(f"Warning: could not restore CPU RNG: {e}")

def _restore_cuda_rng(states):
    if not torch.cuda.is_available() or states is None:
        return
    try:
        fixed = []
        for s in states:
            fixed.append(_to_byte_cpu_tensor(s))
        torch.cuda.set_rng_state_all(fixed)
    except Exception as e:
        print(f"Warning: could not restore CUDA RNG: {e}")

# Restore RNG states
_restore_cpu_rng(ckpt.get("cpu_rng_state"))
_restore_cuda_rng(ckpt.get("cuda_rng_state"))

# Compute resume epoch
start_epoch = int(ckpt.get("epoch", 0)) + 1
print(f"Resuming from epoch {start_epoch}")


In [None]:
print('--- Training Multi‑Task (regression + classification) ---')
PRETRAIN_EPOCHS = 0         # important when resuming
EXTRA_EPOCHS    = 100
PATIENCE = 20
mt_best = train_model(
    model, train_loader, val_loader,
    do_regression=True,
    pretrain_epochs=PRETRAIN_EPOCHS,
    alpha_clf=LOSS_ALPHA,
    checkpoint_path="/content/drive/MyDrive/Final Project/models/best_multi.pt",
    start_epoch=start_epoch,
    max_epochs=EXTRA_EPOCHS,
)


## Training classifier

In [None]:
# Build loaders without orbit targets
clf_train_loader, clf_val_loader = build_dataloaders(
    features_csv      = 'data/selected_features_windows_scaled.csv',
    classes_csv       = 'data/txs_classes.csv',
    verts_parquet     = 'data/verts_int.parquet',
    batch_size        = BATCH_SIZE,
    val_windows_count = VAL_WINDOWS_COUNT,
    include_orbits    = False,          # CLASSIFIER ONLY
    worker_init_fn    = seed_worker,
    generator         = g,
)

# (Optional) infer IN_FEATS defensively from these loaders
_sample = next(iter(clf_train_loader))
IN_FEATS = _sample['x'].shape[-1]
NUM_CLASSES = 2  # as before; keep consistent with your label remap
PATIENCE        = 20

# Fresh model (same backbone/head dims for fair comparison)
clf_model = MultiTaskA3TGCN(
    in_channels     = IN_FEATS,
    hidden_channels = HIDDEN_DIM,
    num_orbits      = NUM_ORBITS,   # kept for shape; prev_orbits will be zero-filled internally
    num_classes     = NUM_CLASSES,
    periods         = PERIODS,
    dropout         = 0.1,
).to(device)

# Losses
criterion_reg = nn.SmoothL1Loss(beta=0.5, reduction='mean')
criterion_clf = nn.CrossEntropyLoss(ignore_index=-100)

# IMPORTANT: reset the global optimizer used by run_epoch/train_model
optimizer = torch.optim.Adam([
    {'params': clf_model.a3tgcn.parameters(), 'lr': LR_TRUNK, 'weight_decay': WEIGHT_DECAY},
    {'params': clf_model.classifier.parameters(), 'lr': LR_CLF,   'weight_decay': WEIGHT_DECAY},
    {'params': clf_model.regressor.parameters(),  'lr': LR_REG,   'weight_decay': 0.0},
])

In [None]:
# Train classifier-only (no regression term, no warm-up)
print('--- Training Classifier-Only ---')
clf_best = train_model(
    clf_model, clf_train_loader, clf_val_loader,
    do_regression   = False,      # <-- key difference
    pretrain_epochs = 0,          # no pretrain when regression is off
    alpha_clf       = 1.0,        # full weight on CE
    checkpoint_path = "/content/drive/MyDrive/Final Project/models/best_classifier.pt",
)

## Evaluating models

In [None]:
## loading the best multi task
print('--- Load best MULTI-TASK for evaluation ---')

# Build val loader with orbits
_, mt_val_loader = build_dataloaders(
    features_csv      = 'data/selected_features_windows_scaled.csv',
    classes_csv       = 'data/txs_classes.csv',
    verts_parquet     = 'data/verts_int.parquet',
    batch_size        = BATCH_SIZE,
    val_windows_count = VAL_WINDOWS_COUNT,
    include_orbits    = True,    # key difference
    worker_init_fn    = seed_worker,
    generator         = g,
)

# Recreate model
_sample = next(iter(mt_val_loader))
IN_FEATS   = _sample['x'].shape[-1]
NUM_CLASSES= 2

mt_model = MultiTaskA3TGCN(
    in_channels     = IN_FEATS,
    hidden_channels = HIDDEN_DIM,
    num_orbits      = NUM_ORBITS,
    num_classes     = NUM_CLASSES,
    periods         = PERIODS,
    dropout         = 0.1,
).to(device)

# Load weights
ckpt_mt = torch.load("/content/drive/MyDrive/Final Project/models/best_multi.pt",
                     map_location=device, weights_only=False)
mt_model.load_state_dict(ckpt_mt["model_state"], strict=True)
mt_model.eval()

# Point your generic names to this model/loader for the eval cell
model      = mt_model
val_loader = mt_val_loader

# Use these in your run_epoch call:
EVAL_DO_REG = True
EVAL_ALPHA  = LOSS_ALPHA   # your hyperparam (e.g., 0.5)


In [None]:
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay

va = run_epoch(model, val_loader, training=False, do_regression=EVAL_DO_REG, alpha_clf=EVAL_ALPHA)

print(f"MT — Val: loss={va['loss']:.4f}, AUROC={va['auroc']:.3f}, AUPRC={va['auprc']:.3f}, F1={va['f1']:.3f}")

# Confusion matrix on labelled nodes
all_true, all_pred = [], []
with torch.no_grad():
    for batch in val_loader:
        x      = batch['x'].squeeze(0).to(device)
        y_clf0 = batch['y_clf'].squeeze(0).to(device)
        prev   = batch['prev_orbits']
        prev   = (prev.squeeze(0).to(device) if isinstance(prev, torch.Tensor) else None)

        y_clf = torch.full_like(y_clf0, -100); y_clf[y_clf0==1]=0; y_clf[y_clf0==2]=1

        id_map = {int(nid): i for i, nid in enumerate(batch['int_ids'])}
        w0     = int(batch['window_start'].item())
        ei, ew_raw = get_edge_data(w0, id_map, edges_ts)
        if ei.numel()==0 or ew_raw.numel()==0:
            continue
        ew = (ew_raw - ew_raw.min()) / (ew_raw.max() - ew_raw.min() + 1e-6)

        y_reg_pred, logits = model(x, ei.to(device), ew.to(device), prev_orbits=prev)
        mask = (y_clf != -100)
        if mask.sum() > 0:
            all_true.extend(y_clf[mask].cpu().numpy().tolist())
            all_pred.extend(F.softmax(logits[mask], dim=1)[:,1].cpu().numpy().tolist())

if len(all_true)>0 and len(set(all_true))>1:
    fpr, tpr, _ = roc_curve(all_true, all_pred)
    plt.figure(); plt.plot(fpr,tpr); plt.plot([0,1],[0,1],'--'); plt.title('MT ROC'); plt.xlabel('FPR'); plt.ylabel('TPR'); plt.show()

    prec, rec, _ = precision_recall_curve(all_true, all_pred)
    plt.figure(); plt.plot(rec,prec); plt.title('MT PR'); plt.xlabel('Recall'); plt.ylabel('Precision'); plt.show()

cm = confusion_matrix(all_true, (np.array(all_pred)>=0.5).astype(int)) if len(all_true)>0 else np.array([[0,0],[0,0]])
plt.figure(figsize=(5,4)); sns.heatmap(cm, annot=True, fmt='d'); plt.title('MT Confusion Matrix'); plt.xlabel('Pred'); plt.ylabel('True'); plt.tight_layout(); plt.show()

In [None]:
## loading best classification
print('--- Load best CLASSIFIER for evaluation ---')

# Build val loader without orbit targets
_, clf_val_loader = build_dataloaders(
    features_csv      = 'data/selected_features_windows_scaled.csv',
    classes_csv       = 'data/txs_classes.csv',
    verts_parquet     = 'data/verts_int.parquet',
    batch_size        = BATCH_SIZE,
    val_windows_count = VAL_WINDOWS_COUNT,
    include_orbits    = False,   # key difference
    worker_init_fn    = seed_worker,
    generator         = g,
)

# Recreate model (same dims you trained with)
_sample = next(iter(clf_val_loader))
IN_FEATS   = _sample['x'].shape[-1]
NUM_CLASSES= 2

clf_model = MultiTaskA3TGCN(
    in_channels     = IN_FEATS,
    hidden_channels = HIDDEN_DIM,
    num_orbits      = NUM_ORBITS,   # ok: prev_orbits will be zero-filled
    num_classes     = NUM_CLASSES,
    periods         = PERIODS,
    dropout         = 0.1,
).to(device)

# Load weights
ckpt_clf = torch.load("/content/drive/MyDrive/Final Project/models/best_classifier.pt",
                      map_location=device, weights_only=False)
clf_model.load_state_dict(ckpt_clf["model_state"], strict=True)
clf_model.eval()

# Point your generic names to this model/loader for the eval cell
model      = clf_model
val_loader = clf_val_loader

# Use these in your run_epoch call:
EVAL_DO_REG = False
EVAL_ALPHA  = 1.0


In [None]:
va = run_epoch(model, val_loader, training=False, do_regression=EVAL_DO_REG, alpha_clf=EVAL_ALPHA)

print(f"MT — Val: loss={va['loss']:.4f}, AUROC={va['auroc']:.3f}, AUPRC={va['auprc']:.3f}, F1={va['f1']:.3f}")

# Confusion matrix on labelled nodes
all_true, all_pred = [], []
with torch.no_grad():
    for batch in val_loader:
        x      = batch['x'].squeeze(0).to(device)
        y_clf0 = batch['y_clf'].squeeze(0).to(device)
        prev   = batch['prev_orbits']
        prev   = (prev.squeeze(0).to(device) if isinstance(prev, torch.Tensor) else None)

        y_clf = torch.full_like(y_clf0, -100); y_clf[y_clf0==1]=0; y_clf[y_clf0==2]=1

        id_map = {int(nid): i for i, nid in enumerate(batch['int_ids'])}
        w0     = int(batch['window_start'].item())
        ei, ew_raw = get_edge_data(w0, id_map, edges_ts)
        if ei.numel()==0 or ew_raw.numel()==0:
            continue
        ew = (ew_raw - ew_raw.min()) / (ew_raw.max() - ew_raw.min() + 1e-6)

        y_reg_pred, logits = model(x, ei.to(device), ew.to(device), prev_orbits=prev)
        mask = (y_clf != -100)
        if mask.sum() > 0:
            all_true.extend(y_clf[mask].cpu().numpy().tolist())
            all_pred.extend(F.softmax(logits[mask], dim=1)[:,1].cpu().numpy().tolist())

if len(all_true)>0 and len(set(all_true))>1:
    fpr, tpr, _ = roc_curve(all_true, all_pred)
    plt.figure(); plt.plot(fpr,tpr); plt.plot([0,1],[0,1],'--'); plt.title('CL ROC'); plt.xlabel('FPR'); plt.ylabel('TPR'); plt.show()

    prec, rec, _ = precision_recall_curve(all_true, all_pred)
    plt.figure(); plt.plot(rec,prec); plt.title('CL PR'); plt.xlabel('Recall'); plt.ylabel('Precision'); plt.show()

cm = confusion_matrix(all_true, (np.array(all_pred)>=0.5).astype(int)) if len(all_true)>0 else np.array([[0,0],[0,0]])
plt.figure(figsize=(5,4)); sns.heatmap(cm, annot=True, fmt='d'); plt.title('CL Confusion Matrix'); plt.xlabel('Pred'); plt.ylabel('True'); plt.tight_layout(); plt.show()