In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
"""
Cat vs Dog Multi-Model Classification (IPython Notebook version)
Each model runs in independent cells, results are aggregated at the end
"""

# Import dependencies
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from sklearn.metrics import confusion_matrix
import seaborn as sns
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# GPU detection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Global configuration (shared across all models)
class Config:
    # Data configuration
    DATASET_NAME = 'Aurora1609/cat_vs_dog'
    IMG_SIZE = 224
    BATCH_SIZE = 32  
    NUM_WORKERS = 0  
    PIN_MEMORY = False  
    
    # Training configuration
    NUM_EPOCHS = 20
    NUM_CLASSES = 2
    LEARNING_RATE = 0.0001
    WEIGHT_DECAY = 1e-4
    LR_SCHEDULER_PATIENCE = 3
    LR_SCHEDULER_FACTOR = 0.5

    # Save configuration
    MODEL_SAVE_DIR = './saved_models'
    VIS_SAVE_DIR = './visualization_results'
    CLASS_NAMES = ['cat', 'dog']
    NUM_SAMPLE_GRID = 16

cfg = Config()

# Create save directories
os.makedirs(cfg.MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(cfg.VIS_SAVE_DIR, exist_ok=True)

# Global dictionary: stores results for all models (val_acc, test_acc, history)
all_model_results = {}

# Visualization
def plot_single_model_curve(history, model_name, save_path):
    """Plot training curves for a single model"""
    plt.figure(figsize=(12, 4))
    
    # Accuracy curve
    plt.subplot(1, 2, 1)
    plt.plot(history['train_acc'], 'b-o', label='Train Acc', markersize=4)
    plt.plot(history['val_acc'], 'r-o', label='Val Acc', markersize=4)
    plt.title(f'{model_name} - Accuracy Curve', fontsize=12)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Loss curve
    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], 'b-o', label='Train Loss', markersize=4)
    plt.plot(history['val_loss'], 'r-o', label='Val Loss', markersize=4)
    plt.title(f'{model_name} - Loss Curve', fontsize=12)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Training curve saved: {os.path.basename(save_path)}")

def plot_confusion_matrix(all_true, all_pred, class_names, model_name, save_path, split='Val Set'):
    """Plot confusion matrix"""
    cm = confusion_matrix(all_true, all_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=class_names, yticklabels=class_names,
        cbar_kws={'label': 'Number of Samples'}
    )
    plt.title(f'{model_name} - Confusion Matrix ({split})', fontsize=14)
    plt.xlabel('Predicted', fontsize=12)
    plt.ylabel('True', fontsize=12)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Confusion matrix saved: {os.path.basename(save_path)}")

def plot_sample_analysis(model, val_loader, class_names, device, model_name, save_dir, num_samples=16):
    """Plot sample analysis (correct/incorrect predictions)"""
    model.eval()
    correct_imgs, correct_lbls, correct_preds = [], [], []
    incorrect_imgs, incorrect_lbls, incorrect_preds = [], [], []
    mean = torch.tensor([0.485, 0.456, 0.406]).to(device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).to(device).view(3, 1, 1)

    with torch.no_grad():
        for imgs, lbls in val_loader:
            imgs = imgs.to(device)
            lbls = lbls.to(device)
            preds = model(imgs).argmax(dim=1)
            for img, lbl, pred in zip(imgs, lbls, preds):
                if lbl == pred:
                    correct_imgs.append(img)
                    correct_lbls.append(lbl)
                    correct_preds.append(pred)
                else:
                    incorrect_imgs.append(img)
                    incorrect_lbls.append(lbl)
                    incorrect_preds.append(pred)
            if len(correct_imgs)>=num_samples and len(incorrect_imgs)>=num_samples:
                break

    # Correct predictions plot
    if len(correct_imgs)>=num_samples:
        fig, axes = plt.subplots(4, 4, figsize=(16, 16))
        fig.suptitle(f'{model_name} - Correct Predictions (Val Set)', fontsize=18, y=0.95)
        for idx, (img, t_lbl, p_lbl) in enumerate(zip(correct_imgs[:num_samples], correct_lbls[:num_samples], correct_preds[:num_samples])):
            img = img.squeeze().cpu() * std.cpu() + mean.cpu()
            img = img.permute(1, 2, 0).numpy()
            img = np.clip(img, 0, 1)
            ax = axes[idx//4, idx%4]
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(f"True: {class_names[t_lbl.item()]}\nPred: {class_names[p_lbl.item()]}", color='green', fontsize=12, pad=10)
        plt.tight_layout()
        plt.subplots_adjust(top=0.92)
        plt.savefig(os.path.join(save_dir, f"{model_name}_correct_samples.png"), dpi=300, bbox_inches='tight')
        plt.close()
        print(f"✅ Correct predictions saved: {model_name}_correct_samples.png")
    
    # Incorrect predictions plot
    if len(incorrect_imgs)>=num_samples:
        fig, axes = plt.subplots(4, 4, figsize=(16, 16))
        fig.suptitle(f'{model_name} - Incorrect Predictions (Val Set)', fontsize=18, y=0.95)
        for idx, (img, t_lbl, p_lbl) in enumerate(zip(incorrect_imgs[:num_samples], incorrect_lbls[:num_samples], incorrect_preds[:num_samples])):
            img = img.squeeze().cpu() * std.cpu() + mean.cpu()
            img = img.permute(1, 2, 0).numpy()
            img = np.clip(img, 0, 1)
            ax = axes[idx//4, idx%4]
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(f"True: {class_names[t_lbl.item()]}\nPred: {class_names[p_lbl.item()]}", color='red', fontsize=12, pad=10)
        plt.tight_layout()
        plt.subplots_adjust(top=0.92)
        plt.savefig(os.path.join(save_dir, f"{model_name}_incorrect_samples.png"), dpi=300, bbox_inches='tight')
        plt.close()
        print(f"✅ Incorrect predictions saved: {model_name}_incorrect_samples.png")
    else:
        print(f"⚠️ {model_name}: Not enough incorrect samples, skipping incorrect samples plot")

# Training/validation
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc='Training', ncols=100)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_postfix({'loss': f'{running_loss/total:.4f}', 'acc': f'{100*correct/total:.2f}%'})
    return running_loss / total, 100 * correct / total

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_true = []
    all_pred = []
    pbar = tqdm(loader, desc='Validation', ncols=100)
    with torch.no_grad():
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_true.extend(labels.cpu().numpy())
            all_pred.extend(predicted.cpu().numpy())
            pbar.set_postfix({'loss': f'{running_loss/total:.4f}', 'acc': f'{100*correct/total:.2f}%'})
    return running_loss / total, 100 * correct / total, all_true, all_pred

print("✅ Cell 1: Dependencies and configuration completed!")

In [None]:
# Dataset class with error handling
class CatDogDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        try:
            item = self.dataset[idx]
            image = item['image'].convert('RGB')
            label = item['label']
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"⚠️ Failed to load sample {idx}: {e}, returning fallback sample")

            # Return first sample as fallback to prevent crash
            item = self.dataset[0]
            image = item['image'].convert('RGB')
            label = item['label']
            if self.transform:
                image = self.transform(image)
            return image, label

# Data preprocessing (shared across all models)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(cfg.IMG_SIZE, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

val_test_transform = transforms.Compose([
    transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Load dataset
print(f"Loading dataset: {cfg.DATASET_NAME} (with built-in train/val/test splits)...")
dataset = load_dataset(cfg.DATASET_NAME)
train_hf = dataset['train']
val_hf = dataset['val']
test_hf = dataset['test']

# Wrap datasets
train_dataset = CatDogDataset(train_hf, train_transform)
val_dataset = CatDogDataset(val_hf, val_test_transform)
test_dataset = CatDogDataset(test_hf, val_test_transform)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True,
    num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, drop_last=True
)
val_loader = DataLoader(
    val_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False,
    num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY
)
test_loader = DataLoader(
    test_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False,
    num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY
)

# Print dataset information
print(f"\nDataset split information:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Test samples: {len(test_dataset)}")
print(f"  Classes: {cfg.CLASS_NAMES}")

print("✅ Cell2: Data loading complete! Shared across all models")

In [None]:
# Model building
def build_resnet18(num_classes=2, pretrained=True):
    if pretrained:
        try:
            model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        except:
            model = models.resnet18(pretrained=True)
    else:
        model = models.resnet18(pretrained=False)
    
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model.to(device)

# Initialize model
model_name = 'resnet18'
print(f"===== Training model: {model_name} =====")
model = build_resnet18(num_classes=cfg.NUM_CLASSES, pretrained=True)
print(f"✅ {model_name} initialized with pretrained weights")

# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=cfg.LR_SCHEDULER_FACTOR, patience=cfg.LR_SCHEDULER_PATIENCE, verbose=True
)

# Training history
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_acc = 0.0
best_all_true = []
best_all_pred = []

# Training loop
for epoch in range(cfg.NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{cfg.NUM_EPOCHS}")
    print("-"*50)
    
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, all_true, all_pred = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_all_true = all_true
        best_all_pred = all_pred
        model_save_path = os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_acc': best_val_acc,
            'history': history
        }, model_save_path)
        print(f"[SAVED] Best model (Val accuracy: {best_val_acc:.2f}%) → {os.path.basename(model_save_path)}")
    
    print(f"Epoch summary: train_loss={train_loss:.4f} train_acc={train_acc:.2f}% | "
          f"val_loss={val_loss:.4f} val_acc={val_acc:.2f}% | lr={current_lr:.6f}")

# Visualizations
print(f"\n===== Generating {model_name} visualizations =====")
curve_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_training_curves.png")
plot_single_model_curve(history, model_name, curve_path)

cm_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_val_confusion_matrix.png")
plot_confusion_matrix(best_all_true, best_all_pred, cfg.CLASS_NAMES, model_name, cm_path, split='Val Set')

plot_sample_analysis(model, val_loader, cfg.CLASS_NAMES, device, model_name, cfg.VIS_SAVE_DIR, cfg.NUM_SAMPLE_GRID)

# Save results
all_model_results[model_name] = {
    'val_acc': best_val_acc,
    'history': history,
    'model_path': os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
}

print(f"\n✅ {model_name} training completed!")
print(f"Trained models: {list(all_model_results.keys())}")

In [None]:
# Generate visualization analysis for this model
print(f"\n===== Generating visualization analysis for {model_name} =====")

# Training curves
curve_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_training_curves.png")
plot_single_model_curve(history, model_name, curve_path)

# Confusion matrix (Val set)
cm_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_val_confusion_matrix.png")
plot_confusion_matrix(best_all_true, best_all_pred, cfg.CLASS_NAMES, model_name, cm_path, split='Val Set')

# Sample analysis
plot_sample_analysis(model, val_loader, cfg.CLASS_NAMES, device, model_name, cfg.VIS_SAVE_DIR, cfg.NUM_SAMPLE_GRID)

# Save results
all_model_results[model_name] = {
    'val_acc': best_val_acc,
    'history': history,
    'model_path': os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
}

print(f"\n✅ {model_name} training and analysis completed! Results saved to global dictionary")
print(f"Currently trained models: {list(all_model_results.keys())}")

In [None]:
# Display and save ResNet analysis summary
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

%matplotlib inline

model_name = 'resnet18'
vis_dir = cfg.VIS_SAVE_DIR

# Define image paths
image_paths = {
    'Accuracy': os.path.join(vis_dir, f"{model_name}_training_curves.png"),
    'Val Confusion Matrix': os.path.join(vis_dir, f"{model_name}_val_confusion_matrix.png"),
    'Correct Samples': os.path.join(vis_dir, f"{model_name}_correct_samples.png"),
    'Incorrect Samples': os.path.join(vis_dir, f"{model_name}_incorrect_samples.png")
}

# Create 2x2 subplot layout
fig, axes = plt.subplots(2, 2, figsize=(20, 16))
axes = axes.flatten()

# Load and display images
for idx, (title, path) in enumerate(image_paths.items()):
    if os.path.exists(path):
        img = mpimg.imread(path)
        axes[idx].imshow(img)
        axes[idx].set_title(title, fontsize=16, pad=20)
        axes[idx].axis('off')
    else:
        axes[idx].text(0.5, 0.5, f"Image not found: {title}\nPath: {path}", 
                      ha='center', va='center', fontsize=12, wrap=True)
        axes[idx].set_title(title, fontsize=16, pad=20)
        axes[idx].axis('off')

plt.tight_layout(pad=3.0)
plt.suptitle(f'ResNet18 Analysis', fontsize=20, y=0.98)

# Save summary figure
save_path = os.path.join(vis_dir, f"{model_name}_analysis_summary.png")
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f"Summary figure saved: {save_path}")

plt.show()

# Optional: Save individual high-resolution images
def save_single_image(image_title, image_path, save_dir):
    if os.path.exists(image_path):
        img = mpimg.imread(image_path)
        save_path = os.path.join(save_dir, f"{model_name}_{image_title.replace(' ', '_')}_highres.png")
        plt.figure(figsize=(12, 8))
        plt.imshow(img)
        plt.title(f'ResNet18 - {image_title}', fontsize=18)
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        print(f"Individual figure saved: {save_path}")
    else:
        print(f"Image not found: {image_title}")

In [None]:
def build_resnet50(num_classes=2, pretrained=True):
    if pretrained:
        try:
            model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        except:
            model = models.resnet50(pretrained=True)
    else:
        model = models.resnet50(pretrained=False)
    
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model.to(device)

# Initialize
model_name = 'resnet50'
print(f"===== Training model: {model_name} =====")
model = build_resnet50(num_classes=cfg.NUM_CLASSES, pretrained=True)
print(f"✅ {model_name} model initialized (pretrained weights: True)")

# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=cfg.LR_SCHEDULER_FACTOR, patience=cfg.LR_SCHEDULER_PATIENCE, verbose=True
)

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_acc = 0.0
best_all_true = []
best_all_pred = []

# Training loop
for epoch in range(cfg.NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{cfg.NUM_EPOCHS}")
    print("-"*50)
    
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, all_true, all_pred = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_all_true = all_true
        best_all_pred = all_pred
        
        model_save_path = os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_acc': best_val_acc,
            'history': history
        }, model_save_path)
        print(f"[SAVED] Best model (Val Acc: {best_val_acc:.2f}%) → {os.path.basename(model_save_path)}")
    
    print(f"Epoch Summary: Train Loss={train_loss:.4f} Train Acc={train_acc:.2f}% | "
          f"Val Loss={val_loss:.4f} Val Acc={val_acc:.2f}% | LR={current_lr:.6f}")

# Generate visualizations
print(f"\n===== Generating {model_name} visualizations =====")
curve_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_training_curves.png")
plot_single_model_curve(history, model_name, curve_path)

cm_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_val_confusion_matrix.png")
plot_confusion_matrix(best_all_true, best_all_pred, cfg.CLASS_NAMES, model_name, cm_path, split='Val Set')

plot_sample_analysis(model, val_loader, cfg.CLASS_NAMES, device, model_name, cfg.VIS_SAVE_DIR, cfg.NUM_SAMPLE_GRID)

# Save results
all_model_results[model_name] = {
    'val_acc': best_val_acc,
    'history': history,
    'model_path': os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
}

print(f"\n✅ {model_name} training and analysis completed!")
print(f"Trained models: {list(all_model_results.keys())}")

In [None]:
# Model building
def build_vgg16(num_classes=2, pretrained=True):
    if pretrained:
        try:
            model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        except:
            model = models.vgg16(pretrained=True)
    else:
        model = models.vgg16(pretrained=False)
    
    in_features = model.classifier[-1].in_features
    model.classifier[-1] = nn.Linear(in_features, num_classes)
    return model.to(device)

# Initialize model
model_name = 'vgg16'
print(f"===== Training model: {model_name} =====")
model = build_vgg16(num_classes=cfg.NUM_CLASSES, pretrained=True)
print(f"✅ {model_name} model initialized (pretrained weights: True)")

# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=cfg.LR_SCHEDULER_FACTOR, patience=cfg.LR_SCHEDULER_PATIENCE, verbose=True
)

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_acc = 0.0
best_all_true = []
best_all_pred = []

# Training loop
for epoch in range(cfg.NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{cfg.NUM_EPOCHS}")
    print("-"*50)
    
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, all_true, all_pred = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_all_true = all_true
        best_all_pred = all_pred
        model_save_path = os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_acc': best_val_acc,
            'history': history
        }, model_save_path)
        print(f"[SAVED] Best model (Val accuracy: {best_val_acc:.2f}%) → {os.path.basename(model_save_path)}")
    
    print(f"Epoch summary: train_loss={train_loss:.4f} train_acc={train_acc:.2f}% | "
          f"val_loss={val_loss:.4f} val_acc={val_acc:.2f}% | lr={current_lr:.6f}")

# Generate visualization analysis
print(f"\n===== Generating {model_name} visualization analysis =====")
curve_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_training_curves.png")
plot_single_model_curve(history, model_name, curve_path)

cm_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_val_confusion_matrix.png")
plot_confusion_matrix(best_all_true, best_all_pred, cfg.CLASS_NAMES, model_name, cm_path, split='Val Set')

plot_sample_analysis(model, val_loader, cfg.CLASS_NAMES, device, model_name, cfg.VIS_SAVE_DIR, cfg.NUM_SAMPLE_GRID)

# Save results
all_model_results[model_name] = {
    'val_acc': best_val_acc,
    'history': history,
    'model_path': os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
}

print(f"\n✅ {model_name} training and analysis completed!")
print(f"Trained models: {list(all_model_results.keys())}")

In [None]:
# Display and save ResNet50 analysis summary
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

%matplotlib inline

model_name = 'vgg16'
vis_dir = cfg.VIS_SAVE_DIR

image_paths = {
    'Accuracy': os.path.join(vis_dir, f"{model_name}_training_curves.png"),
    'Val confusion matrix': os.path.join(vis_dir, f"{model_name}_val_confusion_matrix.png"),
    'Correct sample': os.path.join(vis_dir, f"{model_name}_correct_samples.png"),
    'Incorrect sample': os.path.join(vis_dir, f"{model_name}_incorrect_samples.png")
}

# Create subplot layout (2x2 grid)
fig, axes = plt.subplots(2, 2, figsize=(20, 16))
axes = axes.flatten()

# Load and display images
for idx, (title, path) in enumerate(image_paths.items()):
    if os.path.exists(path):
        img = mpimg.imread(path)
        axes[idx].imshow(img)
        axes[idx].set_title(title, fontsize=16, pad=20)
        axes[idx].axis('off')
    else:
        axes[idx].text(0.5, 0.5, f"Image not found: {title}\nPath: {path}", 
                      ha='center', va='center', fontsize=12, wrap=True)
        axes[idx].set_title(title, fontsize=16, pad=20)
        axes[idx].axis('off')

plt.tight_layout(pad=3.0)
plt.suptitle(f'ResNet18 analysis', fontsize=20, y=0.98)

# Save summary figure
save_path = os.path.join(vis_dir, f"{model_name}_analysis_summary.png")
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f"Summary saved: {save_path}")

plt.show()

# Optional: Save individual high-resolution images
def save_single_image(image_title, image_path, save_dir):
    if os.path.exists(image_path):
        img = mpimg.imread(image_path)
        save_path = os.path.join(save_dir, f"{model_name}_{image_title.replace(' ', '_')}_highres.png")
        plt.figure(figsize=(12, 8))
        plt.imshow(img)
        plt.title(f'ResNet50 - {image_title}', fontsize=18)
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        print(f"Individual image saved: {save_path}")
    else:
        print(f"Image not found: {image_title}")

print(f"\nFiles saved in: {vis_dir}")

In [None]:
def build_convnext(num_classes=2, pretrained=True):
    if pretrained:
        try:
            model = models.convnext_small(weights=models.ConvNeXt_Small_Weights.DEFAULT)
        except:
            model = models.convnext_small(pretrained=True)
    else:
        model = models.convnext_small(pretrained=False)
    
    # Modify final classifier layer
    in_features = model.classifier[2].in_features
    model.classifier[2] = nn.Linear(in_features, num_classes)
    return model.to(device)

# Initialize model
model_name = 'convnext'
print(f"===== Training Model: {model_name} =====")
model = build_convnext(num_classes=cfg.NUM_CLASSES, pretrained=True)
print(f"✅ {model_name} model initialized (pretrained: True)")

# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=cfg.LR_SCHEDULER_FACTOR, patience=cfg.LR_SCHEDULER_PATIENCE, verbose=True
)

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_acc = 0.0
best_all_true = []
best_all_pred = []

# Training loop
for epoch in range(cfg.NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{cfg.NUM_EPOCHS}")
    print("-"*50)
    
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, all_true, all_pred = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_all_true = all_true
        best_all_pred = all_pred
        model_save_path = os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_acc': best_val_acc,
            'history': history
        }, model_save_path)
        print(f"[SAVED] Best model (Val Acc: {best_val_acc:.2f}%) → {os.path.basename(model_save_path)}")
    
    print(f"Epoch Summary: Train Loss={train_loss:.4f} Train Acc={train_acc:.2f}% | "
          f"Val Loss={val_loss:.4f} Val Acc={val_acc:.2f}% | LR={current_lr:.6f}")

# Generate visualizations
print(f"\n===== Generating {model_name} Visualizations =====")
curve_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_training_curves.png")
plot_single_model_curve(history, model_name, curve_path)

cm_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_val_confusion_matrix.png")
plot_confusion_matrix(best_all_true, best_all_pred, cfg.CLASS_NAMES, model_name, cm_path, split='Val Set')

plot_sample_analysis(model, val_loader, cfg.CLASS_NAMES, device, model_name, cfg.VIS_SAVE_DIR, cfg.NUM_SAMPLE_GRID)

# Save results
all_model_results[model_name] = {
    'val_acc': best_val_acc,
    'history': history,
    'model_path': os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
}

print(f"\n✅ {model_name} training completed! Results saved")
print(f"Trained models: {list(all_model_results.keys())}")

In [None]:
def build_efficientnet(num_classes=2, pretrained=True):
    if pretrained:
        try:
            model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        except:
            model = models.efficientnet_b0(pretrained=True)
    else:
        model = models.efficientnet_b0(pretrained=False)
    
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(in_features, num_classes)
    return model.to(device)

# Initialize model
model_name = 'efficientnet'
print(f"===== Training model: {model_name} =====")
model = build_efficientnet(num_classes=cfg.NUM_CLASSES, pretrained=True)
print(f"✅ {model_name} model initialized (pretrained weights: True)")

# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=cfg.LR_SCHEDULER_FACTOR, patience=cfg.LR_SCHEDULER_PATIENCE, verbose=True
)

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_acc = 0.0
best_all_true = []
best_all_pred = []

# Training loop
for epoch in range(cfg.NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{cfg.NUM_EPOCHS}")
    print("-"*50)
    
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, all_true, all_pred = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_all_true = all_true
        best_all_pred = all_pred
        model_save_path = os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_acc': best_val_acc,
            'history': history
        }, model_save_path)
        print(f"[SAVED] Best model (Val accuracy: {best_val_acc:.2f}%) → {os.path.basename(model_save_path)}")
    
    print(f"Epoch summary: train_loss={train_loss:.4f} train_acc={train_acc:.2f}% | "
          f"val_loss={val_loss:.4f} val_acc={val_acc:.2f}% | lr={current_lr:.6f}")

# Generate visualization analysis
print(f"\n===== Generating {model_name} visualization analysis =====")
curve_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_training_curves.png")
plot_single_model_curve(history, model_name, curve_path)
cm_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_val_confusion_matrix.png")
plot_confusion_matrix(best_all_true, best_all_pred, cfg.CLASS_NAMES, model_name, cm_path, split='Val Set')
plot_sample_analysis(model, val_loader, cfg.CLASS_NAMES, device, model_name, cfg.VIS_SAVE_DIR, cfg.NUM_SAMPLE_GRID)

# Save results to global dictionary
all_model_results[model_name] = {
    'val_acc': best_val_acc,
    'history': history,
    'model_path': os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
}

print(f"\n✅ {model_name} training and analysis complete! Results saved to global dictionary")
print(f"Currently trained models: {list(all_model_results.keys())}")

In [None]:
# Display and save all analysis plots for the model
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

%matplotlib inline

model_name = 'efficientnet'
vis_dir = cfg.VIS_SAVE_DIR

image_paths = {
    'Accuracy': os.path.join(vis_dir, f"{model_name}_training_curves.png"),
    'Val confusion matrix': os.path.join(vis_dir, f"{model_name}_val_confusion_matrix.png"),
    'Correct sample': os.path.join(vis_dir, f"{model_name}_correct_samples.png"),
    'Incorrect sample': os.path.join(vis_dir, f"{model_name}_incorrect_samples.png")
}

# Create 2x2 subplot layout
fig, axes = plt.subplots(2, 2, figsize=(20, 16))
axes = axes.flatten()

for idx, (title, path) in enumerate(image_paths.items()):
    if os.path.exists(path):
        img = mpimg.imread(path)
        axes[idx].imshow(img)
        axes[idx].set_title(title, fontsize=16, pad=20)
        axes[idx].axis('off')
    else:
        axes[idx].text(0.5, 0.5, f"❌ {title} not found\nPath: {path}", 
                      ha='center', va='center', fontsize=12, wrap=True)
        axes[idx].set_title(title, fontsize=16, pad=20)
        axes[idx].axis('off')

plt.tight_layout(pad=3.0)
plt.suptitle(f'{model_name} Analysis', fontsize=20, y=0.98)

# Save summary plot
save_path = os.path.join(vis_dir, f"{model_name}_analysis_summary.png")
plt.savefig(
    save_path, 
    dpi=300,
    bbox_inches='tight',
    facecolor='white'
)
print(f"✅ Summary plot saved: {save_path}")

plt.show()

# Optional: Save individual plots in high resolution
def save_single_image(image_title, image_path, save_dir):
    if os.path.exists(image_path):
        img = mpimg.imread(image_path)
        save_path = os.path.join(save_dir, f"{model_name}_{image_title.replace(' ', '_')}_highres.png")
        plt.figure(figsize=(12, 8))
        plt.imshow(img)
        plt.title(f'{model_name} - {image_title}', fontsize=18)
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        print(f"✅ Individual plot saved: {save_path}")
    else:
        print(f"❌ {image_title} not found, cannot save")


In [None]:
# Check for trained models
if len(all_model_results) == 0:
    print("❌ No trained models available! Please run Cell3~Cell7 to train models first")
else:
    # Find best model
    sorted_models = sorted(all_model_results.items(), key=lambda x: x[1]['val_acc'], reverse=True)
    best_model_name = sorted_models[0][0]
    best_model_result = sorted_models[0][1]
    print(f"===== Best Model: {best_model_name} (Val Accuracy: {best_model_result['val_acc']:.2f}%) =====")
    
    # Load best model
    print(f"\n1. Loading best model...")
    if best_model_name == 'resnet18':
        from torchvision import models
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, cfg.NUM_CLASSES)
    elif best_model_name == 'resnet50':
        model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, cfg.NUM_CLASSES)
    elif best_model_name == 'vgg16':
        model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        in_features = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(in_features, cfg.NUM_CLASSES)
    elif best_model_name == 'convnext':
        model = models.convnext_small(weights=models.ConvNeXt_Small_Weights.DEFAULT)
        in_features = model.classifier[2].in_features
        model.classifier[2] = nn.Linear(in_features, cfg.NUM_CLASSES)
    elif best_model_name == 'efficientnet':
        model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        in_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_features, cfg.NUM_CLASSES)
    
    model = model.to(device)
    model.load_state_dict(torch.load(best_model_result['model_path'], map_location=device)['model_state_dict'])
    model.eval()
    print(f"✅ Best model loaded successfully!")
    
    # Evaluate on test set
    print(f"\n2. Evaluating best model on test set...")
    correct = 0
    total = 0
    all_true = []
    all_pred = []
    with torch.no_grad():
        pbar = tqdm(test_loader, desc='Test evaluation', ncols=100)
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_true.extend(labels.cpu().numpy())
            all_pred.extend(predicted.cpu().numpy())
            pbar.set_postfix({'test_acc': f'{100*correct/total:.2f}%'})
    
    test_acc = 100 * correct / total
    print(f"✅ Test accuracy: {test_acc:.2f}%")
    
    # Save test confusion matrix
    test_cm_path = os.path.join(cfg.VIS_SAVE_DIR, f"{best_model_name}_test_confusion_matrix.png")
    plot_confusion_matrix(all_true, all_pred, cfg.CLASS_NAMES, best_model_name, test_cm_path, split='Test Set')
    
    # Random inference on test sample
    print(f"\n3. Random inference on test sample...")
    idx = np.random.randint(0, len(test_dataset))
    img, label = test_dataset[idx]
    with torch.no_grad():
        img_tensor = img.unsqueeze(0).to(device)
        pred = model(img_tensor).argmax(dim=1).item()
    
    print(f"  Test sample index: {idx}")
    print(f"  Ground truth: {cfg.CLASS_NAMES[label]}")
    print(f"  Prediction: {cfg.CLASS_NAMES[pred]}")
    
    # Visualize inference sample
    print(f"\n4. Visualizing inference sample...")
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img_show = img.cpu() * std + mean
    img_show = img_show.permute(1, 2, 0).numpy()
    img_show = np.clip(img_show, 0, 1)
    
    plt.figure(figsize=(8, 6))
    plt.imshow(img_show)
    plt.axis('off')
    plt.title(f'Best Model: {best_model_name}\nTest Inference - GT: {cfg.CLASS_NAMES[label]} | Pred: {cfg.CLASS_NAMES[pred]}', fontsize=14)
    infer_path = os.path.join(cfg.VIS_SAVE_DIR, f"{best_model_name}_test_infer_sample.png")
    infer_path = os.path.join(cfg.VIS_SAVE_DIR, f"{best_model_name}_test_infer_sample.png")
    plt.savefig(infer_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Inference visualization saved: {os.path.basename(infer_path)}")

print("\n✅ Test evaluation and inference completed!")

In [None]:
class Config1:
    NUM_CLASSES = 2
    CLASS_NAMES = ['cat', 'dog']
    VIS_SAVE_DIR = '/kaggle/working/visualization_results'
    TEST_DATASET_PATH = '/kaggle/input/testdataset/test'
    # Preprocessing must match training configuration
    TEST_TRANSFORM = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.446), (0.229, 0.224, 0.225))
    ])

cfg1 = Config1()

os.makedirs(cfg.VIS_SAVE_DIR, exist_ok=True)

# Load test dataset
from torchvision import datasets

test_dataset = datasets.ImageFolder(
    root=cfg1.TEST_DATASET_PATH,
    transform=cfg1.TEST_TRANSFORM
)

test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"✅ Test dataset loaded successfully!")
print(f"Total samples: {len(test_dataset)}")
print(f"Class mapping: {test_dataset.class_to_idx}")
print(f"Cat samples: {len([img for img, label in test_dataset if label == 0])}")
print(f"Dog samples: {len([img for img, label in test_dataset if label == 1])}")

# Testing
print(f"\n2. Evaluating best model on test set...")
correct = 0
total = 0
all_true = []
all_pred = []
all_images = []
all_labels = []

with torch.no_grad():
    pbar = tqdm(test_loader, desc='Test Evaluation', ncols=100)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        all_true.extend(labels.cpu().numpy())
        all_pred.extend(predicted.cpu().numpy())
        all_images.extend(images.cpu())
        all_labels.extend(labels.cpu())
        
        pbar.set_postfix({'test_acc': f'{100*correct/total:.2f}%'})

test_acc = 100 * correct / total
print(f"✅ Test accuracy: {test_acc:.2f}%")

test_cm_path = os.path.join(cfg1.VIS_SAVE_DIR, f"{best_model_name}_test_confusion_matrix.png")
plot_confusion_matrix(all_true, all_pred, cfg1.CLASS_NAMES, best_model_name, test_cm_path, split='Test Set')

# Random sample from test dataset
print(f"\n3. Random sample inference from test dataset...")
idx = np.random.randint(0, len(test_dataset))
img, label = test_dataset[idx]

with torch.no_grad():
    img_tensor = img.unsqueeze(0).to(device)
    pred = model(img_tensor).argmax(dim=1).item()

true_class = cfg1.CLASS_NAMES[label]
pred_class = cfg1.CLASS_NAMES[pred]

print(f"  Test sample index: {idx}")
print(f"  True class: {true_class}")
print(f"  Predicted class: {pred_class}")
print(f"  Result: {'✅ Correct' if true_class == pred_class else '❌ Incorrect'}")

# Visualization
print(f"\n4. Visualizing inference sample...")
mean = torch.tensor([0.485, 0.456, 0.446]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
img_show = img.cpu() * std + mean
img_show = img_show.permute(1, 2, 0).numpy()
img_show = np.clip(img_show, 0, 1)

plt.figure(figsize=(8, 6))
plt.imshow(img_show)
plt.axis('off')
plt.title(
    f'Best Model: {best_model_name}\nTest Set Inference\n'
    f'True Class: {true_class} | Predicted: {pred_class}',
    fontsize=14
)
infer_path = os.path.join(cfg1.VIS_SAVE_DIR, f"{best_model_name}_test_infer_sample.png")
plt.savefig(infer_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"✅ Inference sample visualization saved: {os.path.basename(infer_path)}")

def plot_correct_incorrect_samples(all_images, all_true, all_pred, class_names, save_path):
    """Visualize correct and incorrect predictions (4 samples each)"""
    correct_idxs = [i for i, (t, p) in enumerate(zip(all_true, all_pred)) if t == p]
    incorrect_idxs = [i for i, (t, p) in enumerate(zip(all_true, all_pred)) if t != p]
    
    selected_correct = np.random.choice(correct_idxs, min(4, len(correct_idxs)), replace=False)
    selected_incorrect = np.random.choice(incorrect_idxs, min(4, len(incorrect_idxs)), replace=False)
    selected_idxs = np.concatenate([selected_correct, selected_incorrect])
    
    fig, axes = plt.subplots(1, len(selected_idxs), figsize=(4*len(selected_idxs), 5))
    if len(selected_idxs) == 1:
        axes = [axes]
    
    for idx, ax in zip(selected_idxs, axes):
        img = all_images[idx]
        true_label = all_true[idx]
        pred_label = all_pred[idx]
        
        img_show = img * std + mean
        img_show = img_show.permute(1, 2, 0).numpy()
        img_show = np.clip(img_show, 0, 1)
        
        true_class = class_names[true_label]
        pred_class = class_names[pred_label]
        status = "Correct" if true_label == pred_label else "Incorrect"
        color = 'green' if status == "Correct" else 'red'
        
        ax.imshow(img_show)
        ax.set_title(f"True: {true_class}\nPred: {pred_class}\n{status}", color=color, fontsize=10)
        ax.axis('off')
    
    plt.suptitle(f'{best_model_name} Test Set Prediction Samples', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Correct/incorrect samples visualization saved: {os.path.basename(save_path)}")

sample_vis_path = os.path.join(cfg1.VIS_SAVE_DIR, f"{best_model_name}_test_correct_incorrect_samples.png")
plot_correct_incorrect_samples(all_images, all_true, all_pred, cfg1.CLASS_NAMES, sample_vis_path)

print("\n✅ Test set evaluation and inference completed!")

In [None]:
print(f"\n2. Evaluating best model on Test set...")
correct = 0
total = 0
all_true = []
all_pred = []
all_images = []

with torch.no_grad():
    pbar = tqdm(test_loader, desc='Test Evaluation', ncols=100)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        all_true.extend(labels.cpu().numpy())
        all_pred.extend(predicted.cpu().numpy())
        all_images.extend(images.cpu())
        
        pbar.set_postfix({'test_acc': f'{100*correct/total:.2f}%'})

test_acc = 100 * correct / total
print(f"✅ Test Set Accuracy: {test_acc:.2f}%")

test_cm_path = os.path.join(cfg1.VIS_SAVE_DIR, f"{best_model_name}_test_confusion_matrix.png")
plot_confusion_matrix(all_true, all_pred, cfg1.CLASS_NAMES, best_model_name, test_cm_path, split='Test Set')

# Random samples from test dataset
print(f"\n3. Random sample inference from Test dataset...")
idx = np.random.randint(0, len(test_dataset))
img, label = test_dataset[idx]

with torch.no_grad():
    img_tensor = img.unsqueeze(0).to(device)
    pred = model(img_tensor).argmax(dim=1).item()

true_class = cfg1.CLASS_NAMES[label]
pred_class = cfg1.CLASS_NAMES[pred]

print(f"  Test sample index: {idx}")
print(f"  True class: {true_class}")
print(f"  Predicted class: {pred_class}")
print(f"  Result: {'✅ Correct' if true_class == pred_class else '❌ Incorrect'}")

# Visualization
print(f"\n4. Visualizing inference sample...")
mean = torch.tensor([0.485, 0.456, 0.446]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
img_show = img.cpu() * std + mean
img_show = img_show.permute(1, 2, 0).numpy()
img_show = np.clip(img_show, 0, 1)

plt.figure(figsize=(8, 6))
plt.imshow(img_show)
plt.axis('off')
plt.title(
    f'Best Model: {best_model_name}\nTest Set Inference\n'
    f'True Class: {true_class} | Predicted Class: {pred_class}',
    fontsize=14
)
infer_path = os.path.join(cfg1.VIS_SAVE_DIR, f"{best_model_name}_test_infer_sample.png")
plt.savefig(infer_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"✅ Inference sample visualization saved: {os.path.basename(infer_path)}")

def plot_correct_incorrect_samples(all_images, all_true, all_pred, class_names, save_path):
    """Visualize correct and incorrect prediction samples (4 each)"""
    # Separate correct/incorrect indices
    correct_idxs = [int(i) for i, (t, p) in enumerate(zip(all_true, all_pred)) if t == p]
    incorrect_idxs = [int(i) for i, (t, p) in enumerate(zip(all_true, all_pred)) if t != p]
    
    # Handle empty lists
    selected_correct = []
    if len(correct_idxs) > 0:
        selected_correct = np.random.choice(
            correct_idxs, min(4, len(correct_idxs)), replace=False
        ).astype(int).tolist()
    
    selected_incorrect = []
    if len(incorrect_idxs) > 0:
        selected_incorrect = np.random.choice(
            incorrect_idxs, min(4, len(incorrect_idxs)), replace=False
        ).astype(int).tolist()
    
    selected_idxs = selected_correct + selected_incorrect
    if len(selected_idxs) == 0:
        print("⚠️  No samples available for visualization")
        return
    
    # Plot subplots
    fig, axes = plt.subplots(1, len(selected_idxs), figsize=(4*len(selected_idxs), 5))
    if len(selected_idxs) == 1:
        axes = [axes]
    
    for idx, ax in zip(selected_idxs, axes):
        img = all_images[idx]
        true_label = int(all_true[idx])
        pred_label = int(all_pred[idx])
        
        # Denormalize image
        img_show = img * std + mean
        img_show = img_show.permute(1, 2, 0).numpy()
        img_show = np.clip(img_show, 0, 1)
        
        true_class = class_names[true_label]
        pred_class = class_names[pred_label]
        status = "Correct" if true_label == pred_label else "Incorrect"
        color = 'green' if status == "Correct" else 'red'
        
        ax.imshow(img_show)
        ax.set_title(f"True: {true_class}\nPred: {pred_class}\n{status}", color=color, fontsize=10)
        ax.axis('off')
    
    plt.suptitle(f'{best_model_name} Test Set Prediction Samples', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Correct/incorrect samples visualization saved: {os.path.basename(save_path)}")

sample_vis_path = os.path.join(cfg1.VIS_SAVE_DIR, f"{best_model_name}_test_correct_incorrect_samples.png")
plot_correct_incorrect_samples(all_images, all_true, all_pred, cfg1.CLASS_NAMES, sample_vis_path)

print("\n✅ Best model Test set evaluation and inference completed!")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from PIL import Image

# GPU detection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Configuration parameters
class Config2:
    DATA_ROOT = './cifar10_data'
    MODEL_NAME = 'convnext_tiny'
    IMG_SIZE = 32
    NUM_CLASSES = 10
    BATCH_SIZE = 128
    NUM_EPOCHS = 15
    LEARNING_RATE = 0.001
    WEIGHT_DECAY = 1e-4
    NUM_WORKERS = 2
    PIN_MEMORY = True
    LR_SCHEDULER_PATIENCE = 3
    LR_SCHEDULER_FACTOR = 0.5
    VAL_SPLIT = 0.2
    VIS_SAVE_DIR = './cifar10_visualizations'
    NUM_SAMPLE_GRID = 16
    CM_FIGSIZE = (12, 10)
    SAMPLE_FIGSIZE = (16, 12)
    CLASS_NAMES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']

cfg2 = Config2()
os.makedirs(cfg2.VIS_SAVE_DIR, exist_ok=True)

# Data preprocessing and augmentation
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.446), (0.229, 0.224, 0.225))
])

val_test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.446), (0.229, 0.224, 0.225))
])

# Load CIFAR-10 dataset
full_train_dataset = datasets.CIFAR10(
    root=cfg2.DATA_ROOT, train=True, download=True, transform=train_transform
)
test_dataset = datasets.CIFAR10(
    root=cfg2.DATA_ROOT, train=False, download=True, transform=val_test_transform
)

# Split train and validation sets
val_size = int(cfg2.VAL_SPLIT * len(full_train_dataset))
train_size = len(full_train_dataset) - val_size
train_dataset, val_dataset = random_split(
    full_train_dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

val_dataset.dataset.transform = val_test_transform

# Create DataLoaders
train_loader = DataLoader(
    train_dataset, batch_size=cfg2.BATCH_SIZE, shuffle=True,
    num_workers=cfg2.NUM_WORKERS, pin_memory=cfg2.PIN_MEMORY
)
val_loader = DataLoader(
    val_dataset, batch_size=cfg2.BATCH_SIZE, shuffle=False,
    num_workers=cfg2.NUM_WORKERS, pin_memory=cfg2.PIN_MEMORY
)
test_loader = DataLoader(
    test_dataset, batch_size=cfg2.BATCH_SIZE, shuffle=False,
    num_workers=cfg2.NUM_WORKERS, pin_memory=cfg2.PIN_MEMORY
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Class names: {cfg2.CLASS_NAMES}")

# Build ConvNeXt model
def build_convnext(num_classes=10):
    """Load ConvNeXt Tiny and adapt for CIFAR-10"""
    model = models.convnext_tiny(pretrained=True)
    in_features = model.classifier[2].in_features
    model.classifier[2] = nn.Linear(in_features, num_classes)
    return model.to(device)

model = build_convnext(num_classes=cfg2.NUM_CLASSES)
print(f"\n{cfg2.MODEL_NAME} Model initialized:")
print(f"Model structure: {model}")

# Loss function, optimizer, and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    model.parameters(), lr=cfg2.LEARNING_RATE, weight_decay=cfg2.WEIGHT_DECAY
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=cfg2.LR_SCHEDULER_FACTOR,
    patience=cfg2.LR_SCHEDULER_PATIENCE, verbose=True
)

# Training function
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc='Train', ncols=100)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        pbar.set_postfix({
            'loss': f'{running_loss/total:.4f}',
            'acc': f'{100*correct/total:.2f}%'
        })

    return running_loss / total, 100 * correct / total

# Validation/Test function
def validate(model, loader, criterion, device, return_preds=False):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    all_images = []

    pbar = tqdm(loader, desc='Valid/Test', ncols=100)
    with torch.no_grad():
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            if return_preds:
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_images.extend(images.cpu())

            pbar.set_postfix({
                'loss': f'{running_loss/total:.4f}',
                'acc': f'{100*correct/total:.2f}%'
            })

    avg_loss = running_loss / total
    avg_acc = 100 * correct / total

    if return_preds:
        return avg_loss, avg_acc, all_preds, all_labels, all_images
    return avg_loss, avg_acc

# Visualization
def plot_training_curves(history, save_path):
    """Plot training accuracy and loss curves"""
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history['train_acc'], label='Train Acc', marker='o', linewidth=2)
    plt.plot(history['val_acc'], label='Val Acc', marker='s', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title(f'{cfg2.MODEL_NAME} Training/Validation Accuracy')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], label='Train Loss', marker='o', linewidth=2)
    plt.plot(history['val_loss'], label='Val Loss', marker='s', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{cfg2.MODEL_NAME} Training/Validation Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    print(f"✅ Training curves saved to: {save_path}")

def plot_confusion_matrix(all_labels, all_preds, class_names, save_path):
    """Plot confusion matrix"""
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=cfg2.CM_FIGSIZE)
    
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    sns.heatmap(
        cm_normalized, annot=True, fmt='.2f', cmap='Blues',
        xticklabels=class_names, yticklabels=class_names
    )
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title(f'{cfg2.MODEL_NAME} Confusion Matrix (Test Set)')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    
    print("\nClassification Report (Test Set):")
    print(classification_report(
        all_labels, all_preds, target_names=class_names, digits=4
    ))
    print(f"✅ Confusion matrix saved to: {save_path}")

def denormalize_image(tensor):
    """Denormalize image for visualization"""
    mean = torch.tensor([0.485, 0.456, 0.446]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return tensor * std + mean

def plot_sample_analysis(all_images, all_labels, all_preds, class_names, save_path):
    """Visualize correct and incorrect predictions"""
    correct_idxs = [i for i, (p, l) in enumerate(zip(all_preds, all_labels)) if p == l]
    incorrect_idxs = [i for i, (p, l) in enumerate(zip(all_preds, all_labels)) if p != l]

    fig, axes = plt.subplots(2, 1, figsize=cfg2.SAMPLE_FIGSIZE)
    
    axes[0].set_title(f'Correct Predictions (Test Set) - {len(correct_idxs)} samples', fontsize=14)
    num_correct = min(cfg2.NUM_SAMPLE_GRID, len(correct_idxs))
    selected_correct = np.random.choice(correct_idxs, num_correct, replace=False)
    
    for i, idx in enumerate(selected_correct):
        img = denormalize_image(all_images[idx])
        img = img.permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        
        ax = plt.subplot(2, cfg2.NUM_SAMPLE_GRID//2, i+1)
        ax.imshow(img)
        ax.set_title(f"True: {class_names[all_labels[idx]]}\nPred: {class_names[all_preds[idx]]}", fontsize=8)
        ax.axis('off')
    
    axes[1].set_title(f'Incorrect Predictions (Test Set) - {len(incorrect_idxs)} samples', fontsize=14)
    num_incorrect = min(cfg2.NUM_SAMPLE_GRID, len(incorrect_idxs))
    if num_incorrect > 0:
        selected_incorrect = np.random.choice(incorrect_idxs, num_incorrect, replace=False)
        for i, idx in enumerate(selected_incorrect):
            img = denormalize_image(all_images[idx])
            img = img.permute(1, 2, 0).numpy()
            img = np.clip(img, 0, 1)
            
            ax = plt.subplot(2, cfg2.NUM_SAMPLE_GRID//2, num_correct + i + 1)
            ax.imshow(img)
            ax.set_title(f"True: {class_names[all_labels[idx]]}\nPred: {class_names[all_preds[idx]]}", fontsize=8)
            ax.axis('off')
    else:
        axes[1].text(0.5, 0.5, 'No Incorrect Predictions!', ha='center', va='center', fontsize=16)
        axes[1].axis('off')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    print(f"✅ Sample analysis saved to: {save_path}")

best_val_acc = 0.0
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': []
}

print("\n" + "="*60)
print(f"CIFAR-10 {cfg2.MODEL_NAME} Training Started (Epochs: {cfg2.NUM_EPOCHS})")
print("="*60)

# Main training loop
for epoch in range(cfg2.NUM_EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)

    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    print(f"Epoch {epoch+1}/{cfg2.NUM_EPOCHS} | "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}% | "
          f"LR: {current_lr:.6f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        model_save_path = os.path.join(cfg2.VIS_SAVE_DIR, f"{cfg2.MODEL_NAME}_best.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'history': history
        }, model_save_path)
        print(f"  [Saved] Best model (Val Acc: {val_acc:.2f}%) to: {model_save_path}")

    print("-" * 60)

print(f"\nTraining Finished! Best Validation Accuracy: {best_val_acc:.2f}%")

# Test set evaluation with visualization
print("\n" + "="*60)
print(f"Test Set Evaluation for {cfg2.MODEL_NAME}")
print("="*60)

checkpoint = torch.load(model_save_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Loaded best model (trained for {checkpoint['epoch']+1} epochs)")

test_loss, test_acc, test_preds, test_labels, test_images = validate(
    model, test_loader, criterion, device, return_preds=True
)
print(f"\nTest Set Results: Loss = {test_loss:.4f}, Acc = {test_acc:.2f}%")

# Generate visualizations
print(f"\n===== Generating {cfg2.MODEL_NAME} Visualization Analysis (Test Set) =====")

curve_path = os.path.join(cfg2.VIS_SAVE_DIR, f"{cfg2.MODEL_NAME}_training_curves.png")
plot_training_curves(history, curve_path)

cm_path = os.path.join(cfg2.VIS_SAVE_DIR, f"{cfg2.MODEL_NAME}_test_confusion_matrix.png")
plot_confusion_matrix(test_labels, test_preds, cfg2.CLASS_NAMES, cm_path)

sample_path = os.path.join(cfg2.VIS_SAVE_DIR, f"{cfg2.MODEL_NAME}_test_sample_analysis.png")
plot_sample_analysis(test_images, test_labels, test_preds, cfg2.CLASS_NAMES, sample_path)

# Save results
all_model_results = {}
all_model_results[cfg2.MODEL_NAME] = {
    'val_acc': best_val_acc,
    'test_acc': test_acc,
    'test_loss': test_loss,
    'history': history,
    'model_path': model_save_path,
    'test_preds': test_preds,
    'test_labels': test_labels
}

print(f"\n✅ {cfg2.MODEL_NAME} Training + Testing + Visualization Complete!")
print(f"Visualization files saved to: {cfg2.VIS_SAVE_DIR}")
print(f"Trained models: {list(all_model_results.keys())}")
print(f"Final test accuracy: {test_acc:.2f}%")

In [None]:
def plot_sample_analysis(all_images, all_labels, all_preds, class_names, save_path):
    # Separate correct and incorrect predictions
    correct_idxs = [int(i) for i, (p, l) in enumerate(zip(all_preds, all_labels)) if p == l]
    incorrect_idxs = [int(i) for i, (p, l) in enumerate(zip(all_preds, all_labels)) if p != l]

    # Maximum samples to display per section
    max_samples_per_section = cfg2.NUM_SAMPLE_GRID
    num_correct = min(max_samples_per_section, len(correct_idxs))
    num_incorrect = min(max_samples_per_section, len(incorrect_idxs))
    
    # Calculate total rows needed
    total_rows = 0
    if num_correct > 0:
        total_rows += 1
    if num_incorrect > 0:
        total_rows += 1
    if total_rows == 0:
        print("⚠️  No samples available for visualization")
        return
    
    # Create figure with dynamic row adjustment
    fig = plt.figure(figsize=(cfg2.SAMPLE_FIGSIZE[0], cfg2.SAMPLE_FIGSIZE[1] * total_rows / 2))
    
    # Plot correct predictions (if any)
    if num_correct > 0:
        correct_selected = np.random.choice(correct_idxs, num_correct, replace=False).astype(int)
        for i, idx in enumerate(correct_selected):
            ax = plt.subplot(total_rows, num_correct, i + 1)
            img = denormalize_image(all_images[idx])
            img = img.permute(1, 2, 0).numpy()
            img = np.clip(img, 0, 1)
            ax.imshow(img)
            ax.set_title(f"True: {class_names[all_labels[idx]]}\nPred: {class_names[all_preds[idx]]}", fontsize=8)
            ax.axis('off')
        fig.suptitle(f'Correct Predictions (Test Set) - {len(correct_idxs)} total samples', fontsize=14, y=0.95)
    
    # Plot incorrect predictions
    if num_incorrect > 0:
        incorrect_selected = np.random.choice(incorrect_idxs, num_incorrect, replace=False).astype(int)
        row_offset = 1 if num_correct > 0 else 0
        for i, idx in enumerate(incorrect_selected):
            ax = plt.subplot(total_rows, num_incorrect, row_offset * num_incorrect + i + 1)
            img = denormalize_image(all_images[idx])
            img = img.permute(1, 2, 0).numpy()
            img = np.clip(img, 0, 1)
            ax.imshow(img)
            ax.set_title(f"True: {class_names[all_labels[idx]]}\nPred: {class_names[all_preds[idx]]}", fontsize=8, color='red')
            ax.axis('off')
        
        if num_correct > 0:
            fig.text(0.5, 0.47, f'Incorrect Predictions (Test Set) - {len(incorrect_idxs)} total samples', 
                    fontsize=14, ha='center')
        else:
            fig.suptitle(f'Incorrect Predictions (Test Set) - {len(incorrect_idxs)} total samples', fontsize=14, y=0.95)
    
    plt.tight_layout(rect=[0, 0, 1, 0.92])
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    print(f"✅ Sample analysis saved to: {save_path}")

# Visualization
sample_path = os.path.join(cfg2.VIS_SAVE_DIR, f"{cfg2.MODEL_NAME}_test_sample_analysis.png")
plot_sample_analysis(test_images, test_labels, test_preds, cfg2.CLASS_NAMES, sample_path)

# Save results
all_model_results = {}
all_model_results[cfg2.MODEL_NAME] = {
    'val_acc': best_val_acc,
    'test_acc': test_acc,
    'test_loss': test_loss,
    'history': history,
    'model_path': model_save_path,
    'test_preds': test_preds,
    'test_labels': test_labels
}

print(f"\n✅ {cfg2.MODEL_NAME} training + testing + visualization completed!")
print(f"Visualization files saved to: {cfg2.VIS_SAVE_DIR}")
print(f"Trained models: {list(all_model_results.keys())}")
print(f"Final test accuracy: {test_acc:.2f}%")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from sklearn.metrics import confusion_matrix
import seaborn as sns
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# GPU Detection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Global Configuration 
class Config:
    # Data Configuration
    DATASET_NAME = 'Aurora1609/cat_vs_dog'
    IMG_SIZE = 224
    BATCH_SIZE = 32
    NUM_WORKERS = 0
    PIN_MEMORY = False
    
    # Training Configuration
    NUM_EPOCHS = 20
    NUM_CLASSES = 2
    LEARNING_RATE = 0.0001
    WEIGHT_DECAY = 1e-4
    LR_SCHEDULER_PATIENCE = 3
    LR_SCHEDULER_FACTOR = 0.5

    # Saving Configuration
    MODEL_SAVE_DIR = './saved_models'
    VIS_SAVE_DIR = './visualization_results'
    CLASS_NAMES = ['cat', 'dog']
    NUM_SAMPLE_GRID = 16

cfg = Config()

# Create save directories
os.makedirs(cfg.MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(cfg.VIS_SAVE_DIR, exist_ok=True)

# Global dictionary: store results of all models (val_acc, test_acc, history)
all_model_results = {}

# Visualization utility functions
def plot_single_model_curve(history, model_name, save_path):
    plt.figure(figsize=(12, 4))
    # Accuracy curve
    plt.subplot(1, 2, 1)
    plt.plot(history['train_acc'], 'b-o', label='Train Acc', markersize=4)
    plt.plot(history['val_acc'], 'r-o', label='Val Acc', markersize=4)
    plt.title(f'{model_name} - Accuracy Curve', fontsize=12)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    # Loss curve
    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], 'b-o', label='Train Loss', markersize=4)
    plt.plot(history['val_loss'], 'r-o', label='Val Loss', markersize=4)
    plt.title(f'{model_name} - Loss Curve', fontsize=12)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Training curve saved: {os.path.basename(save_path)}")

def plot_confusion_matrix(all_true, all_pred, class_names, model_name, save_path, split='Val Set'):
    cm = confusion_matrix(all_true, all_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=class_names, yticklabels=class_names,
        cbar_kws={'label': 'Number of Samples'}
    )
    plt.title(f'{model_name} - Confusion Matrix ({split})', fontsize=14)
    plt.xlabel('Predicted', fontsize=12)
    plt.ylabel('True', fontsize=12)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Confusion matrix saved: {os.path.basename(save_path)}")

def plot_sample_analysis(model, val_loader, class_names, device, model_name, save_dir, num_samples=16):
    model.eval()
    correct_imgs, correct_lbls, correct_preds = [], [], []
    incorrect_imgs, incorrect_lbls, incorrect_preds = [], [], []
    mean = torch.tensor([0.485, 0.456, 0.406]).to(device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).to(device).view(3, 1, 1)

    with torch.no_grad():
        for imgs, lbls in val_loader:
            imgs = imgs.to(device)
            lbls = lbls.to(device)
            preds = model(imgs).argmax(dim=1)
            for img, lbl, pred in zip(imgs, lbls, preds):
                if lbl == pred:
                    correct_imgs.append(img)
                    correct_lbls.append(lbl)
                    correct_preds.append(pred)
                else:
                    incorrect_imgs.append(img)
                    incorrect_lbls.append(lbl)
                    incorrect_preds.append(pred)
            if len(correct_imgs)>=num_samples and len(incorrect_imgs)>=num_samples:
                break

    # Correct prediction samples
    if len(correct_imgs)>=num_samples:
        fig, axes = plt.subplots(4, 4, figsize=(16, 16))
        fig.suptitle(f'{model_name} - Correct Predictions (Val Set)', fontsize=18, y=0.95)
        for idx, (img, t_lbl, p_lbl) in enumerate(zip(correct_imgs[:num_samples], correct_lbls[:num_samples], correct_preds[:num_samples])):
            img = img.squeeze().cpu() * std.cpu() + mean.cpu()
            img = img.permute(1, 2, 0).numpy()
            img = np.clip(img, 0, 1)
            ax = axes[idx//4, idx%4]
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(f"True: {class_names[t_lbl.item()]}\nPredict: {class_names[p_lbl.item()]}", color='green', fontsize=12, pad=10)
        plt.tight_layout()
        plt.subplots_adjust(top=0.92)
        plt.savefig(os.path.join(save_dir, f"{model_name}_correct_samples.png"), dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Correct prediction samples saved: {model_name}_correct_samples.png")
    
    # Incorrect prediction samples
    if len(incorrect_imgs)>=num_samples:
        fig, axes = plt.subplots(4, 4, figsize=(16, 16))
        fig.suptitle(f'{model_name} - Incorrect Predictions (Val Set)', fontsize=18, y=0.95)
        for idx, (img, t_lbl, p_lbl) in enumerate(zip(incorrect_imgs[:num_samples], incorrect_lbls[:num_samples], incorrect_preds[:num_samples])):
            img = img.squeeze().cpu() * std.cpu() + mean.cpu()
            img = img.permute(1, 2, 0).numpy()
            img = np.clip(img, 0, 1)
            ax = axes[idx//4, idx%4]
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(f"True: {class_names[t_lbl.item()]}\nPred: {class_names[p_lbl.item()]}", color='red', fontsize=12, pad=10)
        plt.tight_layout()
        plt.subplots_adjust(top=0.92)
        plt.savefig(os.path.join(save_dir, f"{model_name}_incorrect_samples.png"), dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Incorrect prediction samples saved: {model_name}_incorrect_samples.png")
    else:
        print(f"{model_name}: Insufficient incorrect samples ({len(incorrect_imgs)} < {num_samples}), skipping incorrect sample plot")

# Training/validation utility functions 
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc='Training', ncols=100)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_postfix({'loss': f'{running_loss/total:.4f}', 'acc': f'{100*correct/total:.2f}%'})
    return running_loss / total, 100 * correct / total

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_true = []
    all_pred = []
    pbar = tqdm(loader, desc='Validation', ncols=100)
    with torch.no_grad():
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_true.extend(labels.cpu().numpy())
            all_pred.extend(predicted.cpu().numpy())
            pbar.set_postfix({'loss': f'{running_loss/total:.4f}', 'acc': f'{100*correct/total:.2f}%'})
    return running_loss / total, 100 * correct / total, all_true, all_pred

# Dataset class with error handling for stability
class CatDogDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        try:
            item = self.dataset[idx]
            image = item['image'].convert('RGB')
            label = item['label']
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"⚠️ Failed to load sample {idx}: {e}, returning fallback sample")
            item = self.dataset[0]
            image = item['image'].convert('RGB')
            label = item['label']
            if self.transform:
                image = self.transform(image)
            return image, label

# Data preprocessing (shared across all models)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(cfg.IMG_SIZE, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

val_test_transform = transforms.Compose([
    transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Load dataset (using built-in train/val/test split)
print(f"Loading dataset: {cfg.DATASET_NAME} (with built-in train/val/test split)...")
dataset = load_dataset(cfg.DATASET_NAME)
train_hf = dataset['train']
val_hf = dataset['val']
test_hf = dataset['test']

# Wrap dataset
train_dataset = CatDogDataset(train_hf, train_transform)
val_dataset = CatDogDataset(val_hf, val_test_transform)
test_dataset = CatDogDataset(test_hf, val_test_transform)

# Create DataLoader (single process to avoid errors)
train_loader = DataLoader(
    train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True,
    num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, drop_last=True
)
val_loader = DataLoader(
    val_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False,
    num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY
)
test_loader = DataLoader(
    test_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False,
    num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY
)

# Print data information
print(f"\nDataset Split Information:")
print(f"  Training set samples: {len(train_dataset)}")
print(f"  Validation set samples: {len(val_dataset)}")
print(f"  Test set samples: {len(test_dataset)}")
print(f"  Classes: {cfg.CLASS_NAMES}")

# Model
def build_efficientnet(num_classes=2, pretrained=True):
    if pretrained:
        try:
            model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        except:
            model = models.efficientnet_b0(pretrained=True)
    else:
        model = models.efficientnet_b0(pretrained=False)
    # Modify final layer
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(in_features, num_classes)
    return model.to(device)

# Initialize model
model_name = 'efficientnet'
print(f"===== Starting training for model: {model_name} =====")
model = build_efficientnet(num_classes=cfg.NUM_CLASSES, pretrained=True)
print(f"{model_name} model initialized (pretrained weights: True)")

# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=cfg.LR_SCHEDULER_FACTOR, patience=cfg.LR_SCHEDULER_PATIENCE, verbose=True
)

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_acc = 0.0
best_all_true = []
best_all_pred = []

# Training loop
for epoch in range(cfg.NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{cfg.NUM_EPOCHS}")
    print("-"*50)
    
    # Training
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    # Validation
    val_loss, val_acc, all_true, all_pred = validate(model, val_loader, criterion, device)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Record history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Update best results
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_all_true = all_true
        best_all_pred = all_pred
        # Save best model
        model_save_path = os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_acc': best_val_acc,
            'history': history
        }, model_save_path)
        print(f"[SAVED] Best model (Val Accuracy: {best_val_acc:.2f}%) → {os.path.basename(model_save_path)}")
    
    # Print epoch summary
    print(f"Epoch Summary: Train Loss={train_loss:.4f} Train Acc={train_acc:.2f}% | "
          f"Val Loss={val_loss:.4f} Val Acc={val_acc:.2f}% | LR={current_lr:.6f}")

# Visualization
print(f"\n===== Generating visualization analysis for {model_name} =====")
curve_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_training_curves.png")
plot_single_model_curve(history, model_name, curve_path)
cm_path = os.path.join(cfg.VIS_SAVE_DIR, f"{model_name}_val_confusion_matrix.png")
plot_confusion_matrix(best_all_true, best_all_pred, cfg.CLASS_NAMES, model_name, cm_path, split='Val Set')
plot_sample_analysis(model, val_loader, cfg.CLASS_NAMES, device, model_name, cfg.VIS_SAVE_DIR, cfg.NUM_SAMPLE_GRID)

# Save results
all_model_results[model_name] = {
    'val_acc': best_val_acc,
    'history': history,
    'model_path': os.path.join(cfg.MODEL_SAVE_DIR, f"{model_name}_best.pth")
}

print(f"\n{model_name} training + analysis completed")
print(f"Currently trained models: {list(all_model_results.keys())}")
