In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, 
import os

In [None]:
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms as T # for simplifying the transforms
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models

In [None]:
# remove warnings
import warnings
warnings.filterwarnings("ignore")

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
# Data Visualization
import plotly.express as px

In [None]:
import sys
from tqdm import tqdm
import time
import copy

In [None]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [None]:
import os
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader, random_split
import torchvision.datasets as datasets
import numpy as np

# Set seed for reproducibility
seed = 12
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def get_data_loaders(data_dir, batch_size):
    transform = T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
        T.RandomRotation(degrees=15),  # Added random rotation within ±15 degrees
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    full_dataset = datasets.ImageFolder(os.path.join(data_dir, "/kaggle/input/monkeypox-aug-munim/Monkeypox_Aug/"), transform=transform)

    # Calculate the sizes of the training and validation sets
    dataset_size = len(full_dataset)
    train_size = int(0.8 * dataset_size)
    val_size = dataset_size - train_size

    # Manually split the dataset into training and validation sets
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    test_transform = T.Compose([
        T.ToTensor(),
        T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    test_data = datasets.ImageFolder(os.path.join(data_dir, "/kaggle/input/monkeypox-aug-munim/Monkeypox_Aug/"), transform=test_transform)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader, test_loader, len(train_dataset), len(val_dataset), len(test_data)

In [None]:
data_dir = "/kaggle/input/monkeypox-aug-munim/Monkeypox_Aug/"
batch_size = 128
train_loader, val_loader, test_loader, train_size, val_size, test_size = get_data_loaders(data_dir, batch_size)

In [None]:
classes = get_classes("/kaggle/input/monkeypox-aug-munim/Monkeypox_Aug/")

In [None]:
# Get Class Names
class_names = sorted(os.listdir(data_dir))
n_classes = len(class_names)

# Show
print(f"Class Names : {class_names}")
print(f"Number of Classes  : {n_classes}")

In [None]:
# Calculate class distribution
class_dis = [len(os.listdir(data_dir + name)) for name in class_names]
class_dis

In [None]:
# Visualization
fig = px.pie(names=class_names, values=class_dis, title="Training Class Distribution")
fig.update_layout({'title':{'x':0.45}})
fig.show()

In [None]:
pred_dataset_path = "/kaggle/input/monkeypox-aug-munim/Monkeypox_Aug/"
# Calculate class distribution
class_dis = [len(os.listdir(pred_dataset_path + name)) for name in class_names]

# Visualization
fig = px.pie(names=class_names, values=class_dis, title="Prediction Class Distribution")
fig.update_layout({'title':{'x':0.45}})
fig.show()

In [None]:
dataloaders = {
    "train": train_loader,
    "val": val_loader
}

dataset_sizes = {
    "train": train_size,
    "val": val_size
}

In [None]:
print(len(train_loader), len(val_loader), len(test_loader))

In [None]:
print("Train dataset size:", train_size)
print("Validation dataset size:", val_size)
print("Test dataset size:", test_size)

In [None]:
# now, for the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
!pip install timm # kaggle doesnt have it installed by default
import timm
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init

# Assuming 'model' is your existing model
# Freeze all parameters in the existing model
for param in model.parameters():
    param.requires_grad = False

# Get the number of input features of the existing fc layer
n_inputs = model.head.in_features

# Modify the fc layer and add more layers for reducing val loss
model.head= nn.Sequential(
    nn.Linear(n_inputs, 1024),
    nn.BatchNorm1d(1024),  # Batch Normalization
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(1024, 512),
    nn.BatchNorm1d(512),   # Batch Normalization
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 256),
    nn.BatchNorm1d(256),   # Batch Normalization
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, 128),
    nn.BatchNorm1d(128),   # Batch Normalization
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 64),
    nn.BatchNorm1d(64),    # Batch Normalization
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(64, len(classes)),
    nn.Softmax(dim=1)  # Softmax activation function
)

model = model.to(device)
print(model)

In [None]:
import torch.optim as optim
# Add weight decay to the optimizer
weight_decay = 1e-5  # You can adjust this value based on your needs
optimizer = optim.AdamW(model.head.parameters(), lr=0.001, weight_decay=weight_decay)
# Learning rate scheduler
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.97)

criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

In [None]:
from sklearn.metrics import confusion_matrix, f1_score, classification_report, cohen_kappa_score, roc_auc_score, roc_curve, auc
from sklearn.preprocessing import label_binarize
import time
import copy
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
!pip install seaborn
import seaborn as sns
from sklearn.metrics import confusion_matrix, f1_score, classification_report, cohen_kappa_score, roc_auc_score, roc_curve, auc
from sklearn.preprocessing import label_binarize
import seaborn as sns

import time
import copy
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, classes, device,
                num_epochs=200, patience=500, save_weights_every=199):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    no_improvement_count = 0

    val_true_labels = []
    val_pred_labels = []
    val_pred_probs = []
    train_acc_history = []
    val_acc_history = []
    train_loss_history = []
    val_loss_history = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-" * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                if phase == 'val':
                    val_true_labels += labels.tolist()
                    val_pred_labels += preds.tolist()
                    val_pred_probs += torch.softmax(outputs, dim=1).tolist()

            if phase == 'train':
                train_loss_history.append(running_loss / dataset_sizes[phase])
            else:
                val_loss_history.append(running_loss / dataset_sizes[phase])

            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, running_loss / dataset_sizes[phase], epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())  # keep the best validation accuracy model
                no_improvement_count = 0
            elif phase == 'val':
                no_improvement_count += 1

            if phase == 'train':
                train_acc_history.append(epoch_acc.item())
            else:
                val_acc_history.append(epoch_acc.item())

        print()

        if no_improvement_count >= patience:
            print(f"No improvement in validation accuracy for {no_improvement_count} epochs. Early stopping...")
            break

        # Save weights after every 'save_weights_every' epochs
        if epoch % save_weights_every == 0 and epoch > 0:
            save_path = f'model_weights_epoch_{epoch}.pth'
            torch.save(model.state_dict(), save_path)
            print(f'Model weights saved at epoch {epoch} to {save_path}')

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))

    model.load_state_dict(best_model_wts)



    # Calculate confusion matrix, F1 score, sensitivity, specificity, accuracy, and Cohen's kappa
    confusion_mat = confusion_matrix(val_true_labels, val_pred_labels)
    f1 = f1_score(val_true_labels, val_pred_labels, average='weighted')
    sensitivity = confusion_mat.diagonal() / confusion_mat.sum(axis=1)
    specificity = np.diag(confusion_mat) / np.sum(confusion_mat, axis=1)
    accuracy = np.sum(np.diag(confusion_mat)) / np.sum(confusion_mat)
    true_positives = np.diag(confusion_mat)
    true_negatives = np.sum(confusion_mat) - (np.sum(true_positives) + np.sum(confusion_mat.sum(axis=0)) - np.sum(true_positives))
    false_positives = confusion_mat.sum(axis=0) - true_positives
    false_negatives = confusion_mat.sum(axis=1) - true_positives
    kappa = cohen_kappa_score(val_true_labels, val_pred_labels)

    # Print confusion matrix, F1 score, sensitivity, specificity, accuracy, and Cohen's kappa
    print("Confusion Matrix:")
    print(confusion_mat)
    print("F1 Score: {:.4f}".format(f1))
    print("Sensitivity (Recall):", sensitivity)
    print("Specificity:", specificity)
    print("Accuracy: {:.4f}".format(accuracy))
    print("True Positives:", true_positives)
    print("True Negatives:", true_negatives)
    print("False Positives:", false_positives)
    print("False Negatives:", false_negatives)
    print("Cohen's Kappa:", kappa)


    # Print classification report
    target_names = [str(i) for i in range(len(classes))]
    print(classification_report(val_true_labels, val_pred_labels, target_names=target_names))
    
    
    
    # Plot accuracy curves
    plt.figure(figsize=(10, 5))
    sns.lineplot(x=range(1, len(train_acc_history) + 1), y=train_acc_history, label='Train', linestyle='-', color='#DC8686')
    sns.lineplot(x=range(1, len(val_acc_history) + 1), y=val_acc_history, label='Validation', linestyle='-', color='#59CE8F')

    # Create a band around the train line
    plt.fill_between(range(1, len(train_acc_history) + 1), np.array(train_acc_history) - 0.02, np.array(train_acc_history) + 0.02, color='#DC8686', alpha=0.2)

    # Create a band around the validation line
    plt.fill_between(range(1, len(val_acc_history) + 1), np.array(val_acc_history) - 0.02, np.array(val_acc_history) + 0.02, color='#59CE8F', alpha=0.2)

    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    plt.show()

    # Plot loss curves
    plt.figure(figsize=(10, 5))
    sns.lineplot(x=range(1, len(train_loss_history) + 1), y=train_loss_history, label='Train', linestyle='-', color='#DC8686')
    sns.lineplot(x=range(1, len(val_loss_history) + 1), y=val_loss_history, label='Validation', linestyle='-', color='#59CE8F')

    # Create a band around the train line
    plt.fill_between(range(1, len(train_loss_history) + 1), np.array(train_loss_history) - 0.02, np.array(train_loss_history) + 0.02, color='#DC8686', alpha=0.2)

    # Create a band around the validation line
    plt.fill_between(range(1, len(val_loss_history) + 1), np.array(val_loss_history) - 0.02, np.array(val_loss_history) + 0.02, color='#59CE8F', alpha=0.2)

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()




    confusion_mat = confusion_matrix(val_true_labels, val_pred_labels, labels=np.arange(len(classes)))
    # Plot customized confusion matrix
    plt.figure(figsize=(8, 8))
    sns.heatmap(confusion_mat, annot=True, fmt=".0f", cmap="GnBu", linewidths=.5, square=True, cbar=False,
                xticklabels=classes, yticklabels=classes)
    
    plt.title('Confusion Matrix', fontsize=16)
    plt.xlabel('Predicted label', fontsize=12)
    plt.ylabel('True label', fontsize=12)
    
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0, ha='right')
    
    plt.tight_layout()
    plt.show()
    
    return model

In [None]:
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler, dataloaders, dataset_sizes, classes, device)
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model.cpu(), example)
traced_script_module.save("/kaggle/working/monkeypox_deit.pt")