In [None]:
# Model Visualization Notebook
# Run each cell sequentially to load model and visualize predictions

# Cell 1: Imports and Setup
import torch
import sys
from pathlib import Path

# Add project paths
project_root = Path.cwd()
sys.path.append(str(project_root / "preprocessing"))
sys.path.append(str(project_root / "model_architecture"))
sys.path.append(str(project_root / "metrics"))
sys.path.append(str(project_root / "utils"))
sys.path.append(str(project_root / "config"))

from dataloader import create_data_loaders, BurnSeverityDataset
from visualization import visualize_model_predictions, visualize_single_prediction
from unet import UNet
from resUnet import ResUNet
from attentionUnet import AttentionUNet
from config import Config

print(f"Using device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")


In [2]:
# Cell 2: Model Loading Function
def get_model(name: str, num_classes: int):
    """Factory to choose model by name."""
    if name.lower() == "unet":
        return UNet(n_channels=Config.IN_CHANNELS, n_classes=num_classes)
    elif name.lower() == "resunet":
        return ResUNet(n_channels=Config.IN_CHANNELS, n_classes=num_classes)
    elif name.lower() == "attentionunet":
        return AttentionUNet(n_channels=Config.IN_CHANNELS, n_classes=num_classes)
    else:
        raise ValueError(f"Unknown model name: {name}")

def load_trained_model(model_path: str, model_name: str, device: torch.device):
    """Load a trained model from checkpoint"""
    print(f"Loading model from: {model_path}")
    
    # Create model architecture
    model = get_model(model_name, num_classes=Config.NUM_CLASSES)
    
    # Load trained weights
    checkpoint = torch.load(model_path, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint)
    
    # Move to device and set to eval mode
    model = model.to(device)
    model.eval()
    
    print(f"Successfully loaded {model_name} model")
    return model



In [None]:
# Cell 3: Load Your Specific Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load your ResU-Net model
model_path = "logs/checkpoints/0.0002_v2_resunet_model.pth" #-------------------------------------------LOAD MODEL HERE
model_name = "resunet"

model = load_trained_model(model_path, model_name, device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")



In [None]:
# Cell 4: Load Test Data
print("Loading test data...")
_, _, test_loader = create_data_loaders(
    dataset_path=Config.DATASET_PATH,
    batch_size=Config.BATCH_SIZE,
)

# Also create dataset object for single sample visualization
test_dataset = BurnSeverityDataset(Config.DATASET_PATH, "test")

print(f"Test dataset size: {len(test_dataset)}")
print(f"Test batches: {len(test_loader)}")



In [None]:
# Cell 5: Visualize Model Predictions on Batch
print("Visualizing model predictions on batch...")
visualize_model_predictions(
    model, 
    test_loader, 
    device, 
    num_samples=4, 
    save_dir="results", 
    prefix="resunet_predictions"
)



In [None]:
# Cell 6: Visualize Single Sample Prediction (you can change the index)
sample_idx = 26  # ------------------------------------------------------------ LOAD SINGLE SAMPLE HERE

print(f"Visualizing single prediction for sample {sample_idx}...")
metrics = visualize_single_prediction(
    model, 
    test_dataset, 
    device, 
    idx=sample_idx, 
    save_dir="results"
)

print("\nSample metrics:")
print(f"Accuracy: {metrics['accuracy']:.1f}%")
print(f"Mean confidence: {metrics['mean_confidence']:.3f}")



In [None]:
# Cell 7: Compare Multiple Single Samples
print("Visualizing multiple single samples...")
sample_indices = [26,100,430]  # ALWAYS THESE SAMPLES TO COMPARE IN REPORT

for idx in sample_indices:
    if idx < len(test_dataset):
        print(f"\nSample {idx}:")
        metrics = visualize_single_prediction(
            model, 
            test_dataset, 
            device, 
            idx=idx, 
            save_dir="results"
        )
        print(f"Accuracy: {metrics['accuracy']:.1f}%, Confidence: {metrics['mean_confidence']:.3f}")
    else:
        print(f"Index {idx} is out of range (max: {len(test_dataset)-1})")

