In [None]:
import os

import pandas as pd
import seaborn as sn
import torch
from IPython.core.display import display
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger


In [None]:
# Reference : https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/cifar10-baseline.html

DATA_FOLDER = "./data"
BATCH_SIZE = 512
NUM_WORKERS = int(os.cpu_count() / 2)
print(NUM_WORKERS)

In [None]:
from utility import cifar10Utility, imageAugmentationUtility

# import importlib
# imageAugmentationUtility = importlib.reload(imageAugmentationUtility)

train_transforms, test_transforms = imageAugmentationUtility.get_cifar10_train_and_test_transforms(cifar10Utility.get_mean(),
                                                                                                   cifar10Utility.get_std())

train_dataset, validation_dataset, test_dataset  = cifar10Utility.get_datasets(train_transforms_collection=train_transforms,
                                                                   test_transforms_collection=test_transforms,
                                                                    data_folder=DATA_FOLDER)
print(f"Images in train_dataset are :{len(train_dataset)}, validation_dataset: {len(validation_dataset)}, and test_dataset: {len(test_dataset)}")
train_loader, validation_loader, test_loader = cifar10Utility.get_dataloaders(train_dataset=train_dataset,
                                                                    validation_dataset=validation_dataset,
                                                                    test_dataset=test_dataset,
                                                                    num_workers=NUM_WORKERS,
                                                                    batch_size=BATCH_SIZE)
print(f"Batches count in train data loader are :{len(train_loader)}, validation loader: {len(validation_loader)},and test data loader: {len(test_loader)}")


In [None]:
from utility import commonUtility
from utility import imageVisualizationUtility

In [None]:
# get some random training images
images, labels = commonUtility.get_random_images_from_data_loader(train_loader, images_count=8)
labels = cifar10Utility.get_labels_names(labels_indexes=labels)
print(labels)
imageVisualizationUtility.show(images, labels)

# show images using torch vision grid function
# imageVisualizationUtility.show(torchvision.utils.make_grid(images), labels="-".join(labels))

In [None]:
# Print single image from test dataset
import random
image, label = test_dataset[random.randint(0,len(test_dataset)-1)]
label = cifar10Utility.get_labels_names(labels_indexes=label)
imageVisualizationUtility.show(image, label)

In [None]:
from  models import modelHandler
import torchmetrics
import torchmetrics.classification
# modelHandler = importlib.reload(modelHandler)

model_handler = modelHandler.ModelHandler(batch_size=BATCH_SIZE)

In [None]:
model = model_handler.get_lightning_model_instance()
model_handler.show_model_summary(model)

In [None]:
from lightning.pytorch.loggers import TensorBoardLogger

trainer = Trainer(
    max_epochs=30,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    logger=CSVLogger(save_dir="logs/"), #[TensorBoardLogger(save_dir="logs/"), CSVLogger(save_dir="logs/")],
    auto_lr_find=True,
    callbacks=[LearningRateMonitor(logging_interval="step"), TQDMProgressBar(refresh_rate=10)]
)

# Find the learning rate
# result = trainer.tune(model, train_loader)
lr_finder = trainer.tuner.lr_find(model, train_loader, validation_loader, num_training=200)
new_lr = lr_finder.suggestion()
print(f"Suggested LR: {new_lr}")
fig = lr_finder.plot(suggest=True)
fig.show()

In [None]:
# Setting new learning rate
model.hparams.lr = new_lr
trainer.fit(model, train_loader, validation_loader)
trainer.test(model, test_loader)

In [None]:
print(f"{trainer.logger.log_dir}/metrics.csv")
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")

del metrics["step"]
del metrics["valid_loss"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all"))
sn.relplot(data=metrics, kind="line")

In [None]:
MODEL_NAME = "lightning_resnet18.pth"

In [None]:
# Saving model 
torch.save(model.state_dict(), f"{trainer.logger.log_dir}/{MODEL_NAME}")
# Saving at top level as well
torch.save(model.state_dict(), MODEL_NAME)

In [None]:
# Creating new model instance and loading weights
# model_location = f"{trainer.logger.log_dir}/lightning_resnet18.pth"

new_loaded_model = model_handler.get_lightning_model_instance(saved_model=MODEL_NAME)
model_handler.show_model_summary(new_loaded_model)

# Set the model to evaluation mode (disable dropout, randomness, etc.)
new_loaded_model = new_loaded_model.eval()

In [None]:

batch_images, batch_labels = commonUtility.get_random_images_batch_and_labels_from_data_loader(test_loader)

#  {"images" : images, "predicted_labels" : predicted_labels, "actual_labels" : actual_labels}
non_matched_results, matched_results = commonUtility.get_images_for_matched_and_non_matched_model_predictions(new_loaded_model, batch_images, batch_labels, max_image_count=10)

In [None]:
print("For Matched results")
predicted_labels_names = cifar10Utility.get_labels_names(labels_indexes=matched_results["predicted_labels"])
actual_labels_names = cifar10Utility.get_labels_names(labels_indexes=matched_results["actual_labels"])
matched_combined_labels = commonUtility.combine_labels(predicted_labels_names, actual_labels_names)
imageVisualizationUtility.show(matched_results["images"], matched_combined_labels)

In [None]:
print("For Non-Matched results")
predicted_labels_names = cifar10Utility.get_labels_names(labels_indexes=non_matched_results["predicted_labels"])
actual_labels_names = cifar10Utility.get_labels_names(labels_indexes=non_matched_results["actual_labels"])
non_matched_combined_labels = commonUtility.combine_labels(predicted_labels_names, actual_labels_names)
imageVisualizationUtility.show(non_matched_results["images"], non_matched_combined_labels)

In [None]:
from utility import gradcamUtility

In [None]:
print("For Matched results")
# param image_weight: The final result is image_weight * img + (1-image_weight) * mask
heatmap_overlaid_images = gradcamUtility.create_grad_cam_overlaid_images(new_loaded_model.model, 
                                                                         [new_loaded_model.model.layer3[-1]],
                                                                         images=matched_results["images"],
                                                                         predictions_labels=matched_results["predicted_labels"],
                                                                         actual_labels=matched_results["actual_labels"],
                                                                         image_weight=0.98
                                                                         )
imageVisualizationUtility.show(heatmap_overlaid_images, matched_combined_labels)

In [None]:
print("For Non-Matched results")
# param image_weight: The final result is image_weight * img + (1-image_weight) * mask
heatmap_overlaid_images = gradcamUtility.create_grad_cam_overlaid_images(new_loaded_model.model, 
                                                                         [new_loaded_model.model.layer3[-1]],
                                                                         images=non_matched_results["images"],
                                                                         predictions_labels=non_matched_results["predicted_labels"],
                                                                         actual_labels=non_matched_results["actual_labels"],
                                                                         image_weight = 0.98
                                                                         )
imageVisualizationUtility.show(heatmap_overlaid_images, non_matched_combined_labels)