# Import for Code

In [None]:
import torch  # PyTorch library
import torch.nn as nn  # Module for neural networks
import torch.optim as optim  # Optimization algorithms
import torchvision.transforms as transforms  # Module for image transformations
from torchvision import datasets  # Image datasets
from torch.utils.data import DataLoader  # Data loaders 
import matplotlib.pyplot as plt  # Visualization tool
import numpy as np  # Mathematical and array manipulation tool
from collections import Counter  # Tool for counting elements
import torchvision.models as models  # Pre-trained models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score  # Evaluation metrics
from sklearn.model_selection import train_test_split  # Dataset splitting
import optuna  # Hyperparameter optimization tool
import os  # Operating system-related tool
import csv  # Tool for handling CSV files

# Plotting 
aims to visualize the training and validation losses per epoch and the validation accuracy per epoch. Additionally, it plots the micro-average AUROC (Area Under the Receiver Operating Characteristic curve) during training.

In [None]:
# Plotting the training and validation loss
def draw_train_val_curve(train_losses, val_losses, val_accuracies, val_micro_aurocs):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title("Training and Validation Losses per Epoch")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('train_val_losses_p_epoch.png')
    plt.close()
    plt.clf()

    plt.figure(figsize=(10, 6))
    val_accuracies_cpu = [acc.cpu().numpy() for acc in val_accuracies]
    plt.plot(val_accuracies_cpu, label='Validation Accuracy')
    plt.title('Validation Accuracy per Epoch')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig('val_acc_p_epoch.png')
    plt.close()
    plt.clf()

    plt.figure(figsize=(10, 6))
    plt.plot(val_micro_aurocs, label='Micro-average AUROC (Training)')
    plt.title('Micro-average AUROC per Epoch (Training)')
    plt.xlabel('Epochs')
    plt.ylabel('AUROC')
    plt.legend()
    plt.savefig('val_micro_auroc_p_epoch.png')
    plt.close()
    plt.clf()


# Select HyperParameter 

In [None]:
torch.set_num_threads(1)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Image_Size = 227 
Num_Epochs = 5
Learning_Rate = 0.0001 
Batch_Size = 8
Model_Name = 'Model_name'

# DataPreprocessing
Manipulation and transformation of raw data into a format before analysis of data

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((227, 227)),          
    transforms.ToTensor(), 
])

val_transform = transforms.Compose([
        transforms.Resize((227, 227)), 
        transforms.ToTensor()
])

# DataLoad
Data loading involves the process of reading and loading data into memory
Data Preparation: preparing the data into a usable format.
Data Loading: data is loaded into memory.
Mini-Batch Formation: Handling large datasets all at once can be inefficient, so the data is often divided into smaller batches.
Data Shuffling: prevent the model from learning any spurious patterns based on the order of the data.


In [None]:
train_dir = "your train_dataset"
val_dir = "your val_dataset"

train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size = Batch_Size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size = Batch_Size, shuffle =False)

class_names = train_dataset.classes
num_classes = len(train_dataset.classes)

# Model Load
example(Resnet)

In [None]:
weights = models.ResNet50_Weights.IMAGENET1K_V1
model = models.resnet50(weights=weights)
classifier = model.fc
last_layer_in_features = classifier.in_features
model.fc = nn.Linear(last_layer_in_features, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = Learning_Rate)

device = torch.device("cuda" if torch.cuda.is_available() else"cpu")
model.to(device)

# Model Train, Val

In [None]:
best_val_loss = float('inf')
best_val_acc = float(0.0)
best_val_auroc = float(0.0)
p_acc_counter =0 
p_loss_counter = 0
p_auroc_counter = 0

patience = 300  # Number of epochs to wait for improvement before stopping
patience_counter = 0
train_losses =[]
val_losses=[]
val_accuracies =[]
all_preds=[]
all_labels =[]
val_micro_aurocs = []
log_file_path = 'training_log.txt'
csv_file_path = 'training_metrics.csv'

for epoch in range(Num_Epochs):
    model.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        
    model.eval()
    val_loss = 0.0
    val_corrects = 0
    val_labels = []
    val_probas = []  # Later used for detailed predictions
    with torch.no_grad():  # Deactivate gradient calculation
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)  # Move data to appropriate device
            outputs = model(images)  # Perform prediction through the model
            loss = criterion(outputs, labels)  # Calculate loss
            val_loss += loss.item() * images.size(0)  # Accumulate loss
            _, preds = torch.max(outputs, 1)  # Extract class with highest probability
            probabilities = torch.nn.functional.softmax(outputs, dim=1)  # Calculate probabilities of each class

            # Calculate accuracy and save labels
            val_corrects += torch.sum(preds == labels.data)  # Accumulate number of correctly predicted
            all_preds.extend(preds.cpu().numpy())  # Save prediction results
            all_labels.extend(labels.cpu().numpy())  # Save actual labels
            val_labels.extend(labels.cpu().numpy())  # Save actual labels
            val_probas.extend(probabilities.cpu().numpy())  # Save prediction probabilities
        val_probas = np.array(val_probas)
        val_labels = np.array(val_labels)

        train_loss = train_loss / len(train_loader.dataset)
        val_loss = val_loss / len(val_loader.dataset)
        val_accuracy = val_corrects.double() / len(val_loader.dataset)
        val_auroc = roc_auc_score(val_labels,val_probas,average='micro',multi_class='ovr')

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        val_micro_aurocs.append(val_auroc)
        draw_train_val_curve(train_losses,val_losses,val_accuracies,val_micro_aurocs)

    # Write metrics to CSV file
    with open(csv_file_path, 'a', newline='') as csv_file:
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow([epoch + 1, train_loss, val_auroc, val_loss, val_accuracy])

    # Check for improvement in validation loss, accuracy, and AUROC
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'valloss.pth')
        p_loss_counter = 0
    else:
        p_loss_counter+=1
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        torch.save(model.state_dict(), 'valacc.pth')
        p_acc_counter = 0
    else:
        p_acc_counter+=1
    if val_auroc >  best_val_auroc:
        best_val_auroc = val_auroc
        torch.save(model.state_dict(), 'auroc.pth')
        p_auroc_counter = 0
    else:
        p_auroc_counter+=1

    # Check patience
    if patience_counter >= patience:
        print("Stopping early due to no improvement in validation loss.")
        break 

    torch.save(model.state_dict(), 'epoch.pth')

    # Log training progress
    with open(log_file_path, 'a') as log_file:
        log_file.write(f"Epoch {epoch+1}/{Num_Epochs}, Train Loss: {train_loss:.4f}, Val auroc:{val_auroc}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy}, patience counter(acc,loss,auroc): {p_acc_counter},{p_loss_counter},{p_auroc_counter}\n")
