author: Yagik Poshiya
github: @yagnikposhiya
organization: Tvisi

In [1]:
# mount google drive to the current session
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
os.getcwd() # check which one is the current working directory

'/content'

In [3]:
os.listdir('/content/drive/MyDrive/Notebook_Testing') # list all directories or files available in the specified path

['data', 'train.ipynb']

In [None]:
# install weights-and-biases using pip
!pip install wandb

Collecting wandb
  Downloading wandb-0.17.5-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl.metadata (13 kB)
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.12.0-py2.py3-none-any.whl.metadata (9.8 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb)
  Downloading gitdb-4.0.11-py3-none-any.whl.metadata (1.2 kB)
Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb)
  Downloading smmap-5.0.1-py3-none-any.whl.metadata (4.3 kB)
Downloading wandb-0.17.5-py3-none-ma

In [5]:
# install pytorch_lightning using pip
!pip install lightning

Collecting lightning
  Downloading lightning-2.3.3-py3-none-any.whl.metadata (35 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.6-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.4.0.post0-py3-none-any.whl.metadata (19 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.3.3-py3-none-any.whl.metadata (21 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<4.0,>=2.0.0->lightning)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch<4.0,>=2.0.0->lightning)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch<4.0,>=2.0.0->lightning)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Co

In [6]:
# install torchinfo for printing model summary using pip
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


### Import required python libraries

In [7]:
import os
import cv2
import glob
import json
import wandb
import torch
import shutil
import random
import subprocess
import matplotlib
import numpy as np
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
import matplotlib.pyplot as plt

from PIL import Image
from typing import Any
from torchinfo import summary
from collections import Counter
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import Dataset, DataLoader, random_split

### Define required functions & classes

In [8]:
def check_gpu_config() -> None:
    """
    This function is used to check, whether GPUs are available.

    Parameters:
    - (None)

    Returns:
    - (None)
    """

    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count() # get total number of available gpus
        print("- Number of GPUs available: {}".format(num_gpus))

        for i in range(num_gpus):
            gpu_name = torch.cuda.get_device_name() # get gpu name
            print("- GPU name: {}".format(gpu_name))

        command = "nvidia-smi" # set a command
        result = subprocess.run(command, shell=True, capture_output=True, text=True) # execute a command

        if result.returncode == 0:
            print(result.stdout) # output after successful execution of command
        else:
            print("- Error message: \n{}".format(result.stderr)) # output after failed execution of command

    else:
        print("- CUDA is not available. Using CPU instead.")

In [9]:
class Config():
    def __init__(self) -> None:

        # current working directory
        # self.CWD = os.getcwd() # get current working directory
        self.CWD = '/content/drive/MyDrive/Notebook_Testing/'

        # training and validation set paths
        self.TRAINSET_PATH = '' # set training set path
        self.VALIDSET_PATH = '' # set validation set path
        self.TRAIN_IMAGE_DIR = '' # set train/images directory path
        self.TRAIN_MASK_DIR = '' # set train/masks directory path

        # other relevant paths
        self.INPUT_JSON_FILE_PATH = os.path.join(self.CWD,'data/raw/json_projects') # set path to the directory which contains one or more than one json files
        self.SAMPLE_JSON_FILE_PATH = os.path.join(self.CWD,'data/raw/json_projects/File1.json')  # set path to the json file for understanding json file structure
        self.RAW_IMAGE_DIR = os.path.join(self.CWD,'data/raw/images') # set raw image directory
        self.BASE_DATA_PATH = os.path.join(self.CWD,'data') # set base data folder path
        self.PATH_TO_SAVE_TRAINED_MODEL = os.path.join(self.dir,'saved_models') # set path to save trained model

        # weight and biases config
        self.ENTITY = 'neuralninjas' # set team/organization name for wandb account
        self.PROJECT = 'tvisi-agri-count' # set project name
        self.REINIT = True # set boolean value for reinitialization
        self.ANONYMOUS = 'allow' # set anonymous value type
        self.LOG_MODEL = 'all' # set log model type

        # model training parameters
        self.BATCH_SIZE = 16 # set batch size for model training
        self.MAX_EPOCHS = 2 # set maximum epochs for model training
        self.NUM_CLASSES = 2 # set number of classes contains by mask images (in segmentation case)
        self.LEARNING_RATE = 0.001 # set learning rate
        self.TRANSFORM = True # set booelan values for applying augmentation techniques for training set


    def printConfiguration(self) -> None:
        """
        This function is used to print all configuration related to paths and model training params

        Parameters:
        - (None)

        Returns:
        - (None)
        """

        print("-----------------------------------------------------------")
        print("-----------------------CONFIGURATIONS----------------------")
        print("-----------------------------------------------------------")
        print("\n",
              f"- Current working directory: {self.CWD}\n",
              f"- Trainset path: {self.TRAINSET_PATH}\n",
              f"- Validset path: {self.VALIDSET_PATH}\n",
              f"- Train image directory: {self.TRAIN_IMAGE_DIR}\n",
              f"- Train mask directory: {self.TRAIN_MASK_DIR}\n",
              f"- Input JSON file path: {self.INPUT_JSON_FILE_PATH}\n",
              f"- Sample JSON file path: {self.SAMPLE_JSON_FILE_PATH}\n",
              f"- Raw image directory: {self.RAW_IMAGE_DIR}\n",
              f"- Base data path: {self.BASE_DATA_PATH}\n",
              f"- Path to save trained model: {self.PATH_TO_SAVE_TRAINED_MODEL}\n",
              f"- Batch size: {self.BATCH_SIZE}\n",
              f"- Maximum epochs: {self.MAX_EPOCHS}\n",
              f"- Number of classes: {self.NUM_CLASSES}\n",
              f"- Learning rate: {self.LEARNING_RATE}\n",
              f"- Tranformation/Augmentation: {self.TRANSFORM}\n")

In [14]:
class FileDoesNotExist(BaseException): # create custom class to raise an error
    pass # no actions are needed

class DirectoryDoesNotExist(BaseException): # create custom class to raise an error
    pass # no actions are needed

def look_at_json_structure(json_file_path: str) -> None:
    """
    This function is used to understand the JSON file structure and what information it contains.

    Parameters:
    - json_file_path (str): Input json file path which contains mask region XY co-ordinates

    Returns:
    - (None)
    """

    input_file = open(json_file_path, 'r') # open json file in read-only mode
    input_data = json.load(input_file) # laod json data from input_file

    # print(input_data) # print json data in output console
    # print(type(input_data)) # <class 'dict'>

    print("- Keys available in json data: \n{}\n".format(input_data.keys())) # all major keys available in the json data
    # print("- Values of keys available in json data: \n{}".format(input_data.values())) # all major values available in the json data

    """
    Major keys available in json file:
    dict_keys(['_via_settings', '_via_img_metadata', '_via_attributes', '_via_data_format_version', '_via_image_id_list'])

    Out of these keys '_via_img_metadata' key contains mask region information.
    """

    print("- Keys in value of '_via_img_metadata' key: \n{}\n".format(input_data['_via_img_metadata'].keys())) # keys in value of '_via_img_metadata' key
    # as an output of above line, we'll get name of the images for those mask regions are there.

    print("- Keys in value of '9.jpeg136245' key: \n{}\n".format(input_data['_via_img_metadata']['9.jpeg136245'].keys())) # keys in value of '9.jpeg136245' key
    # dict_keys(['filename', 'size', 'regions', 'file_attributes'])

    print("- Filename: {}".format(input_data['_via_img_metadata']['9.jpeg136245']['filename'])) # filename
    # print("- Regions: {}".format(type(input_data['_via_img_metadata']['9.jpeg136245']['regions']))) # <class 'list'>
    print("- Regions: {}".format(len(input_data['_via_img_metadata']['9.jpeg136245']['regions']))) # total of regions available in an image
    print("- First region information:")
    print("-- Class name: {}".format(input_data['_via_img_metadata']['9.jpeg136245']['regions'][0]['region_attributes']['name'].rstrip('\n'))) # extract class name or extract main part of a string; do not want last \n character
    print("-- X co-ordinates: \n{}".format(input_data['_via_img_metadata']['9.jpeg136245']['regions'][0]['shape_attributes']['all_points_x'])) # X co-ordinates
    print("-- Y co-ordinates: \n{}\n".format(input_data['_via_img_metadata']['9.jpeg136245']['regions'][0]['shape_attributes']['all_points_y'])) # Y co-ordinates

def createMasks(json_file_path: str, raw_image_dir:str, base_data_path:str) -> str:
    """
    This function is used to create mask image from the existing information available in the
    json file related to mask regions in each image.

    Parameters:
    - json_file_path (str): Input json file path which contains mask region XY co-ordinates
    - raw_image_dir (str): Directory path contains raw images mentioned in the json file
    - base_data_path (str): Path for data directory

    Returns:
    - (None)
    """

    input_file = open(json_file_path, 'r') # open json file in read-only mode
    input_data = json.load(input_file) # laod json data from input_fi

    if not (os.path.exists(raw_image_dir)): # check if images directory exist or not
        raise DirectoryDoesNotExist(raw_image_dir) # raise an error directory does not exist
    elif not (os.path.exists(f'{base_data_path}/processed')): # check if masks directory exists or not
        os.makedirs(f'{base_data_path}/processed') # if not then create it

    if not (os.path.exists(f'{base_data_path}/processed/train')): # check if train directory exists or not
        os.makedirs(f'{base_data_path}/processed/train') # if not then create it

        if not (os.path.exists(f'{base_data_path}/processed/train/images')): # check if images directory exists or not
            os.makedirs(f'{base_data_path}/processed/train/images') # if not then create it

        if not (os.path.exists(f'{base_data_path}/processed/train/masks')): # check if masks directory exists or not
            os.makedirs(f'{base_data_path}/processed/train/masks') # if not then create it

    trainset_path = f'{base_data_path}/processed/train' # set trainset path
    trainset_images_path = f'{base_data_path}/processed/train/images' # set trainset images path
    # input_data = input_data['_via_img_metadata'] # extract data related to mask regions only

    for filename in os.listdir(raw_image_dir): # list all files available in the images directory of raw directory
        full_filename = os.path.join(raw_image_dir,filename) # create full filename
        if os.path.exists(full_filename): # check if file available or not
            shutil.copy(full_filename,trainset_images_path) # if available then copy that file to destination path
        else:
            raise FileDoesNotExist(full_filename) # raise an error file does not exist

    file_names = [] # create an empty list to store filenames
    heights = [] # create an empty list to store image heights
    widths = [] # create an empty list to store image widths
    channels = [] # create en empty list to store image channels

    for key,value in input_data.items():
        filename = value ['filename'] # extract filename i.e. 9.jpeg
        image_path = f'{trainset_path}/images/{filename}' # set processed image path
        mask_path = f'{trainset_path}/masks/{filename}' # set processed mask path

        if not (os.path.exists(image_path)): # check whether image exist or not
            raise FileDoesNotExist(image_path) # raise an error file does not exist
        else:
            image = cv2.imread(image_path, cv2.IMREAD_COLOR) # load colored image
            h, w, channel = image.shape

            heights.append(h) # append height of corresponding image
            widths.append(w) # append width of corresponding image
            channels.append(channel) # append channel of corresponding image
            file_names.append(image_path) # append image path

    # finding most common height and width
    most_common_height = Counter(heights).most_common(1)[0][0] # SINGLE most common height; returns a list contains tuple (common_height,count); here common_height is favorable then additional [0] index at last
    most_common_width = Counter(widths).most_common(1)[0][0] # SINGLE most common width; same explanation as above
    most_common_channel = Counter(channels).most_common(1)[0][0] # SINGLE most common channel; same explanation as above

    filtered_filenames = [] # define empty list to store filenames

    # The zip function is particularly useful when you want to iterate over multiple sequences in parallel/simultaneously.
    for filename_local, height, width, channel in zip(file_names, heights, widths, channels): # create
        if height == most_common_height and width == most_common_width and channel == most_common_channel:
            filtered_filenames.append(filename_local)
        else:
            os.remove(filename_local) # if file has different shape than common shape then remove file from processed directory
            print(f'File has been removed from processed dataset.: {filename_local}')

    for key, value in input_data.items():
        fname = value['filename'] # extract filename
        image_path = f'{trainset_path}/images/{fname}' # create full image path

        if image_path in filtered_filenames: # if image_path is available in the filtered_filenames
            image = cv2.imread(image_path, cv2.IMREAD_COLOR) # load colored image
            h, w, _ = image.shape # extract height and width of the image

            mask = np.zeros((h,w)) # create zero-array of size (height,width) for mask

            regions = value['regions'] # extract mask regions related information to the loaded image

            for region in regions:
                shape_attributes = region['shape_attributes'] # extract information regarding shape attributes
                x_points = shape_attributes['all_points_x'] # extract all x co-ordinates
                y_points = shape_attributes['all_points_y'] # extract all y co-ordinates

                contours = [] # define an empty list to store contours

                for x, y in zip(x_points,y_points):
                    contours.append((x,y)) # append tuple of x-y co-ordinates to the contours list

                contours = np.array(contours) # convert contours from list to numpy array

                cv2.drawContours(mask, [contours], -1, 255, -1) # draw contours in an image

            # apply morphological operations
            kernel = np.ones((3,3), np.uint8) # define kernel for morphological operation
            mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2) # used to separate touching objects

            mask = np.uint8(mask/255) # label the pixel; either 0 (background) or 1 (Peanut Seed)

            mask_path = f'{trainset_path}/masks/{fname}' # set path for saving a masks
            cv2.imwrite(mask_path, mask) # save mask to the specified path with the same filename as image has to local storage

        else:
            print(f'File is not available in the dataset.: {image_path}') # if file(image_path) is not available in the filtered_filenames list

    print("Mask image are created successfully.")

class SegmentationDataset(Dataset):
    def __init__(self, images_dir:str, masks_dir:str, transform:bool=False) -> None:
        self.images_dir = images_dir # set path to directory contains input images
        self.masks_dir = masks_dir # set path to directory contains mask images
        self.transform = transform # set boolean value for transform/augmentation

        self.image_files = sorted(os.listdir(images_dir)) # generate a list of input images
        self.mask_files = sorted(os.listdir(masks_dir)) # generate a list of mask images

    def __len__(self) -> int:
        return len(self.image_files) # return total number of data samples available in the dataset

    def __getitem__(self,index) -> Any:
        image_path = os.path.join(self.images_dir, self.image_files[index]) # generate path for a single image
        mask_path = os.path.join(self.masks_dir, self.mask_files[index]) # generate path for a single mask image

        image = cv2.imread(image_path, cv2.IMREAD_COLOR) # read an input image as a colored image
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # read a mask image as a grayscale image

        image = np.transpose(image, (2, 0, 1)) # permute shape from (720,1280,3) to (3,720,1280)

        """
        There is no need to do anything related to one-hot encoding for mask images. Once you have clarify the number of classes
        at the time of the model implementation then it will work in the segmentation task.
        """

        """
        Tensors with negative strides are not currently supported that is why we have to use PIL and
        then have to apply transformations.
        """

        # image = Image.fromarray(image).convert("L") # convert to PIL grayscale images
        # mask = Image.fromarray(mask).convert("L") # convert to PIL grayscale images

        # if self.transform: # if transform/augmentation value is set to True
        #     random_int = random.randint(0,2) # return random interge including both end-points
        #     if random_int == 0:
        #         image = image # no transformation/augmentation technique is applied
        #     elif random_int == 1:
        #         image = np.flip(image, axis=0) # i.e. vertical flipping (first raw & last raw) -> (last raw -> first raw)
        #         mask = np.flip(mask, axis=0) # i.e. vertical flipping
        #     elif random_int == 2:
        #         image = np.flip(image, axis=1) # i.e. horizontal flipping (first column & last column) -> (last column -> first column)
        #         mask = np.flip(mask, axis=1) # i.e. horizontal flipping

        # image = image.unsqueeze(0) # add a channel dimension

        # # convert back to numpy arrays
        # image = np.array(image)
        # mask = np.array(mask)

        # # remove negative strides
        # image = image.copy()
        # mask = mask.copy()

        image = torch.tensor(image, dtype=torch.float32) # convert to torch tensor
        mask = torch.tensor(mask, dtype=torch.long) # convert to torch tensor

        return image, mask # return image and mask in the form of tuple

class SegmentationDataModule(pl.LightningDataModule):
    def __init__(self, train_image_dir:str, train_mask_dir:str, batch_size:int=32, transform:bool=False, val_split:float=0.2, test_split:float=0.1) -> None:
        super(SegmentationDataModule, self).__init__()
        self.train_image_dir = train_image_dir # set path to directory contains input images for training
        self.train_mask_dir = train_mask_dir # set path to directory contains input mask images for training
        self.batch_size = batch_size # set batch size
        self.transform = transform # set boolean value for transformation/augmentation
        self.val_split = val_split # set validation set split ratio
        self.test_split = test_split # set test set split ratio

    def setup(self, stage=None) -> None:
        self.train_dataset = SegmentationDataset(self.train_image_dir, self.train_mask_dir, self.transform) # create an instance of SegmentationDataset class and load data samples as needed
        # self.val_dataset = SegmentationDataset(self.val_image_dir, self.val_mask_dir) # same as trainset loading just transform is not applied to the validation set
        val_size = int(len(self.train_dataset) * self.val_split) # calculate validation set size
        # test_size = int(len(self.train_dataset) * test_size) # calculate test set size
        # train_size = len(self.train_dataset) - val_size - test_size # calculate train set size
        train_size = len(self.train_dataset) - val_size # calculate train set size

        # self.train_dataset, self.val_dataset, self.test_dataset = random_split(self.train_dataset, [train_size,val_size, test_size]) # split whole dataset into train and validation
        self.train_dataset, self.val_dataset = random_split(self.train_dataset, [train_size,val_size]) # split whole dataset into train and validation

    def train_dataloader(self) -> Any:
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True) # return train set

    def val_dataloader(self) -> Any:
        return DataLoader(self.val_dataset, batch_size=self.batch_size) # return validation set

    # def test_dataloader(self) -> Any:
    #     return DataLoader(self.test_dataset, batch_size=self.batch_size) # return test set

In [11]:
# matplotlib.use('TkAgg') # or 'Qt5Agg' or any other backend that supports interactive display
# # by default backend: FigureCanvasAgg

class DirectoryDoesNotExist(BaseException): # define a custom class for exception and raise if directory does not exist/available
    pass

def merge_json_files(input_files:list, output_file:str) -> Any:
    """
    This function is used to merge json files which have same file structure.

    Parameters:
    - input_files (list): List of JSON files
    - output_file (str): Path to the JSON where merged JSON data will be stored

    Returns:
    - (None)
    """

    if len(input_files) == 1:
        print(f'Only {len(input_files)} json file is detected.')
    else:
        print(f'Merging {len(input_files)} json files.')

    dictionaries_list = [] # define empty dictionaries list

    for file in input_files:
        input_file = open(file, 'r') # open json file in read-only mode
        input_data = json.load(input_file) # load json data from input_file
        input_data = input_data['_via_img_metadata'] # extract values of _via_img_metadata key
        dictionaries_list.append(input_data) # append extracted values to the list of dictionaries

    merged_data = {key: value for d in dictionaries_list for key, value in d.items()} # merging all dictionaries data

    directory_path, filename = os.path.split(output_file) # split output_file path into two parts; 1) directory path and 2) filename

    if os.path.exists(directory_path): # check that directory exists or not
        outfile = open(output_file, 'w') # open json file in writing mode
        json.dump(merged_data, outfile) # write merged json data to the output_file with indent=4
    else:
        raise DirectoryDoesNotExist(directory_path) # raise an exception that directory does not exist

    if len(input_files) > 1: # if more than one json files are available
        print(f'{len(input_files)} json files are merged successfully.')


def show_mask_image_pair(image_dir:str, mask_dir:str) -> None:
    """
    This function is used to visualize image-mask pair.

    Parameters:
    - image_dir (str): Path to image directory in processed data directory
    - mask_dir (str): Path to mask directory in processed data directory

    Returns:
    - (None)
    """

    image_list = glob.glob(f'{image_dir}/*.jpeg') # create list contains all images
    mask_list = glob.glob(f'{mask_dir}/*.jpeg') # create list contains all masks

    while True:
        try:
            user_choice = str(input('Do you want to visualize mask-image pairs? [Y/N]: ')) # ask user for his/her binary choice
            if user_choice.lower() == 'y': # if user enter Y/y
                random_num = random.randint(0,len(image_list)-1) # generate random number between 0 and len(image_list)-1
                image = cv2.imread(image_list[random_num], cv2.IMREAD_COLOR) # read an image
                mask = cv2.imread(mask_list[random_num], cv2.IMREAD_GRAYSCALE) # read a mask

                _, image_name = os.path.split(image_list[random_num]) # extract image name from the path
                _, mask_name = os.path.split(mask_list[random_num]) # extract mask name from the path

                # opacity = 0.5 # set desired opacity for the mask image
                # mask_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) # convert mask to a 3-channel image

                # # apply a color to a mask for better visualization
                # mask_color[:, :, 1] = 0 # zero out the green channel
                # mask_color[:, :, 2] = 255 # max out the red channel

                # overlay_image = cv2.addWeighted(image, 1-opacity, mask_color, opacity, 0) # overlay the mask onto original image

                fig, (ax1,ax2) = plt.subplots(nrows=1, ncols=2, figsize=(40,40)) # create a canvas with 1 rows and 2 cols with 20*20 figure size

                ax1.imshow(image, cmap='gray') # visualize an image
                ax1.set_title(f'{image_name}') # set image name as a title
                ax1.axis('off') # do not visualize an image with an axis

                ax2.imshow(mask, cmap='gray') # visualize a mask
                ax2.set_title(f'{mask_name}') # set mask name as a title
                ax2.axis('off') # do not visualize an image with an axis

                # ax3.imshow(cv2.cvtColor(overlay_image, cv2.COLOR_BGR2RGB)) # visualize an overlay image
                # ax3.set_title('Overlay image') # set title for figure
                # ax3.axis('off') # do not visualize overlay image with an axis

                plt.show() # display all plots/graphs

            elif user_choice.lower() == 'n': # if user enters either N or n
                return None # stop loop execution

            else:
                print('Enter \'Y\' or \'N\'') # ask user to enter either 'Y' or 'N'

        except BaseException:
            pass # do nothing ignore the exception

def get_user_choice(start:int, end:int) -> int:
    """
    This function is used to get integer user choice within specific range including both end-points.

    Parameters:
    - start (int): Starting point of the range
    - end (int: Ending point of the range

    Returns:
    - (int): Integer user choice
    """

    while True:
        try:
            user_choice = int(input(f'Enter an integer number between {start} and {end}: ')) # ask user to enter his/her choice

            if start <= user_choice <= end:
                return user_choice
            else:
                print(f'Invalid number. Enter an integer number between {start} and {end}') # ask user to enter a choice between specified range
        except ValueError:
            print(f'Invalid number. Enter an integer number between {start} and {end}') # ask user to enter a choice between specified range

def available_models() -> tuple:
    """
    This is used to provide a list of neural net architectures available for training on the existing dataset(s).

    Parameters:
    - (None)

    Returns:
    - (list,int): Returns tuple contains list of available neural net archs and user choice; (available_nn_arch, user_choice)
    """

    models = ['DNA-Segment'] # list of available models

    print('Select any one neural net architecture from the list given below')
    for i in range(len(models)):
        print(f'{i}_________{models[i]}') # print list of available models with the integer model number

    if (len(models)-1) == 0:
        return models, np.uint8(0) # only one neural net architecture is there; no need to ask to user for their choice
    else:
        return models, get_user_choice(0,len(models)-1) # get user choice

def available_optimizers() -> tuple:
    """
    This function is used to provide list of optimizers available for selected neural net architectures.

    Parameters:
    - (None)

    Returns:
    - (list,int): Returns tuple contains list of available optimizers and user choice; (available optimizers, user_choice)
    """

    optimizers = ['Adam',
                  'AdamW',
                  'RMSProp',
                  'SGD'] # list of available optimizers

    print('Select any one optimizer from the list given below')
    for i in range(len(optimizers)):
        print(f'{i}_________{optimizers[i]}') # print list of available optimizers with the integer optimizer number

    if len(optimizers) == 0:
        return optimizers, np.uint(0) # only one optimizer is there; no need to ask to user for their choice
    else:
        return optimizers, get_user_choice(0, len(optimizers)-1) # get user choice

def save_trained_model(model: Any, path:str, model_prefix:str, optimizer:str, epochs:int) -> None:
    """
    This function is used to save trained neural net architecture with .pth extension at specified path.

    Parameters:
    - model (any): model file which contains metadata with neural network weights
    - path (str): path to the directory where model will be saved
    - model_prefix (str): model file will be saved with this prefix (ultimately model name)
    - optimizer (str): optimizer selected by user
    - epochs (int): total number of epochs

    Returns:
    - (None)
    """

    if not os.path.exists(path): # check directory exists or not
        os.makedirs(path) # if not then create it

    model_prefix = model_prefix.replace(' ','_') # replace white space with the underscore if white space is available in the model prefix
    model_prefix = model_prefix.replace('(','').replace(')','') # replace any open or close brackets avaialable in model prefix white empty string; removal of brackets

    counter = 0 # set counter to zero initially

    while True:
        model_file_name = f'{model_prefix}_{optimizer}_{str(epochs)}_{counter}.pth' # generate model file name
        if not os.path.exists(os.path.join(path, model_file_name)):
            break
        else:
            counter += 1 # increment counter by one

    torch.save(model.state_dict(), os.path.join(path,model_file_name)) # save trained model @ specified path with specified name
    print(f'Trained model is successfully saved  at: \n{os.path.join(path,model_file_name)}')

### DNASegment Model

In [12]:
"""
class-0: Background class
class-1: Peanut Seed
"""

"""
Which metric to focus on:
[1] Dice co-efficient:

    -   It measures the overlap between the predicted and groundtruth regions. it is particularly useful
        for binary and multi-class segmentation tasks and is sensitive to the presence of small objects.

    -   High Dice co-efficient values indicate good segmentation performance.

    -   You should focus on the Dice coefficient for each class separately to ensure that the model
        performs well across all classes, especially in cases where class imbalance is a concern.

[2] mean Intersection over Union (IoU):

    -   It measures the average overlap between the predicted and groundtruth regions across all classes.
        It is a more stringent (stiff, rigid) metric than the Dice co-efficient because it penalizes
        both false positives and false negatives more heavily.

    -   High mean Intersection over Union values indicate that the model segments the images accurately
        across all classes.

    -   You should focus on the mean Intersection over Union to get an overall sense of the model's
        segmentation performance across all classes.

Class-Specific performance:
    -   If you are particularly concerned about the performance on individual classes (e.g. due to
        class imbalance or specific importance of certain classes), focus on the Dice co-efficient
        for each class. This will help you identify any classes where the model might be
        underperforming.

    -   If you want an overall measure of segmentation quality that takes into account the performance
        across all classes, focus on the mean IoU. This will give you a comprehensive view of how well
        the model segments the images on average.
"""

# Focal Loss for Dense Object Detection: https://arxiv.org/abs/1708.02002
class FocalLoss(nn.Module):
    def __init__(self,alpha=1.0,gamma=2.0,reduction='mean') -> None:
        super(FocalLoss, self).__init__()
        self.alpha = alpha # assign alpha value
        self.gamma = gamma # assign gamma value
        self.reduction = reduction # assign reduction value; 'mean', 'none', 'sum'

    def forward(self, predictions, targets):
        ce_loss = F.cross_entropy(predictions, targets, reduction='none') # cross-entropy loss
        p_t = torch.exp(-ce_loss) # probability of the positive class
        focal_loss = self.alpha * (1-p_t) ** self.gamma * ce_loss # calculate the focal loss

        if self.reduction=='mean': # the losses are averaged over the batch
            return focal_loss.mean()
        elif self.reduction=='sum': # the losses are summed over the batch
            return focal_loss.sum()
        else: # the loss will return as-is for each element in the batch; no reduction will be applied
            return focal_loss

# V-Net: Fully CNN for Volumetric Medical Image Segmentation: https://arxiv.org/abs/1606.04797
class BinaryDiceLoss(nn.Module):
    def __init__(self, smooth=1) -> None:
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth # assign smooth param value

    def forward(self, predictions, targets):
        predictions = torch.sigmoid(predictions) # apply sigmoid to predictions to get probabilities
        predictions = predictions.view(-1) # flatten the predictions from (N, H, W) to (N, H*W)
        targets = targets.view(-1) # flatten the targets from (N, H, W) to (N, H*W)
        intersection = (predictions * targets).sum() # calculate intersections
        union = predictions.sum() + targets.sum() # calculate union
        dice_coeff = (2. * intersection + self.smooth) / (union + self.smooth) # compute dice co-efficient
        dice_loss = 1 - dice_coeff # compute dice loss
        return dice_loss

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5) -> None:
        super(CombinedLoss, self).__init__()
        self.focal_loss = FocalLoss() # initialize the focal loss
        self.dice_loss = BinaryDiceLoss() # initialize the dice loss
        self.alpha = alpha # set value of alpha
        self.beta = beta # set value of beta
        """
        Adjust alpha and beta based on the performance:
        [1] Class imbalance: Increase alpha to give more weight to Focal Loss, which handles
            class imbalance well.
        [2] Overall Segmentation Accuracy: Increase beta to give more weight to Dice Loss/ Binary Dice Loss
            which measures overlap directly.
        """

    def forward(self, predictions, targets):
        focal_loss = self.focal_loss(predictions,targets) # calculate the focal loss
        dice_loss = self.dice_loss(predictions,targets) # calculate the dice loss
        combined_loss = self.alpha * focal_loss + self.beta * dice_loss # calculate the linear combination of focal loss and dice loss
        return combined_loss # return combined loss

# class PatchEmbedding(nn.Module):
#     def __init__(self, img_height, img_width, patch_size, in_channels, embed_dim):
#         super(PatchEmbedding, self).__init__()
#         self.patch_size = patch_size # initialize patch size
#         self.num_patches_h = img_height // patch_size # calculate the number of patches based on the image height
#         self.num_patches_w = img_width // patch_size # calculate the number of patches based on the image width
#         self.num_patches = self.num_patches_h * self.num_patches_w # calculate the total number of patches
#         self.embed_dim = embed_dim # initialize embedding dimension

#         self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) # linear layer to project each patch to embedding dimension

#     def forward(self, x):
#         B, C, H, W = x.shape # get the batch size, channels, height and width of the input
#         x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) # unfold the input tensor to extract patches
#         x = x.contiguous().view(B, C, -1, self.patch_size, self.patch_size) # rearrange the patches to be in the contiguous block of memory
#         x = x.permute(0, 2, 3, 4, 1).contiguous().view(B, -1, self.patch_size * self.patch_size * C) # permute the dimension to bring the patch dimension to the front and flatten the patch into single vector
#         x = self.proj(x) # project the flattened patches to the embedding dimension
#         return x # return the embedded patches

# class PatchEmbedding(nn.Module):
#     def __init__(self, in_channels=3, patch_size=16, embed_dim=768):
#         super().__init__()
#         self.patch_size = patch_size
#         self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) # change the kernel_size and stride to match the patch_size
#         self.norm = nn.LayerNorm(embed_dim)

#     def forward(self, x):
#         B, C, H, W = x.shape # get the batch size, channels, height and width of the input
#         # the following line is no longer needed as the conv2d layer automatically extracts patches
#         # x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) # unfold the input tensor to extract patches
#         x = self.projection(x) # apply the convolution to extract patches and embed them
#         x = x.flatten(2).transpose(1, 2) # flatten the patches and transpose to get the embedding dimension as the last dimension
#         x = self.norm(x) # apply layer normalization
#         return x

class PatchEmbedding(nn.Module):
    def __init__(self, img_height=720, img_width=1280, patch_size=16, in_channels=3, embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.img_height = img_height
        self.img_width = img_width
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        self.num_patches = (img_height // patch_size) * (img_width // patch_size)  # Calculate num_patches for your image size

        self.projection = nn.Conv2d(in_channels=in_channels,
                                   out_channels=embed_dim,
                                   kernel_size=patch_size,
                                   stride=patch_size)

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, num_patches, embed_dim):
        super(PositionalEncoding, self).__init__()
        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # initialize position embeddings as a learnable parameter

    def forward(self, x):
        return x + self.position_embeddings # add position embeddings to the input tensor X

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads # initialize the number of heads
        self.head_dim = embed_dim // num_heads # set dimension of the each attention head
        self.scale = self.head_dim ** -0.5 # scaling factor for attention scores

        self.qkv = nn.Linear(embed_dim, embed_dim * 3) # linear layer to project input to queries, key, and values
        self.fc = nn.Linear(embed_dim, embed_dim) # linear layer to project the concatenated outputs of attention heads

    def forward(self, x):
        B, N, D = x.shape # get the batch size, number of patches, and embedding dimension
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) # project input to queries, keys, and values and reshape for multi-head attention
        q, k, v = qkv.permute(2, 0, 3, 1, 4).chunk(3, dim=0) # split the queries, keys, and values for each head

        q = q.squeeze(0) # remove the redundant dimension for queries
        k = k.squeeze(0) # remove the redundant dimension for keys
        v = v.squeeze(0) # remove the redundant dimension for values

        attn = (q @ k.transpose(-2, -1)) * self.scale # compute the attention scores
        attn = attn.softmax(dim=-1) # apply softmax to get attention weights

        out = (attn @ v).transpose(1, 2).reshape(B, N, D) # compute the output by applying attention weights to the values
        out = self.fc(out) # project the concatenated outputs of the attention heads
        return out # return the output of the multi-head self-attention

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim):
        super(TransformerEncoderLayer, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim) # layer normalization applied before multi-head self-attention
        self.norm2 = nn.LayerNorm(embed_dim) # layer normalization applied before the MLP
        self.mhsa = MultiHeadSelfAttention(embed_dim, num_heads) # multi-head self-attention layer
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim), # first linear layer projects from embed_dim to mlp_dim
            nn.GELU(), # GELU activation function
            nn.Linear(mlp_dim, embed_dim), # second linear layer projects back to embed_dim
        )

    def forward(self, x):
        x = x + self.mhsa(self.norm1(x)) # apply layer normalization, then multi-head self-attention, and add the result to the input (Residual Connection)
        x = x + self.mlp(self.norm2(x)) # apply layer normalization, then MLP, and add the result to the input (Residual Connection)
        return x # return the output

class VisionTransformerEncoder(nn.Module):
    def __init__(self, img_height, img_width, patch_size, in_channels, embed_dim, num_layers, num_heads, mlp_dim):
        super(VisionTransformerEncoder, self).__init__()
        self.patch_embedding = PatchEmbedding() # initialize patch embedding
        self.positional_encoding = PositionalEncoding(self.patch_embedding.num_patches, embed_dim) # initialize positional embedding
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, mlp_dim) for _ in range(num_layers)
        ]) # initialize a list of transformer encoder layers

    def forward(self, x):
        x = self.patch_embedding(x) # apply patch embedding to input
        x = self.positional_encoding(x) # add positional encoding to the embeddings
        skip_connections = [] # initialize a list to store skip connections
        for layer in self.encoder_layers:
            x = layer(x) # apply each transformer encoder layer
            skip_connections.append(x) # store the output for skip connections
        return x, skip_connections # return the final output and skip connections

class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim):
        super(TransformerDecoderLayer, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim) # layer normalization before self-attention
        self.norm2 = nn.LayerNorm(embed_dim) # layer normalization before cross-attention
        self.norm3 = nn.LayerNorm(embed_dim) # layer normalization before MLP
        self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads) # self-attention layer
        self.cross_attn = MultiHeadSelfAttention(embed_dim, num_heads) # cross-attention layer
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim), # first linear layer of MLP
            nn.GELU(), # GELU activation function
            nn.Linear(mlp_dim, embed_dim), # second linear layer of MLP
        )

    def forward(self, x, encoder_output):
        x = x + self.self_attn(self.norm1(x)) # apply self-attention and add residual connection
        x = x + self.cross_attn(self.norm2(x), encoder_output) # applt cross-attention and add residual connection
        x = x + self.mlp(self.norm3(x)) # apply MLP and add residual connection
        return x # return the output of the decoder layer

class VisionTransformerDecoder(nn.Module):
    def __init__(self, embed_dim, num_layers, num_heads, mlp_dim):
        super(VisionTransformerDecoder, self).__init__()
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(embed_dim, num_heads, mlp_dim) for _ in range(num_layers)
        ]) # initialize a list of transformer decoder layers

    def forward(self, x, encoder_output):
        for layer in self.decoder_layers:
            x = layer(x, encoder_output) # apply each transformer decoder layer
        return x # return the final output

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpsampleBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) # initialize the transposed convolution layer for upsampling

    def forward(self, x):
        return self.upconv(x) # apply the transposed convolution to upsample the input

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) # initialize the first convolution layer with kernel size 3 and padding 1
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) # initialize the second convolution layer with kernel size 3 and padding 1

    def forward(self, x):
        print("Shape of x before conv1:", x.shape)  # Print the shape of x
        x = F.relu(self.conv1(x)) # apply the first convolution layer followed by ReLU activation
        x = F.relu(self.conv2(x)) # applt the second convolution layer followed by ReLU activation
        return x # return the output

class DNASegmentModel(pl.LightningModule):
    def __init__(self, img_height, img_width, patch_size, in_channels, embed_dim, num_layers, num_heads, mlp_dim, num_classes, optimizer):
        super(DNASegmentModel, self).__init__()

        self.optimizer = optimizer # set optimizer

        # stage-1: encoder, bottleneck, and decoder
        self.stage1_encoder = VisionTransformerEncoder(img_height, img_width, patch_size, in_channels, embed_dim, num_layers, num_heads, mlp_dim)
        self.stage1_bottleneck = ConvBlock(embed_dim, embed_dim * 2)
        self.stage1_decoder = VisionTransformerDecoder(embed_dim, num_layers, num_heads, mlp_dim)

        # stage-2: encoder, bottleneck, and decoder
        self.stage2_encoder = VisionTransformerEncoder(img_height, img_width, patch_size, in_channels, embed_dim, num_layers, num_heads, mlp_dim)
        self.stage2_bottleneck = ConvBlock(embed_dim, embed_dim * 2)
        self.stage2_decoder = VisionTransformerDecoder(embed_dim, num_layers, num_heads, mlp_dim)

        # upsampling blocks for stage 1
        self.upconv_blocks_stage1 = nn.ModuleList([
            UpsampleBlock(embed_dim * 2, embed_dim),
            UpsampleBlock(embed_dim, embed_dim // 2),
            UpsampleBlock(embed_dim // 2, embed_dim // 4),
            UpsampleBlock(embed_dim // 4, embed_dim // 8),
            UpsampleBlock(embed_dim // 8, embed_dim // 16),
            UpsampleBlock(embed_dim // 16, embed_dim // 32),
        ])

        # upsampling blocks for stage 2
        self.upconv_blocks_stage2 = nn.ModuleList([
            UpsampleBlock(embed_dim * 2, embed_dim),
            UpsampleBlock(embed_dim, embed_dim // 2),
            UpsampleBlock(embed_dim // 2, embed_dim // 4),
            UpsampleBlock(embed_dim // 4, embed_dim // 8),
            UpsampleBlock(embed_dim // 8, embed_dim // 16),
            UpsampleBlock(embed_dim // 16, embed_dim // 32),
        ])

        # final convolution to get the desired number of classes
        self.final_conv = nn.Conv2d(embed_dim // 32, num_classes, kernel_size=1)

    def forward(self, x):
        # stage 1
        x1, skip_connections1 = self.stage1_encoder(x)
        x1 = self.stage1_bottleneck(x1)
        x1 = self.stage1_decoder(x1, skip_connections1[-1])

        # upsample and concatenate skip connections of stage 1
        for i, upconv_block in enumerate(self.upconv_blocks_stage1):
            x1 = upconv_block(x1)
            if i < len(skip_connections1):
                x1 = torch.cat((x1, skip_connections1[-(i+1)]), dim=1)

        # concatenate original input image with stage 1 output
        x2_input = torch.cat((x, x1), dim=1)

        # stage 2
        x2, skip_connections2 = self.stage2_encoder(x2_input)
        x2 = self.stage2_bottleneck(x2)

        # add connections from stage 1 decoder to stage 2 encoder
        for i in range(len(skip_connections1)):
            x2 = x2 + skip_connections1[i]

        x2 = self.stage2_decoder(x2, skip_connections2[-1])

        # upsample and concatenate skip connections of stage 2
        for i, upconv_block in enumerate(self.upconv_blocks_stage2):
            x2 = upconv_block(x2)
            if i < len(skip_connections2):
                x2 = torch.cat((x2, skip_connections2[-(i+1)]), dim=1)

        x2 = self.final_conv(x2)
        return x2

    def training_step(self, batch, batch_idx) -> None:

        images, masks = batch # load input images and input masks from single-single batch
        outputs = self(images) # calculate the prediction

        # compute metrics
        preds = torch.argmax(outputs, dim=1) # convert raw outputs to predicted class labels

        combined_loss = CombinedLoss(alpha=0.5, beta=0.5) # initialize the combined loss
        loss = combined_loss(outputs,masks) # calculate the focal loss; attention to the below comment about params

        """
        Here in the loss calculation I have passed 'outputs' instead of the 'preds' because
        cross-entropy loss function needs the raw logits to compute the probabilities for each
        class and then calculate the loss. The cross-entropy function internally applies the
        softmax operation to the logits to compute the probabilities and cross-entropy loss.
        """

        mean_iou_score = self.mean_iou(preds,masks) # calculate the mean iou score over all the classes
        """
        Here in the mean_iou function, "preds" parameter is passed because here there is need of
        predicted label classes, not need of the raw logits.
        """

        dice_coeff_bg, dice_coeff_peanut_seed = self.dice_coefficient(preds,masks) # calculate the dice_coefficient separately for all avaiilable classes

        # log metrics
        self.log('DNASegment_train_combined_loss',loss,on_step=True,on_epoch=True,prog_bar=True,enable_graph=True) # save the loss logs for visualization
        self.log('DNASegment_train_mean_IoU',mean_iou_score,on_step=True,on_epoch=True,prog_bar=True,enable_graph=True) # save the mean IoU logs for visualization
        self.log('DNASegment_train_dice_coeff_bg',dice_coeff_bg,on_step=True,on_epoch=True,prog_bar=True,enable_graph=True) # save the dice_coeff_bg logs for visualization
        self.log('DNASegment_train_dice_coeff_peanut_seed',dice_coeff_peanut_seed,on_step=True,on_epoch=True,prog_bar=True,enable_graph=True) # save the dice_coeff_peanut_seed logs for visualization

        return loss

    def validation_step(self, batch, batch_idx) -> None:

        images, masks = batch # load input images and input masks from single-single batch
        outputs = self(images) # calculate the prediction

        # compute metrics
        preds = torch.argmax(outputs, dim=1) # convert raw outputs to predicted class labels

        combined_loss = CombinedLoss(alpha=0.5, beta=0.5) # initialize the combined loss
        loss = combined_loss(outputs,masks) # calculate the focal loss; attention to the below comment about params

        """
        Here in the loss calculation I have passed 'outputs' instead of the 'preds' because
        cross-entropy loss function needs the raw logits to compute the probabilities for each
        class and then calculate the loss. The cross-entropy function internally applies the
        softmax operation to the logits to compute the probabilities and cross-entropy loss.
        """

        mean_iou_score = self.mean_iou(preds,masks) # calculate the mean iou score over all the classes
        """
        Here in the mean_iou function, "preds" parameter is passed because here there is need of
        predicted label classes, not need of the raw logits.
        """

        dice_coeff_bg, dice_coeff_peanut_seed = self.dice_coefficient(preds,masks) # calculate the dice_coefficient separately for all avaiilable classes

        # log metrics
        self.log('DNASegment_validation_combined_loss',loss,on_step=True,on_epoch=True,prog_bar=True,enable_graph=True) # save the loss logs for visualization
        self.log('DNASegment_validation_mean_IoU',mean_iou_score,on_step=True,on_epoch=True,prog_bar=True,enable_graph=True) # save the mean IoU logs for visualization
        self.log('DNASegment_validation_dice_coeff_bg',dice_coeff_bg,on_step=True,on_epoch=True,prog_bar=True,enable_graph=True) # save the dice_coeff_bg logs for visualization
        self.log('DNASegment_validation_dice_coeff_peanut_seed',dice_coeff_peanut_seed,on_step=True,on_epoch=True,prog_bar=True,enable_graph=True) # save the dice_coeff_peanut_seed logs for visualization

        return loss

    def mean_iou(self, predictions, targets, num_classes=2) -> Any:
        ious = [] # create an empty list to store class-wise ious
        predictions = predictions.view(-1) # flatten the predictions from (N, H, W) to (N, H*W)
        targets = targets.view(-1) # flatten the targets from (N, H, W) to (N, H*W)

        for cls in range(num_classes): # iterate over each class
            predictions_inds = predictions == cls # create a binary mask for the current class in the predictions
            target_inds = targets == cls # create a binary mask for the current class in the targets
            intersection = (predictions_inds & target_inds).sum().item() # calculate the intersection for the current class
            union = (predictions_inds | target_inds).sum().item() # calculate the union for the current class

            if union == 0:
                ious.append(float('nan')) # if there is no groundtruth, do not include this in IoU calculation
            else:
                iou = intersection / union # calculate the IoU for the current class
                ious.append(iou) # append the IoU of the cuurent class to the list

        mean_iou = torch.tensor(ious).mean() # return the mean iou over all classes

        return mean_iou.item() # return a python scalar instead of torch tensor

    def dice_coefficient(self, predictions, targets, num_classes=2, smooth=1e-5) -> Any:
        dice_scores = [] # define empty list to store class-wise dice scores

        for cls in range(num_classes):
            predictions = (predictions == cls).float().view(-1) # creates binary mask where pixels belonging to the current class 'cls' are marked as 1 and others as 0
            targets = (targets == cls).float().view(-1) # create binary mask
            intersection = (predictions * targets).sum() # calculate the intersection for the current class
            union = predictions.sum() + targets.sum() # calculate the union for the current class
            dice_coeff = (2. * intersection + smooth)/(union + smooth) # calculate the dice_coeff for the current class
            dice_scores.append(dice_coeff.item()) # convert tensor to scalar and append to a list

        return tuple(dice_scores) # return the dice co-efficient as tuple
    """
    [1] LLRD: Layer-wise Learning Rate Decay
    [2] Weight Decay: L2-Regularization
    [3] Drop Path Rate (Stochastic Path): Randomly drops entire layers during training to help prevent
            overfitting and improve the robustness of the model.
    """
    def configure_optimizers(self):
        if self.optimizer.lower() == 'adam':
            return torch.optim.Adam(self.parameters(),
                                    lr=0.0001,
                                    betas=(0.9,0.999),
                                    weight_decay=0.1) # set adam optimizer with 0.0001 learning rate
        elif self.optimizer.lower() == 'adamw':
            return torch.optim.AdamW(self.parameters(),
                                     lr=0.0001,
                                     betas=(0.9,0.999),
                                     weight_decay=0.1) # set adamw optimizer with 0.0001 learning rate
        elif self.optimizer.lower() == 'rmsprop':
            return torch.optim.RMSprop(self.parameters(),
                                       lr=0.0001,
                                       weight_decay=0.1) # set RMSProp optimizer with 0.0001 learning rate
        elif self.optimizer.lower() == 'sgd':
            return torch.optim.SGD(self.parameters(),
                                   lr=0.0001,
                                   weight_decay=0.1) # set SGD optimizer with 0.0001 learning rate

### Train DNASegment Model

In [16]:
check_gpu_config() # get GPU (General Processing Unit) information if it is available

config = Config() # create an instance of Config class
config.printConfiguration() # print all configuration set by defualt

# wandb.init(entity=config.ENTITY, # assign team/organization name
#                project=config.PROJECT, # assign project name
#                anonymous=config.ANONYMOUS, # set anonymous value type
#                reinit=config.REINIT) # initialize the weights and biases cloud server instance

print("-----------------------------------------------------------")
print("--------------UNDERSTANDING JSON FILE STRUCTURE------------")
print("-----------------------------------------------------------")

look_at_json_structure(config.SAMPLE_JSON_FILE_PATH) # understand JSON file structure and which information it contains

print("-----------------------------------------------------------")
print("------------------CREATING PROCESSED DATASET---------------")
print("-----------------------------------------------------------")

input_files = glob.glob(f'{config.INPUT_JSON_FILE_PATH}/File*.json') # create list of json files available @ config.INPUT_JSON_FILE_PATH
output_file = f'{config.INPUT_JSON_FILE_PATH}/Merge.json' # define path for json file, contains merged data from multiple json files
merge_json_files(input_files=input_files, output_file=output_file) # merge multiple json files

print("Creating mask images from json data...")
createMasks(json_file_path=output_file,
                raw_image_dir=config.RAW_IMAGE_DIR,
                base_data_path=config.BASE_DATA_PATH) # create mask images from existing mask region information i.e. XY co-ordinates

config.TRAINSET_PATH = os.path.join(config.BASE_DATA_PATH,f'processed/train')

show_mask_image_pair(image_dir=os.path.join(config.TRAINSET_PATH,'images'),
                         mask_dir=os.path.join(config.TRAINSET_PATH,'masks')) # visualize mask-image pairs

config.TRAIN_IMAGE_DIR = os.path.join(config.TRAINSET_PATH,'images') # set path to a directory contains input images for training
config.TRAIN_MASK_DIR = os.path.join(config.TRAINSET_PATH, 'masks') # set path to a directory contains input mask images for training

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # set device for model training
data_module = SegmentationDataModule(train_image_dir=config.TRAIN_IMAGE_DIR,
                                         train_mask_dir=config.TRAIN_MASK_DIR,
                                         batch_size=config.BATCH_SIZE,
                                         transform=config.TRANSFORM) # initialize the data module

print("-----------------------------------------------------------")
print("---------------NN ARCHITECTURE (MODEL) SELECTION-----------")
print("-----------------------------------------------------------")

available_nn_archs, user_choice_nn_arch = available_models() # give list of available neural net architectures to user for training
available_optims, user_choice_optimizer = available_optimizers() # give list of available optimizers to user for neural net configuration
print(f'{available_nn_archs[user_choice_nn_arch]} neural net architecture is selected with {available_optims[user_choice_optimizer]} optimizer.')

if user_choice_nn_arch == 0:
        model = DNASegmentModel(
            img_height=720, # image height
            img_width=1280, # image width
            patch_size=16, # single patch size
            in_channels=3, # number of channels in the input images
            embed_dim=768, # embedding dimension
            num_layers=18, # number of layers in the transformers
            num_heads=12, # number of heads
            mlp_dim=3072, # MLP dimension
            num_classes=2, # total number of output classes
            optimizer=available_optims[user_choice_optimizer]) # set optimizer

print('- Model summary:\n')
# summary(model=model,
#             input_size=(1,3,720,1280),
#             col_names=['input_size','output_size','kernel_size']) # print model summary; input shape is extracted @ data loading time

model = model.to(device) # move neural net architecture to available computing device
# wandb_logger = WandbLogger(log_model=config.LOG_MODEL) # initialize the weights-and-biases logger

# trainer = pl.Trainer(max_epochs=config.MAX_EPOCHS, # set maximum number of epochs
#                          log_every_n_steps=1, # after how many 'n' steps log will be saved
#                          logger=wandb_logger) # assign logger for saving a logs

trainer = pl.Trainer(max_epochs=config.MAX_EPOCHS, # set maximum number of epochs
                         log_every_n_steps=1) # assign logger for saving a logs

print("-----------------------------------------------------------")
print("---------------NN ARCHITECTURE (MODEL) TRAINING------------")
print("-----------------------------------------------------------")

print('Training started...')
trainer.fit(model, data_module) # train the neural network architecture selected by user
print('Training finished.')

# wandb.finish() # close the weights and biases cloud instance

print("-----------------------------------------------------------")
print("----------------------SAVE TRAINED MODEL-------------------")
print("-----------------------------------------------------------")

print('Saving trained model..')
save_trained_model(model=model, # model
                       path=config.PATH_TO_SAVE_TRAINED_MODEL, # path to save trained model
                       model_prefix=available_nn_archs[user_choice_nn_arch], # model name
                       optimizer=available_optims[user_choice_optimizer], # selected optimizer and max. epochs
                       epochs=config.MAX_EPOCHS) # save trained neural network architecture in the .pth format


- Number of GPUs available: 1
- GPU name: Tesla T4
Wed Jul 31 17:27:20 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   65C    P0              30W /  70W |   4985MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                 

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                 | Type                     | Params | Mode 
--------------------------------------------------------------------------
0 | stage1_encoder       | VisionTransformerEncoder | 130 M  | train
1 | stage1_bottleneck    | ConvBlock                | 31.9 M | train
2 | stage1_decoder       | VisionTransformerDecoder | 170 M  | train
3 | stage2_encoder       | VisionTransformerEncoder | 130 M  | train
4 | stage2_bottleneck    | ConvBlock                | 31.9 M | train
5 | stage2_decoder       | VisionTransformerDecoder | 170 M  | train
6 | upconv_blocks_stage1 | ModuleList               | 6.3 M  | tra

-----------------------------------------------------------
---------------NN ARCHITECTURE (MODEL) TRAINING------------
-----------------------------------------------------------
Training started...


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.06 GiB. GPU 