<a href="https://colab.research.google.com/github/sajidcsecu/radioGenomic/blob/main/NewUnetTrain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from torch.utils.data import  DataLoader
import torch
from NewUnetDataPreparation import PatientDataset2DUNet
from NewUnet import UNet
from NewUnetLoss import DiceBCELoss
import time
import numpy as np
import os
import random
from tqdm.auto import tqdm
from glob import glob
from sklearn.model_selection import train_test_split
import pandas as pd

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        """
        Args:
            patience (int): Number of epochs to wait before stopping if no improvement.
            min_delta (float): Minimum change in the monitored metric to qualify as an improvement.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0  # Reset patience counter if loss improves
        else:
            self.counter += 1  # Increase counter if no improvement

        if self.counter >= self.patience:
            print(f"⛔ Early stopping triggered after {self.patience} epochs without improvement!")
            return True  # Stop training
        return False  # Continue training

class UnetTrain:
    """ Seeding the randomness. """
    def seeding(self,seed):
        random.seed(seed)
        os.environ["PYTHONHASHSEED"] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

    """ Calculate the time taken """
    def epoch_time(self, start_time, end_time):
        elapsed_time = end_time - start_time
        elapsed_mins = int(elapsed_time / 60)
        elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
        return elapsed_mins, elapsed_secs

    """ Training """
    def train(self, model, loader, optimizer, loss_fn, device):
        epoch_loss = 0.0
        model.train()
        for x, y in loader:
            x = x.to(device,dtype=torch.float32)
            y = y.to(device,dtype=torch.float32)

            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch_loss = epoch_loss / len(loader)
        return epoch_loss

    """ Testing """
    def evaluate(self,model, loader, loss_fn, device):
        epoch_loss = 0.0
        model.eval()
        with torch.no_grad():
            for x, y in loader:
                x = x.to(device, dtype=torch.float32)
                y = y.to(device, dtype=torch.float32)
                y_pred = model(x)
                loss = loss_fn(y_pred, y)
                epoch_loss += loss.item()

            epoch_loss = epoch_loss / len(loader)
        return epoch_loss

    """ Loading the paths of data  """
    def load_paths(self, path, split=0.2):
        images = sorted(glob(os.path.join(path, "images/*")))
        masks = sorted(glob(os.path.join(path, "masks/*")))
        total_size = len(images)
        print("Total Images : ", total_size)
        valid_size = int(split * total_size)
        test_size = int(split * total_size)


        train_x, valid_x = train_test_split(images, test_size=valid_size, random_state=42)
        train_y, valid_y = train_test_split(masks, test_size=valid_size, random_state=42)

        train_x, test_x = train_test_split(train_x, test_size=test_size, random_state=42)
        train_y, test_y = train_test_split(train_y, test_size=test_size, random_state=42)
        print(len(train_x), len(valid_x), len(test_x))
        return (train_x, train_y), (valid_x, valid_y), (test_x, test_y)

    def test(self,model, loader, loss_fn, device):
        pass


if __name__ == "__main__":
    ut = UnetTrain()

    """ Seeding """
    ut.seeding(42)
    """ Hyperparameters """
    batch_size = 8
    num_epochs = 10
    lr = 1e-4


    """ Paths """
    # path = "F:\\Idiot Developer\\radioGenomic\\Segementation"
    path = ".\Segementation"


    checkpoint_path = os.path.join(path, "files", "checkpoint.pth")
    print("Check Point Path : ",checkpoint_path )
    # print(data_path)
    # """ Dataset and loader """
    # # image = torch.load("F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\image.pt")
    # # mask = torch.load("F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\mask.pt")
    """ Training from images in disk using slice dataset"""
    # data_path = os.path.join(path, "data", "full data")
    # (train_x, train_y), (valid_x, valid_y), (test_x, test_y) = ut.load_paths(data_path)
    # train_dataset = SliceDataset(train_x, train_y)
    # valid_dataset = SliceDataset(valid_x, valid_y)
    #
    # # train_dataset, valid_dataset = torch.utils.data.random_split(patient_dataset, [0.8, 0.2])
    """ Training from dicoms in disk using patient dataset """

    # Load dataset
    metadata_lung1 = pd.read_csv('.\metadata\metadata_lung1.csv', sep=',', index_col=False)
    patient_list_lung1 = metadata_lung1["Subject ID"].unique().tolist()
    index_of_error_patient = [patient_list_lung1.index(i) for i in ['LUNG1-128']]
    patient_list_lung1 = np.delete(patient_list_lung1, index_of_error_patient)
    patient_list_lung1 =  patient_list_lung1[:30]
    print(patient_list_lung1)
    train_patient, valid_patient = train_test_split(patient_list_lung1, test_size=0.1, random_state=42)
    train_patient, test_patient = train_test_split(train_patient, test_size=0.1, random_state=42)
    print("Number of Total Patients : ", len(patient_list_lung1))
    print("Number of Patients for Training : ", len(train_patient))
    print("Number of Patients for Validation : ", len(valid_patient))
    print("Number of Patients for Testing : ", len(test_patient))
    # transform = transforms.Compose([
    #     transforms.ToTensor(),
    # ])

    # # Load dataset
    # print("Training Loading...")
    # train_dataset = PatientDatasetAllInOneTensor(train_patient, metadata_lung1, train=True)
    # print("Valid Loading...")
    # valid_dataset = PatientDatasetAllInOneTensor(valid_patient, metadata_lung1, train=False)
    # print("Testing Loading...")
    # test_dataset = PatientDatasetAllInOneTensor(test_patient, metadata_lung1, train=False)
    # Load dataset
    print("Training Loading...")
    train_dataset = PatientDataset2DUNet(train_patient, metadata_lung1, train=True)
    print("Valid Loading...")
    valid_dataset = PatientDataset2DUNet(valid_patient, metadata_lung1, train=False)
    print("Testing Loading...")
    test_dataset = PatientDataset2DUNet(test_patient, metadata_lung1, train=False)
    #
    # # train_dataset, valid_dataset = torch.utils.data.random_split(patient_dataset, [0.8, 0.2])
    # #
    # #
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=os.cpu_count()
    )
    #
    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=os.cpu_count()
    )

    print(f"Total images in Training Dataset: {len(train_dataset)}")
    print(f"Total images in Valid Dataset: {len(valid_dataset)}")
    print(f"Total images in Testing Dataset: {len(test_dataset)}")
    # Make device agnostic code
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = UNet(1,1)
    model = model.to(device)
    # #
    """ Loss Function and Optimizer """
    # Adam is great, but AdamW (Adam with weight decay) improves generalization by preventing overfitting.
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    # If loss plateaus, the model automatically reduces the learning rate.
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    loss_fn = DiceBCELoss()
    # #

    # """ Training the model """
    # Initialize Early Stopping
    early_stopping = EarlyStopping(patience=10, min_delta=0.001)
    best_valid_loss = float("inf")
    #
    for epoch in tqdm(range(num_epochs)):
        start_time = time.time()

        train_loss = ut.train(model, train_loader, optimizer, loss_fn, device)
        valid_loss = ut.evaluate(model, valid_loader, loss_fn, device)
    #
        """ Saving the model """
        if valid_loss < best_valid_loss:
            data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}"
            print(data_str)

            best_valid_loss = valid_loss
            torch.save(model.state_dict(), checkpoint_path)

        end_time = time.time()
        epoch_mins, epoch_secs = ut.epoch_time(start_time, end_time)

        data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
        data_str += f'\tTrain Loss: {train_loss:.3f}\n'
        data_str += f'\t Val. Loss: {valid_loss:.3f}\n'
        print(data_str)
        # **Early Stopping Check**
        if early_stopping(valid_loss):
            print("🛑 Stopping training early due to no improvement.")
            break