# Phase 4: Post-Training Analysis

This notebook analyzes the trained Lead Scout Model to verify:
1. **Loss Convergence** - Are train/val losses decreasing?
2. **Evaluation Metrics** - Precision, Recall, F1-score
3. **Attention Patterns** - Before vs After training comparison
4. **Confusion Matrix** - False positives/negatives analysis

In [None]:
# ============================================================
# Cell 1: Setup
# ============================================================
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc

from src.tokenizer import SalesTokenizer
from src.model.lead_scout import LeadScoutModel
from src.data.dataset import LeadDataset
from torch.utils.data import DataLoader, random_split

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("‚úÖ Imports successful")

## 1. Load Training History & Plot Loss Curves

In [None]:
# ============================================================
# Cell 2: Load Training History
# ============================================================
history_path = '../checkpoints/training_history.json'

try:
    with open(history_path, 'r') as f:
        history = json.load(f)
    print(f"‚úÖ Loaded training history with {len(history['train_loss'])} epochs")
except FileNotFoundError:
    print("‚ö†Ô∏è No training history found. Run train.py first.")
    history = None

In [None]:
# ============================================================
# Cell 3: Plot Training Curves
# ============================================================
if history:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss curve
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Loss Curve (Lower is Better)')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Check for convergence
    if len(history['train_loss']) > 1:
        loss_decrease = history['train_loss'][0] - history['train_loss'][-1]
        if loss_decrease > 0:
            axes[0, 0].annotate(f'Loss decreased by {loss_decrease:.4f}', 
                               xy=(0.5, 0.95), xycoords='axes fraction', 
                               fontsize=10, color='green')
    
    # Accuracy curve
    axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
    axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Accuracy Curve (Higher is Better)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Precision/Recall curve
    axes[1, 0].plot(epochs, history['train_precision'], 'b-', label='Train Precision', linewidth=2)
    axes[1, 0].plot(epochs, history['val_precision'], 'r-', label='Val Precision', linewidth=2)
    axes[1, 0].plot(epochs, history['train_recall'], 'b--', label='Train Recall', linewidth=2)
    axes[1, 0].plot(epochs, history['val_recall'], 'r--', label='Val Recall', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Score')
    axes[1, 0].set_title('Precision & Recall Curves')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # F1 Score curve
    axes[1, 1].plot(epochs, history['train_f1'], 'b-', label='Train F1', linewidth=2)
    axes[1, 1].plot(epochs, history['val_f1'], 'r-', label='Val F1', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('F1 Score')
    axes[1, 1].set_title('F1 Score Curve (Higher is Better)')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print final metrics
    print("\n=== Final Epoch Metrics ===")
    print(f"Train - Loss: {history['train_loss'][-1]:.4f}, Acc: {history['train_acc'][-1]:.4f}, "
          f"P: {history['train_precision'][-1]:.4f}, R: {history['train_recall'][-1]:.4f}, F1: {history['train_f1'][-1]:.4f}")
    print(f"Val   - Loss: {history['val_loss'][-1]:.4f}, Acc: {history['val_acc'][-1]:.4f}, "
          f"P: {history['val_precision'][-1]:.4f}, R: {history['val_recall'][-1]:.4f}, F1: {history['val_f1'][-1]:.4f}")

## 2. Load Trained Model & Evaluate on Test Data

In [None]:
# ============================================================
# Cell 4: Load Trained Model
# ============================================================
model_path = '../checkpoints/lead_scout_best.pth'

# Initialize model
model = LeadScoutModel(
    vocab_size=17,
    embed_dim=128,
    num_heads=4,
    num_layers=3,
    dropout=0.1
)

try:
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model.eval()
    print(f"‚úÖ Loaded trained model from {model_path}")
    trained_model_available = True
except FileNotFoundError:
    print("‚ö†Ô∏è No trained model found. Run train.py first.")
    trained_model_available = False

In [None]:
# ============================================================
# Cell 5: Prepare Validation Data
# ============================================================
if trained_model_available:
    # Load dataset
    dataset = LeadDataset('../data/leads_raw.csv')
    
    # Use same split as training (for fair comparison)
    torch.manual_seed(42)  # Same seed as training
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    _, val_dataset = random_split(dataset, [train_size, val_size])
    
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    print(f"‚úÖ Loaded {len(val_dataset)} validation samples")

In [None]:
# ============================================================
# Cell 6: Run Inference & Collect Predictions
# ============================================================
if trained_model_available:
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for tokens, labels in val_loader:
            outputs = model(tokens)
            probs = outputs.squeeze()
            preds = (outputs > 0.5).float().squeeze()
            
            all_probs.extend(probs.numpy().flatten())
            all_preds.extend(preds.numpy().flatten())
            all_labels.extend(labels.numpy().flatten())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    print(f"‚úÖ Collected {len(all_preds)} predictions")

## 3. Confusion Matrix & Classification Report

In [None]:
# ============================================================
# Cell 7: Confusion Matrix
# ============================================================
if trained_model_available:
    cm = confusion_matrix(all_labels, all_preds)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                xticklabels=['No Reply', 'Reply'],
                yticklabels=['No Reply', 'Reply'])
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Actual')
    ax.set_title('Confusion Matrix on Validation Set')
    plt.tight_layout()
    plt.show()
    
    # Print classification report
    print("\n=== Classification Report ===")
    print(classification_report(all_labels, all_preds, 
                               target_names=['No Reply', 'Reply']))

## 4. ROC Curve & AUC Score

In [None]:
# ============================================================
# Cell 8: ROC Curve
# ============================================================
if trained_model_available:
    fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
    roc_auc = auc(fpr, tpr)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
    ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('Receiver Operating Characteristic (ROC) Curve')
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"\nAUC Score: {roc_auc:.4f}")
    if roc_auc > 0.7:
        print("‚úÖ Model has learned meaningful patterns (AUC > 0.7)")
    elif roc_auc > 0.5:
        print("‚ö†Ô∏è Model is better than random, but could improve")
    else:
        print("‚ùå Model is not better than random guessing")

## 5. Attention Pattern Analysis (Before vs After Training)

In [None]:
# ============================================================
# Cell 9: Get Attention Weights from Trained Model
# ============================================================
if trained_model_available:
    # Get a sample batch
    sample_tokens, sample_labels = next(iter(val_loader))
    sample_tokens = sample_tokens[:5]  # First 5 samples
    sample_labels = sample_labels[:5]
    
    # Hook to capture attention weights
    attention_weights = []
    
    def hook_fn(module, input, output):
        # output is (attn_output, attn_weights)
        attention_weights.append(output[1].detach())
    
    # Register hook on first transformer block's attention layer
    hook = model.transformer_blocks[0].attention.register_forward_hook(hook_fn)
    
    # Forward pass
    with torch.no_grad():
        _ = model(sample_tokens)
    
    # Remove hook
    hook.remove()
    
    trained_attn = attention_weights[0]
    print(f"‚úÖ Captured attention weights: {trained_attn.shape}")

In [None]:
# ============================================================
# Cell 10: Get Attention Weights from Untrained Model
# ============================================================
if trained_model_available:
    # Create fresh untrained model
    untrained_model = LeadScoutModel(
        vocab_size=17,
        embed_dim=128,
        num_heads=4,
        num_layers=3,
        dropout=0.1
    )
    untrained_model.eval()
    
    # Hook to capture attention weights
    untrained_attention_weights = []
    
    def untrained_hook_fn(module, input, output):
        untrained_attention_weights.append(output[1].detach())
    
    # Register hook
    untrained_hook = untrained_model.transformer_blocks[0].attention.register_forward_hook(untrained_hook_fn)
    
    # Forward pass with same tokens
    with torch.no_grad():
        _ = untrained_model(sample_tokens)
    
    # Remove hook
    untrained_hook.remove()
    
    untrained_attn = untrained_attention_weights[0]
    print(f"‚úÖ Captured untrained attention weights: {untrained_attn.shape}")

In [None]:
# ============================================================
# Cell 11: Compare Attention Patterns
# ============================================================
if trained_model_available:
    # Get token names for visualization
    tokenizer = SalesTokenizer()
    
    # Visualize first sample
    sample_idx = 0
    token_ids = sample_tokens[sample_idx].numpy()
    token_names = [tokenizer.id_to_token.get(int(tid), f'[{tid}]') for tid in token_ids if tid != 0]
    seq_len = len(token_names)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Untrained attention
    attn_untrained = untrained_attn[sample_idx, :seq_len, :seq_len].numpy()
    sns.heatmap(attn_untrained, ax=axes[0], cmap='viridis', 
                xticklabels=token_names, yticklabels=token_names,
                annot=True, fmt='.2f', cbar=True)
    axes[0].set_title('BEFORE Training (Random Weights)', fontsize=14)
    axes[0].set_xlabel('Key (Attended To)')
    axes[0].set_ylabel('Query (Attending)')
    
    # Trained attention
    attn_trained = trained_attn[sample_idx, :seq_len, :seq_len].numpy()
    sns.heatmap(attn_trained, ax=axes[1], cmap='viridis', 
                xticklabels=token_names, yticklabels=token_names,
                annot=True, fmt='.2f', cbar=True)
    axes[1].set_title('AFTER Training (Learned Weights)', fontsize=14)
    axes[1].set_xlabel('Key (Attended To)')
    axes[1].set_ylabel('Query (Attending)')
    
    label_text = 'Reply' if sample_labels[sample_idx].item() == 1 else 'No Reply'
    fig.suptitle(f'Attention Pattern Comparison (Label: {label_text})', fontsize=16, y=1.02)
    
    plt.tight_layout()
    plt.show()
    
    print("\nüìä Key Observations:")
    print("- BEFORE: Attention is often uniform or focused on special tokens ([START], [END])")
    print("- AFTER: Attention should show more meaningful patterns between features")

## 6. Summary & Recommendations

In [None]:
# ============================================================
# Cell 12: Training Quality Summary
# ============================================================
if trained_model_available and history:
    print("=" * 60)
    print("               TRAINING QUALITY ASSESSMENT")
    print("=" * 60)
    
    # Check loss convergence
    loss_decrease = history['train_loss'][0] - history['train_loss'][-1]
    print(f"\n1. Loss Convergence:")
    if loss_decrease > 0:
        print(f"   ‚úÖ Training loss decreased by {loss_decrease:.4f}")
    else:
        print(f"   ‚ùå Training loss did not decrease")
    
    # Check for overfitting
    val_loss_final = history['val_loss'][-1]
    train_loss_final = history['train_loss'][-1]
    print(f"\n2. Overfitting Check:")
    if val_loss_final > train_loss_final * 1.5:
        print(f"   ‚ö†Ô∏è Potential overfitting (Val loss >> Train loss)")
    else:
        print(f"   ‚úÖ No significant overfitting detected")
    
    # Check AUC
    print(f"\n3. Model Discriminative Power:")
    print(f"   AUC Score: {roc_auc:.4f}")
    if roc_auc > 0.7:
        print(f"   ‚úÖ Good discriminative power")
    elif roc_auc > 0.5:
        print(f"   ‚ö†Ô∏è Moderate - consider more training or hyperparameter tuning")
    else:
        print(f"   ‚ùå Poor - model may not have learned meaningful patterns")
    
    # Final F1
    print(f"\n4. Final Validation Metrics:")
    print(f"   Precision: {history['val_precision'][-1]:.4f}")
    print(f"   Recall:    {history['val_recall'][-1]:.4f}")
    print(f"   F1 Score:  {history['val_f1'][-1]:.4f}")
    
    print("\n" + "=" * 60)