# CLASSIFICATION TRAINING

In [None]:
import os; os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import sys
sys.path.append("../../../../src")

In [None]:
# general imports
from pathlib import Path
import torch
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, random_split, DataLoader, ConcatDataset
import torch.nn as nn
import time
import json
from torchvision.models import mobilenet_v3_small, efficientnet_b0, resnet18, densenet121
from torchvision.models import MobileNet_V3_Small_Weights, EfficientNet_B0_Weights, ResNet18_Weights, DenseNet121_Weights

In [None]:
from classification.annotator import manifest2classification, multilabel2dataframe
from classification.dataset import ClassificationDataset
from classification.utils import get_mean_and_std, plot_label_distribution
from classification.trainer import ClassificationTrainer, calculate_metrics, FocalLoss
from classification.visualizer import visualize, save_model_errors, plot_loss_function, generate_confusion_matrix_plot

# Data Prep

### Data Parameters

In [None]:
# type of task
task_type = "multiclass" # or "multilabel". If you are not sure, try "multiclass"

# define the classes
classes = ["label1", "label2", "label3"]

# Label file location
# manifest file for multiclass OR json file for multilabel classification
label_file = "../../data/raw/v2_output.manifest"

# path to raw images
images_path = Path("../../data/raw/s3_v2_images")

# get the label key from the manifest file
label_key = "label-metadata"

# device name (use "cuda" if you are using Sagemaker)
device = torch.device("mps")

# input image size
input_image_size = 224

# batch size
batch_size = 32

# epochs
n_epochs = 5

# loss fn
loss_alpha = 3
loss_gamma = 2
model_loss_fn = FocalLoss(alpha=loss_alpha, gamma=loss_gamma)

# Name of the model file to be saved
model_file_name = "classifier_project"

### Split raw data into train & validation across labels

In [None]:
if task_type == "multiclass":
    manifest2classification(label_file, images_path, label_key)
    annotations_df_train = annotations_df_validation = None
elif task_type == "multilabel":
    annotations_df_train, annotations_df_validation = multilabel2dataframe(label_file, classes)
else:
    print("The type of task must be either multilabel or multiclass")

### Plot distribution of labels

In [None]:
plot_label_distribution(images_path)

### Extract mean and std of Training data

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((input_image_size, input_image_size)),
    transforms.ToTensor()]
)
train_dataset = datasets.ImageFolder(root = images_path/"train", transform = train_transforms)
train_loader = DataLoader(dataset = train_dataset, batch_size=32, shuffle=False)
mean, std = get_mean_and_std(train_loader)
print(mean, std)

### Apply transformations of Training data

In [None]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'base': transforms.Compose([
        transforms.Resize((input_image_size, input_image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    'aug': transforms.Compose([
        transforms.RandomResizedCrop(size=input_image_size, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(p=0.1),
        transforms.RandomRotation(degrees=(-10, 10)),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
}

### Generate Training and Validation data

In [None]:
base_dataset = ClassificationDataset(
    images_path / "train", classes, task_name = task_type, transform = data_transforms['base'], annotations_df = annotations_df_train)

augmented_dataset = ClassificationDataset(
    images_path / "train", classes, task_name = task_type, transform = data_transforms['aug'], annotations_df = annotations_df_train)

# Concatenate the original and augmented datasets to form train dataset
train_dataset = ConcatDataset([base_dataset, augmented_dataset])


validation_dataset = ClassificationDataset(
    images_path / "validation", classes, task_name = task_type, transform = data_transforms['base'], annotations_df = annotations_df_validation)

In [None]:
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size * 2)

# Modeling

In [None]:
# Choose (model, pre-trained weight) combo from the below list
# ["mobilenet_v3_small", "MobileNet_V3_Small_Weights"]
# ["mobilenet_v3_large", "MobileNet_V3_Large_Weights"]
# ["alexnet", "AlexNet_Weights"]
# ["densenet121", "DenseNet121_Weights"]
# ["efficientnet_b0", "EfficientNet_B0_Weights"]
# ["efficientnet_v2_s", "EfficientNet_V2_S_Weights"]
# ["efficientnet_v2_m", "EfficientNet_V2_M_Weights"]
# ["efficientnet_v2_l", "EfficientNet_V2_L_Weights"]
# ["inception_v3", "Inception_V3_Weights"]
# ["resnet18", "ResNet18_Weights"]
# ["resnet50", "ResNet50_Weights"]
# ["vgg16", "VGG16_Weights"]
# ["vit_b_16", "ViT_B_16_Weights"]
# For more models chek out https://pytorch.org/vision/stable/models.html#classification

# ["mobilenet_v3_small", "MobileNet_V3_Small_Weights"] is selected by default
model_details = ["mobilenet_v3_small", "MobileNet_V3_Small_Weights"]

### Define a Trainer for your model with customizable hyper-parameters

In [None]:
trainer = ClassificationTrainer(classes,
                                train_loader,
                                val_loader,
                                batch_size,
                                batch_size*2,
                                model_details,
                                task_type,
                                loss_fun = model_loss_fn,
                                num_epochs = n_epochs,
                                patience = 3,
                                criterion = "val_f2",
                                model_file_name_prefix = model_file_name
                               )

# Training

In [None]:
start_time = time.time()
metrics_dict = trainer.train()
print("Time for training:", time.time() - start_time)

# Model Performance

### Generate loss function graph

In [None]:
metrics_dict = plot_loss_function(metrics_dict, model_file_name)

### Load saved model

In [None]:
classifier_model = mobilenet_v3_small(
        weights=MobileNet_V3_Small_Weights.DEFAULT
    )
classifier_model.classifier[3] = nn.Linear(
    in_features=1024, out_features=3, bias=True
)
classifier_model.load_state_dict(torch.load(f"{model_file_name}.pt"))

### Generate Confusion matrix

In [None]:
confusion_matrix = generate_confusion_matrix_plot(val_loader, classifier_model, "cpu", classes, model_file_name)

### Save Model Performance and corresponding hyperparameters

In [None]:
# modify tensors
metrics_dict['val_loss'] = [val.item() for val in metrics_dict['val_loss']]

model_parameters = {
    "name": model_file_name,
    "model_details": model_details,

    "batch_size": batch_size,
    "input_image_size": input_image_size,
    "augmentation":str(data_transforms["aug"]),
    "epochs": n_epochs,
    "loss": str(model_loss_fn),
    "loss_alpha": loss_alpha,
    "loss_gamma": loss_gamma,

    "n_train": len(train_dataset),
    "n_val": len(validation_dataset),
    "perf_metrics":metrics_dict,

    "class_labels": classes,
    "confusion_matrix": str(confusion_matrix)
}

metadata_json_path = f"{model_file_name}.json"
with open(metadata_json_path, "w") as outfile:
    outfile.write(json.dumps(model_parameters, indent=4))
print(f"Saved model params to {metadata_json_path}")

# Visualize Random Correct and Incorrect Examples 
(One Each) from Validation Set

(Only for 'multiclass' task)

In [None]:
# name of the class you want to visualize
viz_class_name = "label1"

In [None]:
# create new dataloader
# use batch size 1, and set shuffle to True
val_loader = DataLoader(validation_dataset, 1, shuffle=True)

In [None]:
# vizualize examples
visualize(viz_class_name, trainer, val_loader)

### Analysis of Validation Set Errors

In [None]:
# create new dataloader
# use batch size 1, and set shuffle to True
val_loader = DataLoader(validation_dataset, 1, shuffle=True)

In [None]:
error_dict = save_model_errors(trainer, val_loader)
error_dict