# Model Training: Phishing Brand Classification

This notebook covers the complete model training pipeline for phishing brand classification.

## Objectives
1. Train a deep learning model for brand classification
2. Evaluate model performance with focus on minimizing false positives
3. Analyze errors and misclassifications
4. Optimize confidence threshold
5. Benchmark inference speed
6. Generate model interpretability visualizations

## Key Metrics
- **Accuracy**: Overall correctness
- **F1 Score**: Balance of precision and recall
- **False Positive Rate for 'others'**: Critical metric - benign sites classified as brands
- **Per-class metrics**: Identify problematic classes
- **Inference speed**: Important for production deployment

In [None]:
# Import required libraries
import os
import sys
import json
import time
from pathlib import Path
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import yaml
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

from src.data.dataset import PhishingDataset, create_dataloaders
from src.data.transforms import AlbumentationsTransform, get_train_transforms, get_val_transforms
from src.models.classifier import BrandClassifier, create_model
from src.models.losses import FocalLoss
from src.utils.metrics import (
    calculate_metrics,
    compute_confusion_matrix,
    find_optimal_threshold,
    get_false_positive_analysis,
    evaluate_with_rejection,
)
from src.utils.visualization import (
    plot_confusion_matrix,
    plot_training_curves,
    plot_false_positive_analysis,
    plot_per_class_metrics,
    plot_confidence_distribution,
)

# Configuration
DATA_DIR = project_root / 'data' / 'raw'
PROCESSED_DIR = project_root / 'data' / 'processed'
OUTPUT_DIR = project_root / 'outputs'
FIGURES_DIR = OUTPUT_DIR / 'figures'
MODELS_DIR = OUTPUT_DIR / 'models'

for dir_path in [FIGURES_DIR, MODELS_DIR]:
    dir_path.mkdir(parents=True, exist_ok=True)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Configuration

In [None]:
# Training configuration
CONFIG = {
    # Data
    'image_size': 224,
    'batch_size': 32,
    'num_workers': 4,
    
    # Model
    'architecture': 'efficientnet_b0',  # Options: resnet50, efficientnet_b0, efficientnet_b3
    'pretrained': True,
    'dropout': 0.3,
    
    # Training
    'num_epochs': 30,
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'use_amp': True,  # Mixed precision training
    
    # Loss
    'use_focal_loss': True,
    'focal_gamma': 2.0,
    'use_class_weights': True,
    
    # Evaluation
    'confidence_threshold': 0.85,
    
    # Random seed
    'seed': 42,
}

# Class names
CLASS_NAMES = [
    'amazon', 'apple', 'facebook', 'google', 'instagram',
    'linkedin', 'microsoft', 'netflix', 'paypal', 'twitter',
    'others'
]
NUM_CLASSES = len(CLASS_NAMES)
OTHERS_IDX = CLASS_NAMES.index('others')

print(f"Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")
print(f"\nNumber of classes: {NUM_CLASSES}")
print(f"Others class index: {OTHERS_IDX}")

# Set random seeds
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

## 2. Data Loading

In [None]:
# Load preprocessed splits
train_df = pd.read_csv(PROCESSED_DIR / 'train.csv')
val_df = pd.read_csv(PROCESSED_DIR / 'val.csv')
test_df = pd.read_csv(PROCESSED_DIR / 'test.csv')

print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

# Create transforms
train_transform = AlbumentationsTransform(
    get_train_transforms(image_size=CONFIG['image_size'])
)
val_transform = AlbumentationsTransform(
    get_val_transforms(image_size=CONFIG['image_size'])
)

# Create dataloaders
train_loader, val_loader, test_loader, class_names = create_dataloaders(
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    data_dir=str(DATA_DIR),
    train_transform=train_transform,
    val_transform=val_transform,
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers'],
    use_weighted_sampler=CONFIG['use_class_weights'],
    class_names=CLASS_NAMES,
)

print(f"\nDataloaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

## 3. Model Setup

In [None]:
# Create model
model = create_model(
    architecture=CONFIG['architecture'],
    num_classes=NUM_CLASSES,
    pretrained=CONFIG['pretrained'],
    dropout=CONFIG['dropout'],
)
model = model.to(device)

print(f"Model: {CONFIG['architecture']}")
print(f"Feature dimension: {model.feature_dim}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# Setup class weights
train_dataset = train_loader.dataset
class_weights = train_dataset.get_class_weights().to(device)

print("Class weights:")
for name, weight in zip(CLASS_NAMES, class_weights):
    marker = " <-- benign" if name == 'others' else ""
    print(f"  {name}: {weight:.3f}{marker}")

# Create loss function
if CONFIG['use_focal_loss']:
    criterion = FocalLoss(
        alpha=class_weights if CONFIG['use_class_weights'] else None,
        gamma=CONFIG['focal_gamma'],
    )
    print(f"\nUsing Focal Loss (gamma={CONFIG['focal_gamma']})")
else:
    criterion = nn.CrossEntropyLoss(
        weight=class_weights if CONFIG['use_class_weights'] else None
    )
    print("\nUsing Cross Entropy Loss")

In [None]:
# Optimizer and scheduler
optimizer = AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
)

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=CONFIG['num_epochs'],
    eta_min=CONFIG['learning_rate'] / 100,
)

# Mixed precision scaler
scaler = GradScaler() if CONFIG['use_amp'] and device.type == 'cuda' else None

print(f"Optimizer: AdamW (lr={CONFIG['learning_rate']})")
print(f"Scheduler: CosineAnnealing")
print(f"Mixed precision: {scaler is not None}")

## 4. Training Loop

In [None]:
def train_epoch(model, loader, criterion, optimizer, device, scaler=None):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    pbar = tqdm(loader, desc='Training', leave=False)
    for images, labels, _ in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        if scaler is not None:
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item()
        preds = outputs.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = running_loss / len(loader)
    accuracy = np.mean(np.array(all_preds) == np.array(all_labels))
    
    return avg_loss, accuracy


def validate(model, loader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels, _ in tqdm(loader, desc='Validating', leave=False):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            preds = outputs.argmax(dim=1).cpu().numpy()
            
            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = running_loss / len(loader)
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    
    accuracy = np.mean(all_preds == all_labels)
    
    return avg_loss, accuracy, all_labels, all_preds, all_probs

In [None]:
# Training loop
history = {
    'train_loss': [], 'val_loss': [],
    'train_acc': [], 'val_acc': [],
    'learning_rate': []
}

best_val_acc = 0
best_epoch = 0
experiment_name = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_dir = MODELS_DIR / experiment_name
checkpoint_dir.mkdir(parents=True, exist_ok=True)

print(f"Training for {CONFIG['num_epochs']} epochs...")
print(f"Checkpoints will be saved to: {checkpoint_dir}")
print("="*60)

for epoch in range(1, CONFIG['num_epochs'] + 1):
    # Train
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, device, scaler
    )
    
    # Validate
    val_loss, val_acc, val_labels, val_preds, val_probs = validate(
        model, val_loader, criterion, device
    )
    
    # Update scheduler
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    history['learning_rate'].append(current_lr)
    
    # Print progress
    print(f"Epoch {epoch:3d}/{CONFIG['num_epochs']} | "
          f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f} | "
          f"LR: {current_lr:.6f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'class_names': CLASS_NAMES,
            'config': CONFIG,
        }, checkpoint_dir / 'best_model.pt')
        print(f"  >> Saved best model (val_acc: {val_acc:.4f})")

print("="*60)
print(f"Training complete! Best val_acc: {best_val_acc:.4f} at epoch {best_epoch}")

In [None]:
# Plot training curves
fig = plot_training_curves(history, figsize=(14, 4))
plt.savefig(FIGURES_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Evaluation on Test Set

In [None]:
# Load best model
checkpoint = torch.load(checkpoint_dir / 'best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']}")

# Evaluate on test set
test_loss, test_acc, test_labels, test_preds, test_probs = validate(
    model, test_loader, criterion, device
)

print(f"\nTest Results:")
print(f"  Loss: {test_loss:.4f}")
print(f"  Accuracy: {test_acc:.4f}")

In [None]:
# Comprehensive metrics
metrics = calculate_metrics(
    y_true=test_labels,
    y_pred=test_preds,
    y_proba=test_probs,
    class_names=CLASS_NAMES,
)

print("\nClassification Report:")
print(metrics['classification_report'])

In [None]:
# Confusion matrix
cm = compute_confusion_matrix(test_labels, test_preds)
fig = plot_confusion_matrix(cm, CLASS_NAMES, normalize=True, figsize=(12, 10))
plt.savefig(FIGURES_DIR / 'confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Per-class metrics visualization
fig = plot_per_class_metrics(metrics, CLASS_NAMES)
plt.savefig(FIGURES_DIR / 'per_class_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. False Positive Analysis

Critical analysis: How often are benign websites ('others') misclassified as target brands?

In [None]:
# False positive analysis for 'others' class
fp_analysis = get_false_positive_analysis(
    y_true=test_labels,
    y_pred=test_preds,
    y_proba=test_probs,
    others_class_idx=OTHERS_IDX,
    class_names=CLASS_NAMES,
)

print("False Positive Analysis for 'Others' (Benign) Class:")
print("="*60)
print(f"Total 'others' samples: {fp_analysis['total_others_samples']}")
print(f"False positives: {fp_analysis['false_positive_count']}")
print(f"False positive rate: {fp_analysis['false_positive_rate']:.2%}")
print(f"\nBrands 'others' was misclassified as:")
for brand, count in sorted(fp_analysis['brand_misclassification_counts'].items(), key=lambda x: -x[1]):
    if count > 0:
        print(f"  {brand}: {count}")

# Plot
fig = plot_false_positive_analysis(fp_analysis)
plt.savefig(FIGURES_DIR / 'false_positive_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Confidence Threshold Optimization

Find the optimal confidence threshold that minimizes false positives while maintaining good accuracy.

In [None]:
# Find optimal threshold
optimal_threshold, threshold_metrics = find_optimal_threshold(
    y_true=test_labels,
    y_proba=test_probs,
    others_class_idx=OTHERS_IDX,
    metric='f1_weighted',
    max_fp_rate=0.05,  # Maximum 5% false positive rate for 'others'
)

print(f"Optimal Confidence Threshold: {optimal_threshold:.2f}")
print(f"\nMetrics at optimal threshold:")
for key, value in threshold_metrics.items():
    print(f"  {key}: {value:.4f}")

In [None]:
# Evaluate with rejection (low-confidence predictions classified as 'others')
rejection_results = evaluate_with_rejection(
    y_true=test_labels,
    y_proba=test_probs,
    confidence_threshold=optimal_threshold,
)

print(f"\nEvaluation with Confidence Threshold ({optimal_threshold:.2f}):")
print(f"  Total samples: {rejection_results['total_samples']}")
print(f"  Accepted: {rejection_results['accepted_samples']} ({1-rejection_results['rejection_rate']:.1%})")
print(f"  Rejected: {rejection_results['rejected_samples']} ({rejection_results['rejection_rate']:.1%})")
print(f"  Accuracy on accepted: {rejection_results['accepted_accuracy']:.4f}")
print(f"  F1 on accepted: {rejection_results['accepted_f1']:.4f}")

In [None]:
# Confidence distribution analysis
confidences = test_probs.max(axis=1)
correct_mask = test_preds == test_labels

fig = plot_confidence_distribution(confidences, correct_mask)
plt.savefig(FIGURES_DIR / 'confidence_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nConfidence Statistics:")
print(f"  Correct predictions - Mean: {confidences[correct_mask].mean():.3f}, Std: {confidences[correct_mask].std():.3f}")
print(f"  Incorrect predictions - Mean: {confidences[~correct_mask].mean():.3f}, Std: {confidences[~correct_mask].std():.3f}")

## 8. Error Analysis

In [None]:
# Identify misclassified samples
test_dataset = test_loader.dataset
misclassified_indices = np.where(test_preds != test_labels)[0]

print(f"Total misclassified samples: {len(misclassified_indices)}")
print(f"Error rate: {len(misclassified_indices) / len(test_labels):.2%}")

# Analyze error patterns
error_df = pd.DataFrame({
    'true_label': [CLASS_NAMES[test_labels[i]] for i in misclassified_indices],
    'pred_label': [CLASS_NAMES[test_preds[i]] for i in misclassified_indices],
    'confidence': [test_probs[i].max() for i in misclassified_indices],
})

print(f"\nMost common error patterns:")
error_patterns = error_df.groupby(['true_label', 'pred_label']).size().sort_values(ascending=False)
print(error_patterns.head(10).to_string())

In [None]:
# Visualize some misclassified samples
from src.data.transforms import denormalize

n_samples = min(8, len(misclassified_indices))
sample_indices = np.random.choice(misclassified_indices, n_samples, replace=False)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for ax, idx in zip(axes, sample_indices):
    # Get image
    img_tensor, label, img_path = test_dataset[idx]
    img = denormalize(img_tensor).permute(1, 2, 0).numpy()
    
    true_label = CLASS_NAMES[label]
    pred_label = CLASS_NAMES[test_preds[idx]]
    conf = test_probs[idx].max()
    
    ax.imshow(img)
    ax.set_title(f"True: {true_label}\nPred: {pred_label}\nConf: {conf:.2f}", 
                fontsize=9, color='red')
    ax.axis('off')

plt.suptitle('Misclassified Samples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'misclassified_samples.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Inference Speed Benchmark

In [None]:
from src.predict import PhishingClassifier

# Create inference wrapper
classifier = PhishingClassifier(
    checkpoint_path=str(checkpoint_dir / 'best_model.pt'),
    device=str(device),
    confidence_threshold=optimal_threshold,
)

# Get a sample image for benchmarking
sample_path = test_df['image_path'].iloc[0]

# Run benchmark
benchmark = classifier.benchmark_inference_speed(
    sample_path,
    num_iterations=100,
    warmup_iterations=10,
)

print("Inference Speed Benchmark")
print("="*40)
print(f"Device: {benchmark['device']}")
print(f"Iterations: {benchmark['num_iterations']}")
print(f"\nLatency (ms):")
print(f"  Mean: {benchmark['mean_latency_ms']:.2f}")
print(f"  Std:  {benchmark['std_latency_ms']:.2f}")
print(f"  P50:  {benchmark['p50_latency_ms']:.2f}")
print(f"  P95:  {benchmark['p95_latency_ms']:.2f}")
print(f"  P99:  {benchmark['p99_latency_ms']:.2f}")
print(f"\nThroughput: {benchmark['throughput_fps']:.1f} FPS")

## 10. Model Interpretability (GradCAM)

In [None]:
from src.interpretability import ModelExplainer

# Create explainer
explainer = ModelExplainer(
    model=model,
    class_names=CLASS_NAMES,
    image_size=CONFIG['image_size'],
    device=str(device),
)

# Generate explanations for a few samples
n_explain = 4
sample_paths = test_df.sample(n_explain)['image_path'].tolist()

for i, img_path in enumerate(sample_paths):
    print(f"\nExplanation for sample {i+1}:")
    explanation = explainer.explain(img_path, methods=['gradcam'])
    fig = explainer.plot_explanation(explanation)
    plt.savefig(FIGURES_DIR / f'gradcam_sample_{i+1}.png', dpi=150, bbox_inches='tight')
    plt.show()

## 11. Save Final Results

In [None]:
# Save all results
results = {
    'config': CONFIG,
    'training': {
        'best_epoch': best_epoch,
        'best_val_acc': best_val_acc,
        'history': history,
    },
    'test_metrics': {
        'accuracy': float(test_acc),
        'loss': float(test_loss),
        'precision': float(metrics['precision']),
        'recall': float(metrics['recall']),
        'f1_score': float(metrics['f1_score']),
    },
    'false_positive_analysis': fp_analysis,
    'optimal_threshold': optimal_threshold,
    'threshold_metrics': threshold_metrics,
    'rejection_results': rejection_results,
    'benchmark': benchmark,
    'class_names': CLASS_NAMES,
}

with open(checkpoint_dir / 'results.json', 'w') as f:
    json.dump(results, f, indent=2, default=lambda x: x.tolist() if hasattr(x, 'tolist') else str(x))

print(f"Results saved to: {checkpoint_dir / 'results.json'}")

In [None]:
# Final summary
print("="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"\nModel: {CONFIG['architecture']}")
print(f"Best validation accuracy: {best_val_acc:.4f} (epoch {best_epoch})")
print(f"\nTest Results:")
print(f"  Accuracy: {test_acc:.4f}")
print(f"  F1 Score: {metrics['f1_score']:.4f}")
print(f"\nFalse Positive Rate ('others'): {fp_analysis['false_positive_rate']:.2%}")
print(f"Optimal Confidence Threshold: {optimal_threshold:.2f}")
print(f"\nInference Speed: {benchmark['mean_latency_ms']:.2f} ms ({benchmark['throughput_fps']:.1f} FPS)")
print(f"\nModel saved to: {checkpoint_dir / 'best_model.pt'}")