<a href="https://colab.research.google.com/github/truong1410/Gastro/blob/main/Untitled13.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install required packages
!pip install torch torchvision Pillow scikit-learn matplotlib

# Import libraries
import torch
import torchvision
import time
from torch import nn
from torch import optim
from torch.utils import data
from PIL import Image
import os
import re
import argparse
from collections import defaultdict
import numpy as np
import logging
import csv
from torchvision import transforms, datasets, models
import sklearn.metrics as mtc
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import itertools

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define paths - modify these to match your Google Drive structure
base_path = "/content/drive/MyDrive/"  # Change this to your path
train_root_dir = os.path.join(base_path, "Gastro/train")
val_root_dir = os.path.join(base_path, "Gastro/val")
test_root_dir = os.path.join(base_path, "Gastro/test")
model_path = os.path.join(base_path, "checkpoints/")  # For saving models

# Create directory if it doesn't exist
os.makedirs(model_path, exist_ok=True)

# Parameters (you can adjust these)
max_epochs = 5
batch_size = 32
lr = 0.0005
n_classes = 22  # Number of classes

# Define filename for saving results
filename = f'results_e{max_epochs}_b{batch_size}_lr{lr}_densenet121_improved.csv'

# Define transforms
trans = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.4762, 0.3054, 0.2368], [0.3345, 0.2407, 0.2164])
    ]),
    'valid': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.4762, 0.3054, 0.2368], [0.3345, 0.2407, 0.2164])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.4762, 0.3054, 0.2368], [0.3345, 0.2407, 0.2164])
    ]),
}

# Helper functions (same as before)
def print_metrics(metrics, num_steps):
    outputs = []
    for k in metrics.keys():
        if k == 'dice_coeff' or k == 'dice' or k == 'bce':
            outputs.append('{}:{:4f}'.format(k, metrics[k]/num_steps))
        else:
            outputs.append('{}:{:2f}'.format(k, metrics[k]))
    print('{}'.format(','.join(outputs)))

def training_curve(epochs, lossesT, lossesV):
    plt.plot(epochs, lossesT, 'c', label='Training loss')
    plt.plot(epochs, lossesV, 'm', label='Validation loss')
    plt.title("Training Curve")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig('train_val_epoch_curve.png')
    plt.close()

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix',
                          cmap=plt.cm.Blues, plt_size=[10,10]):
    plt.rcParams['figure.figsize'] = plt_size
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig('confusion_matrix.png')
    plt.close()
    print("Confusion matrix saved.")

def validate_net(model, validation_generator, device, criterion):
    num_steps = 0
    val_loss = 0
    val_metrics = defaultdict(float)
    all_labels_d = torch.tensor([], dtype=torch.long).to(device)
    all_predictions_d = torch.tensor([], dtype=torch.long).to(device)

    model.eval()
    with torch.no_grad():
        for image, labels in validation_generator:
            image, labels = image.to(device, dtype=torch.float32), labels.to(device)
            outputs = model(image)
            loss = criterion(outputs, labels)

            num_steps += image.size(0)
            val_loss += loss.item() * image.size(0)

            _, predicted = torch.max(outputs, dim=1)
            all_labels_d = torch.cat((all_labels_d, labels), 0)
            all_predictions_d = torch.cat((all_predictions_d, predicted), 0)

    y_true = all_labels_d.cpu()
    y_predicted = all_predictions_d.cpu()

    # Calculate metrics
    val_metrics['micro_precision'] = mtc.precision_score(y_true, y_predicted, average="micro")
    val_metrics['micro_recall'] = mtc.recall_score(y_true, y_predicted, average="micro")
    val_metrics['micro_f1'] = mtc.f1_score(y_true, y_predicted, average="micro")
    val_metrics['macro_precision'] = mtc.precision_score(y_true, y_predicted, average="macro")
    val_metrics['macro_recall'] = mtc.recall_score(y_true, y_predicted, average="macro")
    val_metrics['macro_f1'] = mtc.f1_score(y_true, y_predicted, average="macro")
    val_metrics['mcc'] = mtc.matthews_corrcoef(y_true, y_predicted)

    return (val_loss/num_steps), val_metrics, num_steps

def test_net(model, test_generator, device, criterion):
    num_steps = 0
    test_loss = 0
    test_metrics = defaultdict(float)
    all_labels_d = torch.tensor([], dtype=torch.long).to(device)
    all_predictions_d = torch.tensor([], dtype=torch.long).to(device)

    model.eval()
    with torch.no_grad():
        for image, labels in test_generator:
            image, labels = image.to(device, dtype=torch.float32), labels.to(device)
            outputs = model(image)
            loss = criterion(outputs, labels)

            num_steps += image.size(0)
            test_loss += loss.item() * image.size(0)

            _, predicted = torch.max(outputs, dim=1)
            all_labels_d = torch.cat((all_labels_d, labels), 0)
            all_predictions_d = torch.cat((all_predictions_d, predicted), 0)

    y_true = all_labels_d.cpu()
    y_predicted = all_predictions_d.cpu()

    # Calculate metrics
    test_metrics['micro_precision'] = mtc.precision_score(y_true, y_predicted, average="micro")
    test_metrics['micro_recall'] = mtc.recall_score(y_true, y_predicted, average="micro")
    test_metrics['micro_f1'] = mtc.f1_score(y_true, y_predicted, average="micro")
    test_metrics['macro_precision'] = mtc.precision_score(y_true, y_predicted, average="macro")
    test_metrics['macro_recall'] = mtc.recall_score(y_true, y_predicted, average="macro")
    test_metrics['macro_f1'] = mtc.f1_score(y_true, y_predicted, average="macro")
    test_metrics['mcc'] = mtc.matthews_corrcoef(y_true, y_predicted)

    # Confusion matrix and classification report
    cm = confusion_matrix(y_true, y_predicted)
    class_names = test_generator.dataset.classes
    plot_confusion_matrix(cm, classes=class_names, title='Confusion Matrix')

    print('Accuracy of the network on the %d test images: %f %%' % (num_steps, (100.0 * (y_predicted == y_true).sum() / num_steps)))
    print(classification_report(y_true, y_predicted, target_names=class_names))

    return (test_loss/num_steps), test_metrics, num_steps

# Main training class with improved model architecture
class GastroVisionTrainer:
    def __init__(self):
        # Create datasets and dataloaders
        training_dataset = datasets.ImageFolder(train_root_dir, transform=trans['train'])
        validation_dataset = datasets.ImageFolder(val_root_dir, transform=trans['valid'])
        test_dataset = datasets.ImageFolder(test_root_dir, transform=trans['test'])

        self.training_generator = data.DataLoader(training_dataset, batch_size, shuffle=True)
        self.validation_generator = data.DataLoader(validation_dataset, batch_size)
        self.test_generator = data.DataLoader(test_dataset, batch_size)

        print(f'Number of Training set images: {len(training_dataset)}')
        print(f'Number of Validation set images: {len(validation_dataset)}')
        print(f'Number of Test set images: {len(test_dataset)}')

    def create_improved_model(self):
        # Initialize model with proper weights parameter
        model = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1).to(device)

        # Freeze all layers initially
        for param in model.parameters():
            param.requires_grad = False

        # Get the number of features from the last layer
        num_features = model.classifier.in_features

        # Replace classifier with a properly structured sequential model
        model.classifier = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            nn.Linear(256, n_classes),
            nn.LogSoftmax(dim=1)
        ).to(device)

        # Unfreeze classifier layers
        for param in model.classifier.parameters():
            param.requires_grad = True

        return model

    def train_net(self):
        model = self.create_improved_model()

        # Training setup
        optimizer = optim.Adam(model.parameters(), lr, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=4, verbose=True)
        criterion = nn.NLLLoss()

        val_f1_max = 0.0
        epochs = []
        lossesT = []
        lossesV = []

        # Training loop
        for epoch in range(max_epochs):
            print(f'Epoch {epoch+1}/{max_epochs}')
            print('-' * 10)

            since = time.time()
            train_metrics = defaultdict(float)
            total_loss = 0
            num_steps = 0

            all_labels_d = torch.tensor([], dtype=torch.long).to(device)
            all_predictions_d = torch.tensor([], dtype=torch.long).to(device)

            model.train()

            # Training phase
            for image, labels in self.training_generator:
                image, labels = image.to(device, dtype=torch.float32), labels.to(device)

                optimizer.zero_grad()
                outputs = model(image)
                loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()

                num_steps += image.size(0)
                total_loss += loss.item() * image.size(0)

                _, predicted = torch.max(outputs, 1)
                all_labels_d = torch.cat((all_labels_d, labels), 0)
                all_predictions_d = torch.cat((all_predictions_d, predicted), 0)

            # Calculate training metrics
            y_true = all_labels_d.cpu()
            y_predicted = all_predictions_d.cpu()

            train_metrics['loss'] = total_loss / num_steps
            train_metrics['micro_precision'] = mtc.precision_score(y_true, y_predicted, average="micro")
            train_metrics['micro_recall'] = mtc.recall_score(y_true, y_predicted, average="micro")
            train_metrics['micro_f1'] = mtc.f1_score(y_true, y_predicted, average="micro")
            train_metrics['macro_precision'] = mtc.precision_score(y_true, y_predicted, average="macro")
            train_metrics['macro_recall'] = mtc.recall_score(y_true, y_predicted, average="macro")
            train_metrics['macro_f1'] = mtc.f1_score(y_true, y_predicted, average="macro")
            train_metrics['mcc'] = mtc.matthews_corrcoef(y_true, y_predicted)

            print('Training...')
            print(f'Train_loss: {train_metrics["loss"]:.3f}')
            print_metrics(train_metrics, num_steps)

            # Validation phase
            val_loss, val_metrics, val_num_steps = validate_net(model, self.validation_generator, device, criterion)
            scheduler.step(val_loss)

            epochs.append(epoch)
            lossesT.append(train_metrics['loss'])
            lossesV.append(val_loss)

            print('.' * 5)
            print('Validating...')
            print(f'val_loss: {val_loss:.3f}')
            print_metrics(val_metrics, val_num_steps)

            # Save results to CSV
            self.save_results_to_csv(epoch, train_metrics, val_loss, val_metrics)

            # Save best model
            if val_metrics['micro_f1'] >= val_f1_max:
                print(f'val micro f1 increased ({val_f1_max:.6f}-->{val_metrics["micro_f1"]:.6f}). Saving model')
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'loss': val_loss
                }, os.path.join(model_path, f'C_{epoch+1}_{batch_size}_improved.pth'))

                val_f1_max = val_metrics['micro_f1']

            print('-' * 10)
            time_elapsed = time.time() - since
            print(f'{time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')

        # Plot training curves
        training_curve(epochs, lossesT, lossesV)

        # Test phase
        print("Starting testing...")
        test_loss, test_metrics, test_num_steps = test_net(model, self.test_generator, device, criterion)
        print_metrics(test_metrics, test_num_steps)

        # Save test results
        self.save_test_results_to_csv(test_loss, test_metrics)

        return val_metrics, test_metrics

    def save_results_to_csv(self, epoch, train_metrics, val_loss, val_metrics):
        key_name = ['Epoch', 'Train_loss', 'Train_micro_precision', 'Train_micro_recall', 'Train_micro_f1',
                   'Train_macro_precision', 'Train_macro_recall', 'Train_macro_f1', 'Train_mcc',
                   'Val_loss', 'Val_micro_precision', 'Val_micro_recall', 'Val_micro_f1',
                   'Val_macro_precision', 'Val_macro_recall', 'Val_macro_f1', 'Val_mcc']

        train_list = [epoch]
        train_list.extend([train_metrics[k] for k in ['loss', 'micro_precision', 'micro_recall', 'micro_f1',
                                                    'macro_precision', 'macro_recall', 'macro_f1', 'mcc']])
        train_list.append(val_loss)
        train_list.extend([val_metrics[k] for k in ['micro_precision', 'micro_recall', 'micro_f1',
                                                   'macro_precision', 'macro_recall', 'macro_f1', 'mcc']])

        try:
            with open(os.path.join(base_path, filename), 'a', newline="") as f:
                wr = csv.writer(f, delimiter=",")
                if epoch == 0:
                    wr.writerow(key_name)
                wr.writerow(train_list)
        except IOError as e:
            print(f"I/O Error: {e}")

    def save_test_results_to_csv(self, test_loss, test_metrics):
        key_name = ['Test_loss', 'Test_micro_precision', 'Test_micro_recall', 'Test_micro_f1',
                   'Test_macro_precision', 'Test_macro_recall', 'Test_macro_f1', 'Test_mcc']

        test_list = [test_loss]
        test_list.extend([test_metrics[k] for k in ['micro_precision', 'micro_recall', 'micro_f1',
                                                  'macro_precision', 'macro_recall', 'macro_f1', 'mcc']])

        try:
            with open(os.path.join(base_path, filename), 'a', newline="") as f:
                wr = csv.writer(f, delimiter=",")
                wr.writerow([])
                wr.writerow(key_name)
                wr.writerow(test_list)
        except IOError as e:
            print(f"I/O Error: {e}")

# Run the training
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    logging.info(f'''Starting training:
        Epochs: {max_epochs}
        Batch Size: {batch_size}
        Learning Rate: {lr}''')

    trainer = GastroVisionTrainer()
    val_metrics, test_metrics = trainer.train_net()

    print("Training completed!")


Mounted at /content/drive
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 89.5MB/s]


Epoch 1/5
----------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Training...
Train_loss: 1.396
loss:1.396259,micro_precision:0.639274,micro_recall:0.639274,micro_f1:0.639274,macro_precision:0.239251,macro_recall:0.215110,macro_f1:0.220515,mcc:0.570320


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


.....
Validating...
val_loss: 0.830
micro_precision:0.767480,micro_recall:0.767480,micro_f1:0.767480,macro_precision:0.485796,macro_recall:0.412688,macro_f1:0.423001,mcc:0.722071
val micro f1 increased (0.000000-->0.767480). Saving model
----------
63m 17s
Epoch 2/5
----------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Training...
Train_loss: 0.818
loss:0.818130,micro_precision:0.761215,micro_recall:0.761215,micro_f1:0.761215,macro_precision:0.562770,macro_recall:0.441661,macro_f1:0.457128,mcc:0.714957


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


.....
Validating...
val_loss: 0.743
micro_precision:0.790244,micro_recall:0.790244,micro_f1:0.790244,macro_precision:0.482901,macro_recall:0.476435,macro_f1:0.472697,mcc:0.751330
val micro f1 increased (0.767480-->0.790244). Saving model
----------
29m 31s
Epoch 3/5
----------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Training...
Train_loss: 0.718
loss:0.718097,micro_precision:0.780383,micro_recall:0.780383,micro_f1:0.780383,macro_precision:0.542701,macro_recall:0.476441,macro_f1:0.494316,mcc:0.738146


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


.....
Validating...
val_loss: 0.685
micro_precision:0.791870,micro_recall:0.791870,micro_f1:0.791870,macro_precision:0.540021,macro_recall:0.476696,macro_f1:0.489054,mcc:0.752641
val micro f1 increased (0.790244-->0.791870). Saving model
----------
29m 19s
Epoch 4/5
----------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Training...
Train_loss: 0.680
loss:0.680272,micro_precision:0.778344,micro_recall:0.778344,micro_f1:0.778344,macro_precision:0.612968,macro_recall:0.481484,macro_f1:0.501725,mcc:0.735855


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


.....
Validating...
val_loss: 0.646
micro_precision:0.791870,micro_recall:0.791870,micro_f1:0.791870,macro_precision:0.462498,macro_recall:0.454960,macro_f1:0.452653,mcc:0.752057
val micro f1 increased (0.791870-->0.791870). Saving model
----------
29m 31s
Epoch 5/5
----------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Training...
Train_loss: 0.634
loss:0.634218,micro_precision:0.795269,micro_recall:0.795269,micro_f1:0.795269,macro_precision:0.639570,macro_recall:0.510668,macro_f1:0.537043,mcc:0.756184


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


.....
Validating...
val_loss: 0.634
micro_precision:0.790244,micro_recall:0.790244,micro_f1:0.790244,macro_precision:0.547270,macro_recall:0.466118,macro_f1:0.484341,mcc:0.750792
----------
29m 24s
Starting testing...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Confusion matrix, without normalization
[[103   0   2   0   0   1   0   0   1   1  17   2   0   0   0]
 [  0   3   0   0   0   0   0   0   6   0   0   1   0   0   0]
 [  3   0   7   0   0   1   0   0   0   0   7   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0  12   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   2   0   1   0   0]
 [  2   0   1   0   0   7   0   0   0   2   2   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   1   0   0   0]
 [  1   2   1   0   0   0   0   0   4   0   0   3   0   0   0]
 [  0   0   0   0   0   0   0   0  29   0   0   3   1   0   0]
 [  2   0   0   0   0   0   0   0   0   9   9   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0 141   4   0   2   0]
 [  0   0   1   0   0   0   0   0   2   0   0  88   3   3   0]
 [  0   0   0   0   0   0   0   0   3   0   1  10  26   0   0]
 [  0   0   0   0   0   0   0   0   0   0   6   8   3  68   0]
 [  0   0   0   0   0   0   0   0   0   0   0   1   0   0   0]]
Confusion matr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
