In [None]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torchvision

In [None]:
from vce_dataloader import getBinaryDataLoader, getAllDataLoader, visualize_batch
from model_file import getModel, getList

In [None]:
train_transform= transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.7), 
    transforms.RandomVerticalFlip(p=0.7),
    transforms.RandomRotation(15),
    
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

val_transform= transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

In [None]:
allDL_train= getAllDataLoader(image_size = (224,224), path_to_dataset = "/kaggle/input/vce-dataset/training", batch_size = 32, trans=train_transform)
allDL_val = getAllDataLoader(image_size = (224,224), path_to_dataset = "/kaggle/input/vce-dataset/validation", batch_size = 32, trans= val_transform)

In [None]:
binDL_train = getBinaryDataLoader(target_class_name="Normal", path_to_dataset="/kaggle/input/vce-dataset/training",batch_size=32, sampling = True, trans= train_transform)
binDL_val= getBinaryDataLoader(target_class_name="Normal", path_to_dataset="/kaggle/input/vce-dataset/validation",batch_size=32, sampling = True, trans= val_transform)

In [None]:
len(binDL_train), len(binDL_val)

In [None]:
len(allDL_train), len(allDL_val)

In [None]:
visualize_batch(allDL_train, nrow=8)

In [None]:
def get_predictions_and_labels(model, dataloader):
    model.eval()  # Set the model to evaluation mode
    all_preds = []
    all_labels = []
    
    with torch.no_grad():  # Disable gradient calculation
        for images, labels in dataloader:
            images = images.to(device)  # Move images to the same device as model
            labels = labels.to(device)  # Move labels to the same device as model
            outputs = model(images)
            _, preds = torch.max(outputs, 1)  # Get the predicted class
            all_preds.extend(preds.cpu().numpy())  # Collect predictions
            all_labels.extend(labels.cpu().numpy())  # Collect true labels

    return np.array(all_preds), np.array(all_labels)

In [None]:
def train_model(model, optimizer, criterion, n_epochs, trainDL, valDL):
    
    train_loss_history=[]
    val_loss_history=[]
    train_acc_history=[]
    val_acc_history=[]

    for epoch in range(n_epochs):
        model.train()
        print(f'Epoch [{epoch+1}/{n_epochs}]')
        running_loss = 0.0
        correct_predictions = 0
        total_samples  =0

        for i, data in enumerate(trainDL):
            inputs, labels= data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_samples += labels.size(0)

            if((i+1)%(len(trainDL)//4)==0):
                print(f"{i+1}/{len(trainDL)}: {loss.item()}")
            else: 
                print("#", end="")

        train_loss = running_loss / total_samples
        train_accuracy = correct_predictions / total_samples

        model.eval()
        val_loss = 0.0
        correct_val_predictions = 0
        total_val_samples = 0

        with torch.no_grad():
            for val_inputs, val_labels in valDL:
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)

                val_outputs = model(val_inputs)
                loss = criterion(val_outputs, val_labels)

                val_loss += loss.item() * val_inputs.size(0)
                _, val_predicted = torch.max(val_outputs, 1)
                correct_val_predictions += (val_predicted == val_labels).sum().item()
                total_val_samples += val_labels.size(0)

        val_loss /= total_val_samples
        val_accuracy = correct_val_predictions / total_val_samples

        # Print epoch stats

        print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}')
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
        train_loss_history.append(train_loss)
        val_loss_history.append(val_loss)
        train_acc_history.append(train_accuracy)
        val_acc_history.append(val_accuracy)


    return train_loss_history, val_loss_history, train_acc_history, val_acc_history

In [None]:
# getList() resnet50

## Direct Train on 10 classes

In [None]:
EPOCHS=15

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = getModel("resnet50")

In [None]:
model

In [None]:
model.fc= nn.Linear(2048, 10)

In [None]:
model= model.to(device)
criterion= nn.CrossEntropyLoss()
optimizer= torch.optim.Adam(model.parameters(),lr=1e-3)

In [None]:
dir_train_hist, dir_val_hist, dir_train_acc, dir_val_acc= train_model(model, optimizer, criterion, n_epochs=EPOCHS, trainDL = allDL_train, valDL = allDL_val )

In [None]:
# print("hi")

In [None]:
PATH = "Direct10.pt"
torch.save({
            'epoch': EPOCHS-1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
            }, PATH)

In [None]:
# Get the predictions and true labels for the test set
preds, true_labels = get_predictions_and_labels(model, allDL_val)

In [None]:
# Compute the confusion matrix
conf_matrix = confusion_matrix(true_labels, preds)

# Plot the confusion matrix using seaborn
plt.figure(figsize=(6, 4))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=[i for i in range(0,10)], yticklabels=[i for i in range(0,10)])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

## Binary then finetune

In [None]:
binary_model = getModel("resnet50")

In [None]:
binary_model.fc= nn.Linear(2048, 2)
binary_model= binary_model.to(device)

In [None]:
bin_criterion= nn.CrossEntropyLoss()
bin_optimizer= torch.optim.Adam(binary_model.parameters(), lr=1e-3)

In [None]:
bin_train_hist, bin_val_hist, bin_train_acc, bin_val_acc= train_model(binary_model, bin_optimizer, bin_criterion, n_epochs=EPOCHS, trainDL = binDL_train, valDL = binDL_val)

## Now Finetuning

In [None]:
binary_model.fc = nn.Linear(2048, 10)
binary_model = binary_model.to(device)

In [None]:
ft_criterion= nn.CrossEntropyLoss()
ft_optimizer= torch.optim.Adam(binary_model.parameters(), lr=1e-3)

In [None]:
train_hist, val_hist, train_acc, val_acc = train_model(binary_model, ft_optimizer, ft_criterion, n_epochs=EPOCHS, trainDL= allDL_train, valDL= allDL_val)

In [None]:
preds, true_labels = get_predictions_and_labels(binary_model, allDL_val)

In [None]:
# Compute the confusion matrix
conf_matrix = confusion_matrix(true_labels, preds)

# Plot the confusion matrix using seaborn
plt.figure(figsize=(6, 4))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=[i for i in range(0,10)], yticklabels=[i for i in range(0,10)])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

In [None]:
PATH = "fineTune.pt"
torch.save({
            'epoch': EPOCHS-1,
            'model_state_dict': binary_model.state_dict(),
            'optimizer_state_dict': bin_optimizer.state_dict()
            }, PATH)