<a href="https://colab.research.google.com/github/sajidcsecu/radioGenomic/blob/main/NewUnetTrain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from itertools import chain
import torch
from NewUnetDataPreparation import PatientDataset
from NewUnet import UNet
from NewUnetLoss import DiceBCELoss
import time
import numpy as np
import os
import random
from tqdm.auto import tqdm
from glob import glob
from sklearn.model_selection import train_test_split
import cv2


class UnetTrain:
    """ Seeding the randomness. """
    def seeding(self,seed):
        random.seed(seed)
        os.environ["PYTHONHASHSEED"] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

    """ Calculate the time taken """
    def epoch_time(self, start_time, end_time):
        elapsed_time = end_time - start_time
        elapsed_mins = int(elapsed_time / 60)
        elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
        return elapsed_mins, elapsed_secs

    """ Training """
    def train(self, model, loader, optimizer, loss_fn, device):
        epoch_loss = 0.0
        model.train()
        for x, y in loader:
            x = x.to(device,dtype=torch.float32)
            y = y.to(device,dtype=torch.float32)

            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch_loss = epoch_loss / len(loader)
        return epoch_loss

    """ Testing """
    def evaluate(self,model, loader, loss_fn, device):
        epoch_loss = 0.0
        model.eval()
        with torch.no_grad():
            for x, y in loader:
                x = x.to(device, dtype=torch.float32)
                y = y.to(device, dtype=torch.float32)
                y_pred = model(x)
                loss = loss_fn(y_pred, y)
                epoch_loss += loss.item()

            epoch_loss = epoch_loss / len(loader)
        return epoch_loss

    """ Loading the paths of data  """
    def load_paths(self, path, split=0.2):
        images = sorted(glob(os.path.join(path, "images/*")))[:10]
        masks = sorted(glob(os.path.join(path, "masks/*")))[:10]
        total_size = len(images)
        print("Total Images : ", total_size)
        valid_size = int(split * total_size)
        test_size = int(split * total_size)


        train_x, valid_x = train_test_split(images, test_size=valid_size, random_state=42)
        train_y, valid_y = train_test_split(masks, test_size=valid_size, random_state=42)

        train_x, test_x = train_test_split(train_x, test_size=test_size, random_state=42)
        train_y, test_y = train_test_split(train_y, test_size=test_size, random_state=42)
        print(len(train_x), len(valid_x), len(test_x))
        return (train_x, train_y), (valid_x, valid_y), (test_x, test_y)

    def test(self,model, loader, loss_fn, device):
        pass


if __name__ == "__main__":
    ut = UnetTrain()

    """ Seeding """
    ut.seeding(42)
    """ Hyperparameters """
    batch_size = 2
    num_epochs = 50
    lr = 1e-4


    """ Paths """
    path = "F:\\Idiot Developer\\radioGenomic\\Segementation"
    data_path = os.path.join(path,"data","full data")
    checkpoint_path = os.path.join(path, "files", "checkpoint.pth")


    print(data_path)


    # """ Dataset and loader """
    # # image = torch.load("F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\image.pt")
    # # mask = torch.load("F:\\Idiot Developer\\radioGenomic\\files\\segmentation\\mask.pt")
    (train_x, train_y), (valid_x, valid_y), (test_x, test_y) = ut.load_paths(data_path)
    print(train_x)
    print(train_y)
    # # image = cv2.imread(train_x[0], cv2.IMREAD_GRAYSCALE)
    # # mask = cv2.imread(train_y[0], cv2.IMREAD_GRAYSCALE)
    # # print(image.shape)
    # # print(mask.shape)
    # # plt.figure()
    # # plt.subplot(1,2,1)
    # # plt.imshow(image,cmap="gray")
    # # plt.subplot(1,2,2)
    # # plt.imshow(mask,cmap="gray")
    # # plt.show()
    # # patient_dataset = PatientDataset(image, mask)
    train_dataset = PatientDataset(train_x,train_y)
    valid_dataset = PatientDataset(valid_x, valid_y)
    #
    # # train_dataset, valid_dataset = torch.utils.data.random_split(patient_dataset, [0.8, 0.2])
    # #
    # #
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=os.cpu_count()
    )
    #
    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=os.cpu_count()
    )
    # Make device agnostic code
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = UNet(1,1)
    model = model.to(device)
    # #
    """ Loss Function and Optimizer """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
    loss_fn = DiceBCELoss()
    # #
    # """ Training the model """
    best_valid_loss = float("inf")
    #
    for epoch in tqdm(range(num_epochs)):
        start_time = time.time()

        train_loss = ut.train(model, train_loader, optimizer, loss_fn, device)
        valid_loss = ut.evaluate(model, valid_loader, loss_fn, device)
    #
        """ Saving the model """
        if valid_loss < best_valid_loss:
            data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}"
            print(data_str)

            best_valid_loss = valid_loss
            torch.save(model.state_dict(), checkpoint_path)

        end_time = time.time()
        epoch_mins, epoch_secs = ut.epoch_time(start_time, end_time)

        data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
        data_str += f'\tTrain Loss: {train_loss:.3f}\n'
        data_str += f'\t Val. Loss: {valid_loss:.3f}\n'
        print(data_str)