In [None]:
%load_ext watermark
%watermark -v -p numpy,pandas,torch,torchvision,PIL,sklearn,matplotlib,wandb,captum --conda

In [None]:
# Setting the experiment environment

import time
import os
import pandas as pd
import torch
import wandb
from utils.set_seed import set_seed
from utils.data_utils import prepare_dataset, eval_dataset

# Set seed for reproducibility
SEED = 0
set_seed(SEED)

# Get start time of the current experiment
start_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())

# Set the device to GPU if available
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {torch.cuda.get_device_name(DEVICE)}")

# Set wandb experiment name and notes
WANDB_NAME=f"SET EXPERIMENT NAME"
WANDB_NOTES="SET YOUR NOTE"
WANDB_NOTEBOOK_NAME="02. torch_age_inceptionv4.ipynb"  # Change this to the name of the notebook you are running
os.environ["WANDB_NOTEBOOK_NAME"] = WANDB_NOTEBOOK_NAME

# Start a run, tracking hyperparameters
wandb.init(
    project="SET YOUR PROJECT NAME",
    config={
        "model": "SET YOUR MODEL NAME", 
        "input_size": (299, 299),
        "batch_size": 32,
        "augment": True, # Set augment to True to use data augmentation
        "num_augmented_images": "Inverse of Frequency", # If HD ImbAugment is used, set num_augmented_images to "Inverse of Frequency"
        "augment_config": {"augment_prob": 1,
                           "flip_horizontal": True,
                           "flip_vertical": False,
                           "flip_prob": 0.5,
                           "random_brightness": True,
                           "brightness_factor": 0.15,
                           "random_contrast": True,
                           "contrast_factor": 0.15,
                           "random_rotation": True,
                           "rotation_factor": 3,
                           "random_translation": True,
                           "translation_factor": (0.05, 0.05),
                           "random_zoom": True,
                           "zoom_factors": (0.95, 1.05),
                           "random_erasing": True,
                           "erasing_prob": 0.15,
                           "erasing_scale": (0.05, 0.10),
                           "erasing_ratio": (0.3, 3.3),
                           }, # Set your augentation config
        "epochs": 100, 
        "learning_rate": 0.001,
        "learning_rate_scheduler": "ReduceLROnPlateau",
        "optimizer": "adam",
        "multitask": False, # Set multitask to False to train a single task model
        "main_target": "age_in_years",
        "dropout": 0.7,
        "dense_units": 1024,
        "early_stopping": True,
        "early_stopping_patience": 20,	
    },
    notes=WANDB_NOTES,
)

In [None]:
# Loading data from csv file

df = pd.read_csv("Data\ccs_dataset.csv")
print(f"Dataset shape: {df.shape}")
df.head()

In [None]:
# Splitting data into train, validation and test sets with scikit-learn
# IMPORTANT: If using ImbAug set same seed and split size as used in pre-processing to avoid data leakage.

from sklearn.model_selection import train_test_split

train_df, temp_df = train_test_split(df, test_size=0.2, stratify=df.age_group)
test_df, val_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df.age_group)

# Deleting temp_df to free memory
del temp_df

# # If using HD ImbAugment, delete train_df to free memory and load the augmented dataset
if wandb.config.num_augmented_images == "Inverse of Frequency":
    del train_df
    train_df = pd.read_csv("path/to/augmented/train/dataset/generated/on/notebook/01.csv"))
    train_df.head(10).style

In [None]:
# Preparing datasets

if wandb.config.num_augmented_images == "Inverse of Frequency":
    train_dataset = prepare_dataset(
        train_df,
        input_size=wandb.config.input_size,
        batch_size=wandb.config.batch_size,
        shuffle=True,
        augment=False,
        multitask=wandb.config.multitask,
        label=wandb.config.main_target,
        augment_config=wandb.config.augment_config,
    )

else:
    train_dataset = prepare_dataset(
        train_df,
        input_size=wandb.config.input_size,
        batch_size=wandb.config.batch_size,
        shuffle=True,
        augment=wandb.config.augment,
        num_augmented_images=wandb.config.num_augmented_images,
        multitask=wandb.config.multitask,
        label=wandb.config.main_target,
        augment_config=wandb.config.augment_config,
    )

val_dataset = prepare_dataset(
    val_df,
    input_size=wandb.config.input_size,
    batch_size=wandb.config.batch_size,
    shuffle=False,
    augment=False,
    multitask=wandb.config.multitask,
    label=wandb.config.main_target,
)

test_dataset = prepare_dataset(
    test_df,
    input_size=wandb.config.input_size,
    batch_size=wandb.config.batch_size,
    shuffle=False,
    augment=False,
    multitask=wandb.config.multitask,
    label=wandb.config.main_target,
)

In [None]:
# Check if the dataset was loaded correctly

eval_dataset(train_dataset, multitask=wandb.config.multitask, label=wandb.config.main_target)

In [None]:
from torch import nn
from models.pytorch.architectures.InceptionV4 import (
    InceptionStem,
    InceptionA,
    InceptionB,
    InceptionC,
    ReductionA,
    ReductionB,
)
from utils.training_utils import weights_init

from torchsummary import summary


class InceptionV4(nn.Module):
    def __init__(self, num_classes, dropout_prob, dense_units):
        super(InceptionV4, self).__init__()

        self.stem = InceptionStem()

        self.inception_a_blocks = nn.Sequential(
            InceptionA(384),
            InceptionA(384),
            InceptionA(384),
            InceptionA(384),
        )

        self.reduction_a = ReductionA(384)

        self.inception_b_blocks = nn.Sequential(
            InceptionB(1024),
            InceptionB(1024),
            InceptionB(1024),
            InceptionB(1024),
            InceptionB(1024),
            InceptionB(1024),
            InceptionB(1024),
        )

        self.reduction_b = ReductionB(1024)

        self.inception_c_blocks = nn.Sequential(
            InceptionC(1536),
            InceptionC(1536),
            InceptionC(1536),
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(dropout_prob)
        self.fc1 = nn.Linear(1536, dense_units)
        self.fc2 = nn.Linear(dense_units, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.inception_a_blocks(x)
        x = self.reduction_a(x)
        x = self.inception_b_blocks(x)
        x = self.reduction_b(x)
        x = self.inception_c_blocks(x)
        x = self.avgpool(x)

        x = torch.flatten(x, 1)
        x = self.dropout(x)

        x = self.fc1(x)
        x = self.fc2(x)

        return x


model = InceptionV4(
    num_classes=1,
    dropout_prob=wandb.config.dropout,
    dense_units=wandb.config.dense_units,
).to(DEVICE).apply(weights_init)


summary(model, (3, wandb.config.input_size[0], wandb.config.input_size[1]))

In [None]:
# Define training parameters
import torch.optim as optim
from torch.cuda.amp import GradScaler

criterion = nn.L1Loss()
mse_metric = nn.MSELoss()

# Define scaler to use mixed precision training
scaler = GradScaler()

# Tunable training parameters
# Epochs
epochs = wandb.config.epochs

# Optimizer
if wandb.config.optimizer == "adam":
    optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
elif wandb.config.optimizer == "RMSprop":
    optimizer = optim.RMSprop(model.parameters(), lr=wandb.config.learning_rate)

# Learning rate scheduler
if wandb.config.learning_rate_scheduler == "ReduceLROnPlateau":
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=5, verbose=True
    )

In [None]:
# Training the model
from utils.training_utils import age_train_loop, age_valid_loop

train_losses = []
train_MSEs = []
valid_losses = []
valid_MSEs = []
early_stopping_patience = wandb.config.early_stopping_patience
best_valid_loss = float("inf")
epochs_without_improvement = 0

for t in range(epochs):
    print(
        f"-------------------------------\nEpoch {t+1}\n-------------------------------"
    )
    train_loss, train_mse = age_train_loop(
        train_dataset, model, criterion, mse_metric, optimizer, device=DEVICE, scaler=scaler
    )
    valid_loss, valid_mse = age_valid_loop(
        test_dataset, model, criterion, mse_metric, device=DEVICE
    )
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    train_MSEs.append(train_mse)
    valid_MSEs.append(valid_mse)

    print(
        "\nMetrics -> "
        f"loss: {train_loss:>5f} | "
        f"MSE: {train_mse:>5f} | "
        f"val_loss: {valid_loss:>5f} | "
        f"val_MSE: {valid_mse:>5f} | "
        f"LR: {optimizer.param_groups[0]['lr']}"
    )

    # Log metrics to wandb
    wandb.log(
        {
            "epoch/epoch": t,
            "epoch/mae": train_loss,
            "epoch/mse": train_mse,
            "epoch/val_mae": valid_loss,
            "epoch/val_mse": valid_mse,
            "epoch/learning_rate": optimizer.param_groups[0]["lr"],
        }
    )

    lr_scheduler.step(valid_loss)

    # Early Stopping and Save Best Model logic
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        epochs_without_improvement = 0

        # Save the best model
        torch.save(
            model.state_dict(),
            f"models/pytorch/weights/best_{WANDB_NAME}.pt",
        )
        print("\nBest model saved.")
    else:
        epochs_without_improvement += 1
        print(f"\nEpochs without improvement: {epochs_without_improvement}")

    if epochs_without_improvement >= early_stopping_patience:
        print("\nEarly stopping triggered.")
        break

print("Done!")

In [None]:
# Plotting the training and validation losses

import matplotlib.pyplot as plt

plt.plot(train_losses, label="Train loss")
plt.plot(valid_losses, label="Valid loss")
plt.legend()
plt.show()


In [None]:
# Plotting the training and validation MSEs

plt.plot(train_MSEs, label="Train MSE")
plt.plot(valid_MSEs, label="Valid MSE")
plt.legend()
plt.show()

In [None]:
# Test model with best weights on test dataset

model.load_state_dict(
    torch.load(
        f"models/pytorch/weights/best_{WANDB_NAME}.pt",
        map_location=torch.device(DEVICE),
    )
)

model.eval()
total_loss = 0
total_mse = 0
total = 0
with torch.no_grad():
    for images, labels in test_dataset:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = model(images)
        loss = criterion(outputs.view(-1), labels.float())
        mse = mse_metric(outputs.view(-1, 1), labels.view(-1, 1))
        total_loss += loss.item()
        total_mse += mse.item()
        total += 1

avg_test_loss = total_loss / total
avg_test_mse = total_mse / total
print(f"Average test loss: {avg_test_loss}")
print(f"Average test MSE: {avg_test_mse}")

# Log the test metrics to wandb
wandb.log({"test_mae": avg_test_loss, "test_mse": avg_test_mse})

In [None]:
# Add test results to summary and finish the experiment

wandb.run.summary["model_saved"] = f'best_{WANDB_NAME}.pt'
wandb.run.summary["test_mae"] = avg_test_loss
wandb.run.summary["test_mse"] = avg_test_mse

wandb.finish()