In [None]:
import os
import torch
import pynvml
import numpy as np
from PIL import Image
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.models import resnet50
from FDResnet50 import FDResNet, resnet50_fdconv
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from train_utils import train, evaluate
from utils import print_generalized_model_summary

%matplotlib inline

In [None]:
is_cuda = torch.cuda.is_available()
device = "cuda" if is_cuda else "cpu"
print("Device:", device)

if is_cuda:
    print("GPU Model:", torch.cuda.get_device_name(0))
    
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(0)

    info = pynvml.nvmlDeviceGetMemoryInfo(handle)
    print(f"Total VRAM:     {info.total / 1024**3:.2f} GB")
    print(f"Used VRAM:      {info.used / 1024**3:.2f} GB")
    print(f"Free VRAM:      {info.free / 1024**3:.2f} GB")

In [None]:
train_dir = "/media/iot/HDD2TB/eyepac-light-v2-512-jpg/train"
val_dir = "/media/iot/HDD2TB/eyepac-light-v2-512-jpg/validation"
test_dir = "/media/iot/HDD2TB/eyepac-light-v2-512-jpg/test"

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.3569, 0.2274, 0.1467], std=[0.2309, 0.1543, 0.1033])
])

train_dataset = ImageFolder(root=train_dir, transform=transform)
val_dataset = ImageFolder(root=val_dir, transform=transform)
test_dataset = ImageFolder(root=test_dir, transform=transform)

class_names = train_dataset.classes
num_classes = len(class_names)

print(f"Number of classes: {num_classes}")
print(f"Class names: {class_names}")

In [None]:
def show_class_samples(dataset, class_names, num_samples=5):
    plt.figure(figsize=(15, 5))
    class_to_idx = dataset.class_to_idx
    
    for c_idx, class_name in enumerate(class_names):
        target_idx = class_to_idx[class_name]
        indices = [i for i, (_, label) in enumerate(dataset.samples) if label == target_idx]
        chosen_indices = indices[:num_samples]
        
        for j, idx in enumerate(chosen_indices):
            img, label = dataset[idx]
            img = img.permute(1, 2, 0) * 0.5 + 0.5
            
            plt.subplot(len(class_names), num_samples, c_idx * num_samples + j + 1)
            plt.imshow(img)
            plt.title(f"{class_name}")
            plt.axis('off')
    plt.show()

show_class_samples(train_dataset, class_names, num_samples=5)

In [None]:
batch_size = 8

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

train_counts = np.bincount([label for _, label in train_dataset.samples], minlength=num_classes)
val_counts = np.bincount([label for _, label in val_dataset.samples], minlength=num_classes)
test_counts = np.bincount([label for _, label in test_dataset.samples], minlength=num_classes)

print(f"Number of training samples: {len(train_dataset)}, class distribution: {dict(zip(class_names, train_counts))}")
print(f"Number of validation samples: {len(val_dataset)}, class distribution: {dict(zip(class_names, val_counts))}")
print(f"Number of test samples: {len(test_dataset)}, class distribution: {dict(zip(class_names, test_counts))}")

In [None]:
resnet_model = resnet50(weights=True)

for param in resnet_model.parameters():
    param.requires_grad = False

in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    
    nn.Linear(512, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(inplace=True),
    nn.Dropout(0.3),
    
    nn.Linear(256, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    
    nn.Linear(128, num_classes)
)

print(f"ResNet50 model created from scratch with {num_classes} output classes")
print("\nResnet50 Model Summary:")
print_generalized_model_summary(resnet_model, model_name="Resnet50")

In [None]:
save_dir = "./model_checkpoints/resnet50"
os.makedirs(save_dir, exist_ok=True)

log_dir = './runs/resnet50'
os.makedirs(log_dir, exist_ok=True)


# Train ResNet50 with simplified training function
print("Starting ResNet50 Training...")
trained_resnet, training_results = train(
    model=resnet_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=100,
    learning_rate=0.001,
    optimizer_type='adam',
    scheduler_type='plateau',
    early_stopping_patience=5,
    device=device,
    save_dir=save_dir,
    model_name='resnet50_trained',
    tensorboard_log_dir=log_dir,
    print_freq=100,
    save_best_only=True,
    mixed_precision=True,
    gradient_accumulation_steps=1,
    weight_decay=1e-4,
    step_size=3,
    gamma=0.1
)

# Get training results
best_val_acc = training_results['best_val_acc']
total_epochs = len(training_results['train_losses'])

print(f"\nTraining Complete!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"Total epochs: {total_epochs}")
print(f"Final train accuracy: {training_results['train_accs'][-1]:.2f}%")
print(f"Final val accuracy: {training_results['val_accs'][-1]:.2f}%")

In [None]:
# Evaluate ResNet50 on Test Set
print("Evaluating ResNet50 on Test Set...")

# Load the best model
best_model_path = os.path.join(save_dir, 'resnet50_enhanced_best.pth')
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path)
    trained_resnet.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1}")

# Evaluate the model
eval_results = evaluate(
    model=trained_resnet,
    test_loader=test_loader,
    device=device,
    class_names=class_names,
    save_dir="./evaluation_results",
    model_name="resnet50_enhanced"
)

# Print evaluation results
print(f"\nResNet50 Test Results:")
print(f"  Test Accuracy: {eval_results['accuracy']:.2f}%")
print(f"  Test Loss: {eval_results['test_loss']:.6f}")
print(f"  Evaluation Time: {eval_results['evaluation_time']:.2f} seconds")
print(f"  Total samples evaluated: {eval_results['num_samples']}")

# Print per-class accuracy
if 'class_accuracies' in eval_results:
    print(f"\nPer-class Accuracies:")
    for i, (class_name, acc) in enumerate(zip(class_names, eval_results['class_accuracies'])):
        print(f"  {class_name}: {acc:.2f}%")

print(f"\nResults saved to: ./evaluation_results/")
print("="*60)