In [6]:
# ======================================================
# FINAL MEMORY-SAFE SPATIO-TEMPORAL TGNN
# Proper Temporal Split + Chunked Training
# ======================================================

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_recall_curve, auc

# ---------------- CONFIG ----------------
DATA_DIR = "/kaggle/input/firms-01"
EPOCHS = 10
LR = 1e-3
TRAIN_SPLIT = 0.8
CHUNK = 14  # Temporal chunk size (Truncated BPTT)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", DEVICE)
print("=" * 60)


# ======================================================
# LOAD TILE SEQUENCES
# ======================================================

tiles = []

for fname in sorted(os.listdir(DATA_DIR)):
    if fname.endswith(".npz"):
        path = os.path.join(DATA_DIR, fname)
        data = np.load(path)

        X = torch.tensor(data["X"], dtype=torch.float32)
        y = torch.tensor(data["y"], dtype=torch.float32)
        edge_index = torch.tensor(data["edge_index"], dtype=torch.long)

        tiles.append((X, y, edge_index))

print("Total tiles:", len(tiles))


# ======================================================
# GRAPH CONV
# ======================================================

class GraphConv(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x, edge_index):
        row, col = edge_index
        agg = torch.zeros_like(x)
        agg.index_add_(0, row, x[col])
        return self.linear(x + agg)


# ======================================================
# TGNN MODEL
# ======================================================

class TGNN(nn.Module):
    def __init__(self,
                 node_features=6,
                 hidden_dim=32,
                 gru_hidden_dim=64,
                 num_gnn_layers=2):
        super().__init__()

        self.input_proj = nn.Linear(node_features, hidden_dim)

        self.gnn_layers = nn.ModuleList(
            [GraphConv(hidden_dim, hidden_dim)
             for _ in range(num_gnn_layers)]
        )

        self.batch_norms = nn.ModuleList(
            [nn.BatchNorm1d(hidden_dim)
             for _ in range(num_gnn_layers)]
        )

        self.gru = nn.GRU(
            input_size=hidden_dim,
            hidden_size=gru_hidden_dim,
            num_layers=1,
            batch_first=True
        )

        self.fc = nn.Linear(gru_hidden_dim, 1)

    def forward(self, X_seq, edge_index, hidden=None):

        T, N, _ = X_seq.shape
        spatial_outputs = []

        for t in range(T):
            x = X_seq[t]

            x = F.relu(self.input_proj(x))

            for conv, bn in zip(self.gnn_layers, self.batch_norms):
                x = F.relu(bn(conv(x, edge_index)))

            spatial_outputs.append(x)

        H = torch.stack(spatial_outputs)     # (T, N, hidden)
        H = H.permute(1, 0, 2)               # (N, T, hidden)

        gru_out, hidden = self.gru(H, hidden)

        logits = self.fc(gru_out).squeeze(-1)  # (N, T)
        logits = logits.permute(1, 0)          # (T, N)

        return logits, hidden


# ======================================================
# CLASS IMBALANCE COMPUTATION
# ======================================================

all_labels = []

for X_seq, y_seq, _ in tiles:
    all_labels.append(y_seq.reshape(-1))

all_labels = torch.cat(all_labels)
pos_ratio = all_labels.mean().item()

print("Positive ratio:", pos_ratio)

pos_weight = torch.tensor([(1 - pos_ratio) / pos_ratio]).to(DEVICE)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)


# ======================================================
# MODEL & OPTIMIZER
# ======================================================

model = TGNN().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)


# ======================================================
# TRAINING LOOP (Tile-by-Tile)
# ======================================================

best_val_loss = float('inf')

for epoch in range(EPOCHS):

    model.train()
    train_loss = 0

    for X_seq, y_seq, edge_index in tiles:

        X_seq = X_seq.to(DEVICE)
        y_seq = y_seq.to(DEVICE)
        edge_index = edge_index.to(DEVICE)

        # -------- Temporal Split (NO LEAKAGE) --------
        T = X_seq.shape[0]
        split = int(TRAIN_SPLIT * T)

        X_train = X_seq[:split]
        y_train = y_seq[:split]

        X_val = X_seq[split:]
        y_val = y_seq[split:]

        hidden = None
        optimizer.zero_grad()
        total_loss = 0

        # -------- Truncated BPTT --------
        for t in range(0, X_train.shape[0], CHUNK):

            X_chunk = X_train[t:t+CHUNK]
            y_chunk = y_train[t:t+CHUNK]

            logits, hidden = model(X_chunk, edge_index, hidden)

            loss = criterion(
                logits.reshape(-1),
                y_chunk.reshape(-1)
            )

            loss.backward()
            hidden = hidden.detach()

            total_loss += loss.item()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        train_loss += total_loss

    train_loss /= len(tiles)

    # ======================================================
    # VALIDATION
    # ======================================================

    model.eval()
    val_loss = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():

        for X_seq, y_seq, edge_index in tiles:

            X_seq = X_seq.to(DEVICE)
            y_seq = y_seq.to(DEVICE)
            edge_index = edge_index.to(DEVICE)

            T = X_seq.shape[0]
            split = int(TRAIN_SPLIT * T)

            X_val = X_seq[split:]
            y_val = y_seq[split:]

            hidden = None

            for t in range(0, X_val.shape[0], CHUNK):

                X_chunk = X_val[t:t+CHUNK]
                y_chunk = y_val[t:t+CHUNK]

                logits, hidden = model(X_chunk, edge_index, hidden)

                loss = criterion(
                    logits.reshape(-1),
                    y_chunk.reshape(-1)
                )

                val_loss += loss.item()

                probs = torch.sigmoid(logits)
                all_preds.append(probs.cpu().numpy().reshape(-1))
                all_targets.append(y_chunk.cpu().numpy().reshape(-1))

                hidden = hidden.detach()

    val_loss /= len(tiles)

    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)

    precision, recall, _ = precision_recall_curve(all_targets, all_preds)
    pr_auc = auc(recall, precision)

    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss:   {val_loss:.4f}")
    print(f"Val PR-AUC: {pr_auc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_tgnn.pt")
        print("✓ Best model saved")

print("\nTraining complete.")
print("=" * 60)

Using device: cuda
Total tiles: 130
Positive ratio: 0.004165706690400839

Epoch 1/10
Train Loss: 21.1994
Val Loss:   4.6897
Val PR-AUC: 0.0110
✓ Best model saved

Epoch 2/10
Train Loss: 14.4351
Val Loss:   5.0512
Val PR-AUC: 0.0124

Epoch 3/10
Train Loss: 13.7306
Val Loss:   4.8952
Val PR-AUC: 0.0133

Epoch 4/10
Train Loss: 13.3922
Val Loss:   4.0714
Val PR-AUC: 0.0139
✓ Best model saved

Epoch 5/10
Train Loss: 13.1973
Val Loss:   4.2223
Val PR-AUC: 0.0116

Epoch 6/10
Train Loss: 13.0181
Val Loss:   4.3511
Val PR-AUC: 0.0133

Epoch 7/10
Train Loss: 12.8538
Val Loss:   4.2940
Val PR-AUC: 0.0123

Epoch 8/10
Train Loss: 12.7519
Val Loss:   4.3835
Val PR-AUC: 0.0126

Epoch 9/10
Train Loss: 12.6334
Val Loss:   4.2684
Val PR-AUC: 0.0139

Epoch 10/10
Train Loss: 12.5690
Val Loss:   4.4889
Val PR-AUC: 0.0155

Training complete.


In [8]:
# ======================================================
# CORRECTED SPATIO-TEMPORAL TGNN (FIXED SCHEDULER)
# Proper Tile-Based Split + Memory-Safe Training
# ======================================================

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_recall_curve, auc, roc_auc_score

# ---------------- CONFIG ----------------
DATA_DIR = "/kaggle/input/firms-01"
EPOCHS = 10
LR = 1e-3
TRAIN_SPLIT = 0.8
CHUNK = 14  # Truncated BPTT window
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATIENCE = 3

print("Using device:", DEVICE)
print("=" * 60)


# ======================================================
# FOCAL LOSS (Better for imbalance)
# ======================================================

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        bce_loss = F.binary_cross_entropy_with_logits(
            logits, targets, reduction='none'
        )
        p_t = probs * targets + (1 - probs) * (1 - targets)
        focal_weight = (1 - p_t) ** self.gamma
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        focal_loss = alpha_t * focal_weight * bce_loss
        return focal_loss.mean()


# ======================================================
# LOAD TILE SEQUENCES
# ======================================================

print("\nLoading tiles...")
tiles = []

for fname in sorted(os.listdir(DATA_DIR)):
    if fname.endswith(".npz"):
        path = os.path.join(DATA_DIR, fname)
        data = np.load(path)

        X = torch.tensor(data["X"], dtype=torch.float32)
        y = torch.tensor(data["y"], dtype=torch.float32)
        edge_index = torch.tensor(data["edge_index"], dtype=torch.long)

        tiles.append((X, y, edge_index))

print(f"Total tiles loaded: {len(tiles)}")

# ======================================================
# PROPER TILE-BASED TRAIN/VAL SPLIT (NO LEAKAGE!)
# ======================================================

split_idx = int(TRAIN_SPLIT * len(tiles))
train_tiles = tiles[:split_idx]
val_tiles = tiles[split_idx:]

print(f"\nDataset split:")
print(f"  Train tiles: {len(train_tiles)}")
print(f"  Val tiles:   {len(val_tiles)}")

# ======================================================
# ANALYZE CLASS DISTRIBUTION
# ======================================================

print("\nAnalyzing class distribution...")
all_labels = []
for X_seq, y_seq, _ in tiles:
    all_labels.append(y_seq.reshape(-1))

all_labels = torch.cat(all_labels)
pos_ratio = all_labels.mean().item()
neg_ratio = 1 - pos_ratio

print(f"Positive ratio: {pos_ratio:.4%}")
print(f"Negative ratio: {neg_ratio:.4%}")
print(f"Imbalance ratio: {neg_ratio/pos_ratio:.2f}:1")


# ======================================================
# GRAPH CONV
# ======================================================

class GraphConv(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x, edge_index):
        row, col = edge_index
        agg = torch.zeros_like(x)
        agg.index_add_(0, row, x[col])
        return self.linear(x + agg)


# ======================================================
# TGNN MODEL
# ======================================================

class TGNN(nn.Module):
    def __init__(self,
                 node_features=6,
                 hidden_dim=64,
                 gru_hidden_dim=128,
                 num_gnn_layers=2,
                 dropout=0.3):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.gru_hidden_dim = gru_hidden_dim

        self.input_proj = nn.Linear(node_features, hidden_dim)

        self.gnn_layers = nn.ModuleList(
            [GraphConv(hidden_dim, hidden_dim)
             for _ in range(num_gnn_layers)]
        )

        self.batch_norms = nn.ModuleList(
            [nn.BatchNorm1d(hidden_dim)
             for _ in range(num_gnn_layers)]
        )

        self.gru = nn.GRU(
            input_size=hidden_dim,
            hidden_size=gru_hidden_dim,
            num_layers=1,
            batch_first=True
        )

        self.fc1 = nn.Linear(gru_hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 32)
        self.fc3 = nn.Linear(32, 1)
        
        self.dropout = nn.Dropout(dropout)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, X_seq, edge_index, hidden=None):
        """
        Args:
            X_seq: [T, N, F] - Temporal sequence of node features
            edge_index: [2, E] - Graph structure
            hidden: [1, N, H] - GRU hidden state
        
        Returns:
            logits: [T, N] - Predictions for each timestep
            hidden: [1, N, H] - Updated hidden state
        """
        T, N, _ = X_seq.shape
        spatial_outputs = []

        # Spatial processing for each timestep
        for t in range(T):
            x = X_seq[t]
            
            x = F.relu(self.input_proj(x))
            x = self.dropout(x)

            # GNN layers with residual connections
            for i, (conv, bn) in enumerate(zip(self.gnn_layers, self.batch_norms)):
                residual = x
                x = conv(x, edge_index)
                x = bn(x)
                x = F.relu(x)
                x = self.dropout(x)
                
                if i > 0:
                    x = x + residual

            spatial_outputs.append(x)

        # Stack temporal dimension
        H = torch.stack(spatial_outputs)  # [T, N, hidden]
        H = H.permute(1, 0, 2)            # [N, T, hidden]

        # Initialize hidden state if needed
        if hidden is None:
            hidden = torch.zeros(
                1, N, self.gru_hidden_dim,
                device=X_seq.device, dtype=X_seq.dtype
            )

        # Temporal processing with GRU
        gru_out, hidden = self.gru(H, hidden)

        # Prediction head
        x = gru_out  # [N, T, gru_hidden_dim]
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        
        logits = self.fc3(x).squeeze(-1)  # [N, T]
        logits = logits.permute(1, 0)     # [T, N]

        return logits, hidden


# ======================================================
# TRAINING FUNCTION
# ======================================================

def train_epoch(model, tiles, criterion, optimizer, device, chunk_size):
    """Train for one epoch on all tiles"""
    model.train()
    total_loss = 0
    num_chunks = 0

    for X_seq, y_seq, edge_index in tiles:
        X_seq = X_seq.to(device)
        y_seq = y_seq.to(device)
        edge_index = edge_index.to(device)

        T = X_seq.shape[0]
        hidden = None

        # Process tile in chunks (Truncated BPTT)
        for t_start in range(0, T, chunk_size):
            t_end = min(t_start + chunk_size, T)
            
            X_chunk = X_seq[t_start:t_end]
            y_chunk = y_seq[t_start:t_end]

            optimizer.zero_grad()

            logits, hidden = model(X_chunk, edge_index, hidden)

            loss = criterion(
                logits.reshape(-1),
                y_chunk.reshape(-1)
            )

            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()

            total_loss += loss.item()
            num_chunks += 1
            
            # Detach hidden state to prevent backprop through entire history
            if hidden is not None:
                hidden = hidden.detach()

    return total_loss / num_chunks


# ======================================================
# EVALUATION FUNCTION
# ======================================================

def evaluate(model, tiles, criterion, device, chunk_size):
    """Evaluate model on validation tiles"""
    model.eval()
    total_loss = 0
    num_chunks = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for X_seq, y_seq, edge_index in tiles:
            X_seq = X_seq.to(device)
            y_seq = y_seq.to(device)
            edge_index = edge_index.to(device)

            T = X_seq.shape[0]
            hidden = None

            # Process in chunks
            for t_start in range(0, T, chunk_size):
                t_end = min(t_start + chunk_size, T)
                
                X_chunk = X_seq[t_start:t_end]
                y_chunk = y_seq[t_start:t_end]

                logits, hidden = model(X_chunk, edge_index, hidden)

                loss = criterion(
                    logits.reshape(-1),
                    y_chunk.reshape(-1)
                )

                total_loss += loss.item()
                num_chunks += 1

                # Collect predictions
                probs = torch.sigmoid(logits)
                all_preds.append(probs.cpu().numpy().reshape(-1))
                all_targets.append(y_chunk.cpu().numpy().reshape(-1))
                
                if hidden is not None:
                    hidden = hidden.detach()

    # Concatenate all predictions
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)

    # Compute metrics
    avg_loss = total_loss / num_chunks
    
    try:
        roc_auc = roc_auc_score(all_targets, all_preds)
    except:
        roc_auc = 0.0
    
    try:
        precision, recall, _ = precision_recall_curve(all_targets, all_preds)
        pr_auc = auc(recall, precision)
    except:
        pr_auc = 0.0

    return {
        'loss': avg_loss,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc
    }


# ======================================================
# MODEL & OPTIMIZER SETUP
# ======================================================

print("\nInitializing model...")
model = TGNN(
    node_features=6,
    hidden_dim=64,
    gru_hidden_dim=128,
    num_gnn_layers=2,
    dropout=0.3
).to(DEVICE)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params:,}")

# Choose loss function based on imbalance
if pos_ratio < 0.01:
    print("\nUsing Focal Loss (extreme imbalance)")
    criterion = FocalLoss(alpha=0.25, gamma=2.0)
else:
    print("\nUsing Weighted BCE Loss")
    pos_weight = torch.tensor([neg_ratio / pos_ratio]).to(DEVICE)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

# Fixed: Removed 'verbose' parameter (not supported in newer PyTorch)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2
)


# ======================================================
# TRAINING LOOP
# ======================================================

print("\n" + "=" * 60)
print("STARTING TRAINING")
print("=" * 60)

best_val_loss = float('inf')
patience_counter = 0
current_lr = LR

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 60)

    # Training
    train_loss = train_epoch(
        model, train_tiles, criterion, optimizer, DEVICE, CHUNK
    )

    # Validation
    val_metrics = evaluate(
        model, val_tiles, criterion, DEVICE, CHUNK
    )

    # Learning rate scheduling
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_metrics['loss'])
    new_lr = optimizer.param_groups[0]['lr']
    
    # Manual verbose output for LR changes
    if old_lr != new_lr:
        print(f"Learning rate reduced: {old_lr:.6f} -> {new_lr:.6f}")

    # Print metrics
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss:   {val_metrics['loss']:.4f}")
    print(f"Val ROC-AUC: {val_metrics['roc_auc']:.4f}")
    print(f"Val PR-AUC:  {val_metrics['pr_auc']:.4f}")

    # Save best model
    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        patience_counter = 0
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_metrics['loss'],
            'val_roc_auc': val_metrics['roc_auc'],
            'val_pr_auc': val_metrics['pr_auc'],
        }, 'best_tgnn.pt')
        
        print("✓ Best model saved!")
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{PATIENCE}")

    # Early stopping
    if patience_counter >= PATIENCE:
        print(f"\nEarly stopping triggered after {epoch+1} epochs")
        break

print("\n" + "=" * 60)
print("TRAINING COMPLETE")
print("=" * 60)
print(f"Best validation loss: {best_val_loss:.4f}")

Using device: cuda

Loading tiles...
Total tiles loaded: 130

Dataset split:
  Train tiles: 104
  Val tiles:   26

Analyzing class distribution...
Positive ratio: 0.4166%
Negative ratio: 99.5834%
Imbalance ratio: 239.06:1

Initializing model...
Total parameters: 93,889

Using Focal Loss (extreme imbalance)

STARTING TRAINING

Epoch 1/10
------------------------------------------------------------
Train Loss: 0.0031
Val Loss:   0.0020
Val ROC-AUC: 0.8522
Val PR-AUC:  0.0221
✓ Best model saved!

Epoch 2/10
------------------------------------------------------------
Train Loss: 0.0024
Val Loss:   0.0020
Val ROC-AUC: 0.8595
Val PR-AUC:  0.0223
✓ Best model saved!

Epoch 3/10
------------------------------------------------------------
Train Loss: 0.0023
Val Loss:   0.0020
Val ROC-AUC: 0.8598
Val PR-AUC:  0.0227
✓ Best model saved!

Epoch 4/10
------------------------------------------------------------
Train Loss: 0.0022
Val Loss:   0.0020
Val ROC-AUC: 0.8589
Val PR-AUC:  0.0232
No improv