# Test with pretrained models

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn.utils import clip_grad_norm_
import torch.nn.init as init
from sklearn.utils.class_weight import compute_class_weight
import random
import numpy as np
import pandas as pd
import os
import cpuinfo
from tqdm import tqdm
from torchinfo import summary
from PIL import Image

In [None]:
# Control randomness
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### File paths

In [None]:
train_path = "../data/Train"        # paths for your training and testing dataset
#train_path = "../data/aug_train"    
test_path = "../data/Test"          # using test dataset as validation too
input_parameter = ""                # paths for import and export custom model trainable parameters
output_parameter = "./model_parameters_efficientnetb0_224"  # paths for import and export custom model trainable parameters

### Device of use

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = "cpu"
device_name = ""

print(f"CPU count: {os.cpu_count()}")
num_workers = min(4, os.cpu_count() // 2)  # Dynamically set num_workers
print(f"Using device: {device}")

### Pretrain Model of use from torchvision

In [None]:
from torchvision.models import efficientnet_v2_s, efficientnet_b0
model = efficientnet_b0(weights='DEFAULT')  # or efficientnet_v2_s(weights='DEFAULT')
img_size = 224 # adjust input image size for model
print(f"Using model {type(model).__name__}")

epochs = 50
batch_size = 16 # adjust to your memory

# For classifier layer if used model freezing (disable when n < 0)
n = 5
optimizer = optim.AdamW(
    [{"params": model.classifier.parameters(), "lr": 1e-2}],
    weight_decay=1e-3,
    )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=2,
    )

# For fully unfrozen model after n epochs
optimizer2 = optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-4,
    )
scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer2,
    mode='max',
    factor=0.2,
    patience=3,
    min_lr=1e-6,
    cooldown=0,
    threshold_mode='rel',
    threshold=0.0001,
    eps=1e-8
    )
if n < 0:
    optimizer = optimizer2
    scheduler = scheduler2

# loss_function = nn.CrossEntropyLoss() defined in Weighed Cross Entropy Loss cell
grad_clip = 5.0         # gradient clipping value

### Data transform/normalization and loader

In [None]:
# Calculate mean and std of dataset images
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(root=train_path, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Initialize variables to compute mean and variance
mean = 0.0
var = 0.0
total_images = 0

for images, _ in tqdm(loader, desc="Calculating stats"):
    batch_samples = images.size(0)  # Number of images in the batch
    images = images.view(batch_samples, images.size(1), -1)  # Flatten the image pixels [B, C, H*W]
    
    # Compute batch mean and variance
    batch_mean = images.mean([0, 2])  # Mean for each channel
    batch_var = images.var([0, 2])   # Variance for each channel
    
    # Update global mean and variance
    mean += batch_mean * batch_samples
    var += batch_var * batch_samples
    total_images += batch_samples

# Final mean and standard deviation
mean /= total_images
std = torch.sqrt(var / total_images)

mean = mean.tolist()
std = std.tolist()

print(f"Total Images: {total_images}")
print(f"Mean: {mean}")
print(f"Std: {std}")

In [None]:
transform_train = transforms.Compose([
    transforms.Resize(450),
    transforms.Pad(padding=150, padding_mode='reflect'),
    transforms.RandomRotation(45, expand=False),
    transforms.CenterCrop(450),
    transforms.Resize(img_size),

    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.02),
    
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),                     # Use the calculated mean and std
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Use the ImageNet mean and std
    transforms.RandomErasing(p=0.2, scale=(0.01, 0.05), ratio=(0.3, 10), value='random', inplace=False),
])

transform_test = transforms.Compose([   # on test dataset
    transforms.Resize(img_size),
    transforms.CenterCrop((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

In [None]:
train_dataset = datasets.ImageFolder(root=train_path, transform=transform_train)
test_dataset = datasets.ImageFolder(root=test_path, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

val_data = [
    (images.to(device), labels.to(device))
    for images, labels in tqdm(test_loader, desc=f"Preloading Test Data to {device_name}", leave=False)
]

class_counts = [0] * len(train_dataset.classes)
for _, label in train_dataset.samples:
    class_counts[label] += 1

print(f"Total Classes: {len(train_dataset.classes)}")
print(f"Class counts: {class_counts}")
print(f"Classes: {train_dataset.classes}")

### Weighted Cross Entropy Loss

In [None]:
class_weights = compute_class_weight(
    'balanced',
    classes=np.arange(len(train_dataset.classes)),
    y=[label for _, label in train_dataset.samples]
)

class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
loss_function = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)

print(f"Class weights: {class_weights}")

### Model Classifier layer

In [None]:
# Edit the output layer of the model
num_classes = len(train_dataset.classes)
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
    nn.Dropout(0.5),
    
    nn.Linear(num_features, 256),
    nn.BatchNorm1d(256),
    nn.LeakyReLU(),

    nn.Linear(256, num_classes),
)

# Initialize weights and biases for classifier
for m in model.classifier:
    if isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight, mean=0, std=0.01, nonlinearity='leaky_relu')
        if m.bias is not None:
            init.constant_(m.bias, 0)

print(f"Classifier - Input features: {num_features}, Output classes: {num_classes}")

### Configure model parameters

In [None]:
# IF NEEDED
# Load custom weight and optimizer states
# if os.path.exists(input_parameter):
#     checkpoint = torch.load("test_weights.pth", map_location=device)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# Selective layer freezing
# Freeze all layers except the classifier
if n >= 0:
    for name, param in model.named_parameters():
        param.requires_grad = False
    for name, param in model.classifier.named_parameters():
        param.requires_grad = True
        print(f"Unfreezing layer: {name}")

In [None]:
# Move model to device
model.to(device)
print(f"Model is on {next(model.parameters()).device}")

In [None]:
# Print model architecture
# print(summary(model, (batch_size, 3, img_size, img_size)))

In [None]:
# # DEBUG
# print(f"Model device: {next(model.parameters()).device}")
# for images, labels in train_data:
#     print(f"Input device: {images.device}")
#     break
# print(f"Device: {device}")

## Training Epochs

In [None]:
#%%time
arr_train_loss = []
arr_train_acc = []
arr_test_loss = []
arr_test_acc = []

# Training loop
for epoch in range(epochs):
    # Unfreeze at epoch n and reset optimizer and scheduler
    if epoch == n:
        optimizer = optimizer2
        scheduler = scheduler2
        for name, param in model.named_parameters():
                param.requires_grad = True
        print("Unfreezing all layers")

    # Training phase
    model.train()
    current_lr = scheduler.get_last_lr()[0]
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    train_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} [Train]", leave=False)
    for images, labels in train_bar:
        images, labels = images.to(device), labels.to(device)

        # --Non-Mixed Precision training--
        outputs = model(images)
        loss = loss_function(outputs, labels)

        optimizer.zero_grad()
        loss.backward()

        clip_grad_norm_(model.parameters(), grad_clip)  # Gradient clipping
        optimizer.step()
        # --Non-Mixed Precision training--

        # Calculate statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

        # Update progress bar
        train_bar.set_postfix({
            'loss': f"{running_loss / total_train:.4f}",
            'acc': f"{100. * correct_train / total_train:.2f}%"
        })

    # Validation phase
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0

    val_bar = tqdm(val_data, desc=f"Epoch {epoch + 1}/{epochs} [Test]", leave=False)
    with torch.no_grad():
        for images, labels in val_bar:
            outputs = model(images)
            loss = loss_function(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

            # Update progress bar
            val_bar.set_postfix({
                'loss': f"{val_loss / total_val:.4f}",
                'acc': f"{100. * correct_val / total_val:.2f}%"
            })

    # Update learning rate
    if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts) or isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR):
        scheduler.step()
    else:
        scheduler.step(torch.tensor(correct_val / total_val))

    # Print epoch summary
    print(f"Epoch {epoch + 1:>3}/{epochs} - "
          f"LR: {current_lr:.7f} | "
          f"Train Loss: {running_loss / total_train:.4f}, Train Acc: {100. * correct_train / total_train:.2f}% | "
          f"Test Loss: {val_loss / total_val:.4f}, Test Acc: {100. * correct_val / total_val:.2f}%")
    
    # Save training and validation loss and accuracy
    arr_train_loss.append(running_loss / total_train)
    arr_train_acc.append(100. * correct_train / total_train)
    arr_test_loss.append(val_loss / total_val)
    arr_test_acc.append(100. * correct_val / total_val)

### Training log and data export

In [None]:
# Save model parameters
if not os.path.exists(output_parameter):
    os.makedirs(output_parameter)

# Save model weights, optimizer state, and scheduler state
torch.save({
    'model_state_dict': model.state_dict()
}, os.path.join(output_parameter, "model_weights.pth"))
torch.save({
    'optimizer_state_dict': optimizer.state_dict()
}, os.path.join(output_parameter, "optimizer_weights.pth"))
torch.save({
    'scheduler_state_dict': scheduler.state_dict()
}, os.path.join(output_parameter, "scheduler_weights.pth"))

# Save training logs to CSV
data = {
    "Epoch": list(range(1, len(arr_train_loss) + 1)),
    "Train Loss": arr_train_loss,
    "Train Accuracy": arr_train_acc,
    "Test Loss": arr_test_loss,
    "Test Accuracy": arr_test_acc
}
df = pd.DataFrame(data)
csv_file = os.path.join(output_parameter, "training_logs.csv")
df.to_csv(csv_file, index=False)

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(arr_train_loss, label='Train Loss', color='blue', linestyle='-')
plt.plot(arr_test_loss, label='Test Loss', color='red', linestyle='-')
plt.title('Training and Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.xticks(np.arange(start=0, stop=len(arr_train_loss)+1, step=10))
plt.xlim(-5, len(arr_train_loss))
plt.yticks(np.arange(start=0, stop=max(arr_train_loss) + 0.5, step=0.25))
plt.ylim(0, 2.5)

plt.subplot(1, 2, 2)
plt.plot(arr_train_acc, label='Train Acc', color='blue', linestyle='-')
plt.plot(arr_test_acc, label='Test Acc', color='red', linestyle='-')
plt.title('Training and Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.xticks(np.arange(start=0, stop=len(arr_train_loss)+1, step=10))
plt.xlim(-5, len(arr_train_loss))
plt.yticks(np.arange(start=max(min(arr_train_acc)//10*10-10, 0), stop=max(arr_train_acc)//10*10+30, step=10))


plt.tight_layout()
plt.savefig('loss_acc_curves.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# iteration 7 template 2
# Updates:
# 1. Use EfficientNetB0
# 2. Implement gradual unfreezing
# 3. Export model parameters
# 4. Adjust data augmentation
# 5. Use dataset mean and std for normalization (previously ImageNet stats)
# 6. Rename validation variables to test in data export