<a href="https://colab.research.google.com/github/paudelsushil/labelcombinations/blob/main/project_adleo_geog315_spring24.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Objective3:
# DeepLab3+ Model
DeepLabv3+ utilizes an encoder-decoder structure to perform image segmentation. The encoder extracts shallow and high-level semantic information from the image, while the decoder combines low-level and high-level features to improve the accuracy of segmentation boundaries and classify the semantic information of different pixels [Chen et al., (2018)](https://link.springer.com/content/pdf/10.1007/978-3-030-01234-2_49.pdf?pdf=inline%20link).

This project is based on the improved classis DeepLabv3+ network model proposed by [Chen et al.,(2023)](https://link.springer.com/content/pdf/10.1007/s40747-023-01304-z.pdf).

## Architecture of improved DeepLabv3+ with MobileNetv2 backbone

**`A. Encoder`**
 1. `Backbone` : lightweight network `MobileNetv2` in place of Xception.
 2. `ASPP` : `Hybrid Dialted Convolution` (HDC) module to alleviate the gridding effect. In addition,  `Strip Pooling Module` is used instead of spatial mean pooling to improve th elocal segmentation effect.
 3. `Normalization-based Attention Module` (NAM): This lightweight attention mechanism is also applied to the stacked compressed high-level feature maps to help improve the segmentation accuracy of the image.

**`B. Decoder`**
1. `NAM`: The seventh layer feature with `NAM` attention is upsampled to the same size as the fourth layer feature after fusion and channel adjustment.
2. `ResNet50`: This module is added to obtain riccher low-level target feature information.
3. `Concatenate`: The **deep features** and **shallow features** are concatenated as in the original model.
4. `Upsampling`: After a 3 X 3 convolution and 4 X `upsampling`, the image is restored to its original size.


 [Architecture Image]()

# 1. Data Preparation

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
# install required packages

%%capture
!pip install rasterio


In [None]:
# Import required packages
import os
from pathlib import Path

from datetime import datetime, timedelta
import tqdm # Adds a smart progress meter to any iterable or file operation

import math
import random
import pandas as pd
import numpy as np


import cv2
import rasterio
#  defines a rectangular area within the raster using four properties
# xoff, yoff, width, height
from rasterio.windows import Window


import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard import SummaryWriter

import logging
import pickle
from datetime import datetime
import itertools


from IPython.core.debugger import set_trace # Insert a breakpoint into the code
from IPython.display import Image

import matplotlib.pyplot as plt

## Utility Functions

In [None]:
class InputError(Exception):
    '''
    Exception raised for errors in the input
    '''
    def __init__(self, message):
        '''
        Params:
            message (str): explanation of the error
        '''
        self.message = message
    def __str__(self):
        '''
        Define message to return when error is raised
        '''
        if self.message:
            return 'InputError, {} '.format(self.message)
        else:
            return 'InputError'
# =============================================================================
def load_data(data_path, usage="train", window=None, norm_stats_type=None,
              is_label=False):
    '''
    Read geographic data into numpy array
    Params:
        data_path : str
            Path of data to load
        usage : str
            Usage of the data: "train", "validate", or "predict"
        window : tuple
            The view onto a rectangular subset of the data, in the format of
            (column offsets, row offsets, width in pixel, height in pixel)
        norm_stats_type : str
            How the normalization statistics is calculated.
        is_label : binary
            Decide whether to saturate data with tested threshold
    Returns:
        narray
    '''
    # Open the data file using the 'rasterio' library
    with rasterio.open(data_path, "r") as src:
      # Check if the data is a label (segmentation mask)
        if is_label:
            if src.count != 1:  # Ensure the label has a single channel
                raise InputError("Label shape not applicable: \
                                expected 1 channel")
            img = src.read(1)  # Read the single channel of the label data

        else:
        # Store the value representing 'no data' in the image
            nodata = src.nodata
            # Verify normalization type is valid
            assert norm_stats_type in ["local_per_tile", "local_per_band",
                                      "global_per_band"]

            if norm_stats_type == "local_per_tile":
              # Apply per-tile normalization
                img = mmnorm1(src.read(), nodata=nodata)
            elif norm_stats_type == "local_per_band":
              # Per-band normalization, clipping values
                img = mmnorm2(src.read(), nodata=nodata, clip_val=1.5)
            elif norm_stats_type == "global_per_band":
              # Global per-band normalization, clipping values
                img = mmnorm3(src.read(), nodata=nodata, clip_val=1.5)

            # For 'train' or 'validate' subsets
            if usage in ['train', 'validate']:
              # Extract a specific window from the image
                img = img[:, max(0, window[1]): window[1] + window[3],
                          max(0, window[0]): window[0] + window[2]]

    return img  # Return the processed image or label data
# ==============================================================================

def get_stacked_img(img_paths, usage, norm_stats_type="local_per_tile",
                    window=None):
    '''
    Read geographic data into numpy array
    Params:
        gsPath :str
            Path of growing season image
        osPath : str
            Path of off season image
        img_paths : list
            List of paths for imgages
        usage : str
            Usage of the image: "train", "validate", or "predict"
        norm_stats_type : str
            How the normalization statistics is calculated.
        window : tuple
            The view onto a rectangular subset of the data, in the
            format of (column offsets, row offsets, width in pixel, height in
            pixel)
    Returns:
        ndarray
    '''

    if len(img_paths) > 1:  # If there are multiple image paths:
      img_ls = [load_data(m, usage, window, norm_stats_type) for m in img_paths]
      # Load data for each image path, potentially applying normalization
      img = np.concatenate(img_ls, axis=0).transpose(1, 2, 0)
      # Combine the loaded data into a single array and rearrange dimensions
    else:  # If there's only a single image path:
      # Load data for the single image path and rearrange dimensions
      img = load_data(img_paths[0], usage, \
                      window, norm_stats_type).transpose(1, 2, 0)

    # For 'train' or 'validate' subsets:
    if usage in ["train", "validate"]:
      # Extract window parameters
      col_off, row_off, col_target, row_target = window
      row, col, c = img.shape  # Get image dimensions

      # Check if image is smaller than the target window
      if row < row_target or col < col_target:
          row_off = abs(row_off) if row_off < 0 else 0  # Adjust offsets
          col_off = abs(col_off) if col_off < 0 else 0

          # Create a larger blank canvas
          canvas = np.zeros((row_target, col_target, c))
          # Place image onto canvas
          canvas[row_off: row_off + row, col_off : col_off + col, :] = img
          return canvas  # Return the canvas with the padded image

      else:
          return img  # The image fits the window, so return it directly

    elif usage == "predict":  # For prediction purposes:
      return img  # Return the image as is

    else:
      raise ValueError  # Invalid 'usage' value

# ==============================================================================
def get_buffered_window(src_path, dst_path, buffer):
    '''
    Get bounding box representing subset of source image that overlaps with
    bufferred destination image, in format of (column offsets, row offsets,
    width, height)

    Params:
        src_path : str
            Path of source image to get subset bounding box
        dst_path : str
            Path of destination image as a reference to define the
            bounding box. Size of the bounding box is
            (destination width + buffer * 2, destination height + buffer * 2)
        buffer :int
            Buffer distance of bounding box edges to destination image
            measured by pixel numbers

    Returns:
        tuple in form of (column offsets, row offsets, width, height)
    '''

    with rasterio.open(src_path, "r") as src:
        gt_src = src.transform

    with rasterio.open(dst_path, "r") as dst:
        gt_dst = dst.transform
        w_dst = dst.width
        h_dst = dst.height

    col_off = round((gt_dst[2] - gt_src[2]) / gt_src[0]) - buffer
    row_off = round((gt_dst[5] - gt_src[5]) / gt_src[4]) - buffer
    width = w_dst + buffer * 2
    height = h_dst + buffer * 2

    return col_off, row_off, width, height

# ==============================================================================

def get_meta_from_bounds(file, buffer):
    '''
    Get metadata of unbuffered region in given file
    Params:
        file (str):  File name of a image chip
        buffer (int): Buffer distance measured by pixel numbers
    Returns:
        dictionary
    '''

    with rasterio.open(file, "r") as src:

        meta = src.meta
        dst_width = src.width - 2 * buffer
        dst_height = src.height - 2 * buffer

        window = Window(buffer, buffer, dst_width, dst_height)
        win_transform = src.window_transform(window)

    meta.update({
        'width': dst_width,
        'height': dst_height,
        'transform': win_transform,
        'count': 1,
        'nodata': -128,
        'dtype': 'int8'
    })

    return meta


# ==============================================================================
def display_hist(img):
    '''
    Display data distribution of input image in a histogram
    Params:
        img (narray): Image in form of (H,W,C) to display data distribution
    '''

    img = mmnorm1(img)
    im = np.where(img == 0, np.nan, img)

    plt.hist(img.ravel(), 500, [np.nanmin(im), img.max()])
    plt.figure(figsize=(20, 20))
    plt.show()

# ==============================================================================
def mmnorm1(img, nodata):
    '''
    Data normalization with min/max method
    Params:
        img (narray): The targeted image for normalization
    Returns:
        narrray
    '''

    img_tmp = np.where(img == nodata, np.nan, img)
    img_max = np.nanmax(img_tmp)
    img_min = np.nanmin(img_tmp)
    normalized = (img - img_min) / (img_max - img_min)
    normalized = np.clip(normalized, 0, 1)

    return normalized

# ------------------------------------------------------------------------------
def mmnorm2(img, nodata, clip_val=None):
    r"""
    Normalize the input image pixels to [0, 1] ranged based on the
    minimum and maximum statistics of each band per tile.
    Arguments:
            img : numpy array
                Stacked image bands with a dimension of (C,H,W).
            nodata : str
                Value reserved to represent NoData in the image chip.
            clip_val : int
                Defines how much of the distribution tails to be cut off.
    Returns:
            img : numpy array
                Normalized image stack of size (C,H,W).
    Note 1: If clip then min, max are calculated from the clipped image.
    """

    # filter out zero pixels in generating statistics.
    nan_corr_img = np.where(img == nodata, np.nan, img)
    nan_corr_img = np.where(img == 0, np.nan, img)

    if clip_val > 0:
        left_tail_clip = np.nanpercentile(nan_corr_img, clip_val)
        right_tail_clip = np.nanpercentile(nan_corr_img, 100 - clip_val)

        left_clipped_img = np.where(img < left_tail_clip, left_tail_clip, img)
        clipped_img = np.where(left_clipped_img > right_tail_clip,
                               right_tail_clip, left_clipped_img)

        normalized_bands = []
        for i in range(img.shape[0]):
            band_min = np.nanmin(clipped_img[i, :, :])
            band_max = np.nanmax(clipped_img[i, :, :])
            normalized_band = (clipped_img[i, :, :] - band_min) /\
                (band_max - band_min)
            normalized_bands.append(np.expand_dims(normalized_band, 0))
        normal_img = np.concatenate(normalized_bands, 0)

    elif clip_val == 0 or clip_val is None:
        normalized_bands = []
        for i in range(img.shape[0]):
            band_min = np.nanmin(nan_corr_img[i, :, :])
            band_max = np.nanmax(nan_corr_img[i, :, :])
            normalized_band = (nan_corr_img[i, :, :] - band_min) /\
                (band_max - band_min)
            normalized_bands.append(np.expand_dims(normalized_band, 0))
        normal_img = np.concatenate(normalized_bands, 0)

    else:
        raise ValueError("clip must be a non-negative decimal.")

    normal_img = np.clip(normal_img, 0, 1)
    return normal_img

# ------------------------------------------------------------------------------
def mmnorm3(img, nodata, clip_val=None):
    hardcoded_stats = {
        "mins": np.array([331.0, 581.0, 560.0, 1696.0]),
        "maxs": np.array([1403.0, 1638.0, 2076.0, 3652.0])
    }

    num_bands = img.shape[0]
    mins = hardcoded_stats["mins"]
    maxs = hardcoded_stats["maxs"]

    if clip_val:
        normalized_bands = []
        for i in range(num_bands):
            nan_corr_img = np.where(img[i, :, :] == nodata, np.nan,
                                    img[i, :, :])
            nan_corr_img = np.where(img[i, :, :] == 0, np.nan, img[i, :, :])
            left_tail_clip = np.nanpercentile(nan_corr_img, clip_val)
            right_tail_clip = np.nanpercentile(nan_corr_img, 100 - clip_val)
            left_clipped_band = np.where(img[i, :, :] < left_tail_clip,
                                         left_tail_clip, img[i, :, :])
            clipped_band = np.where(left_clipped_band > right_tail_clip,
                                    right_tail_clip, left_clipped_band)
            normalized_band = (clipped_band - mins[i]) / (maxs[i] - mins[i])
            normalized_bands.append(np.expand_dims(normalized_band, 0))
        img = np.concatenate(normalized_bands, 0)

    else:
        for i in range(num_bands):
            img[i, :, :] = (img[i, :, :] - mins[i]) / (maxs[i] - mins[i])

    img = np.clip(img, 0, 1)
    return img

# ==============================================================================
def get_chips(img, dsize, buffer):
    '''
    Generate small chips from input images and the corresponding index of each
    chip The index marks the location of corresponding upper-left pixel of a
    chip.
    Params:
        img (narray): Image in format of (H,W,C) to be crop, in this case it is
            the concatenated image of growing season and off season
        dsize (int): Cropped chip size
        buffer (int):Number of overlapping pixels when extracting images chips
    Returns:
        list of cropped chips and corresponding coordinates
    '''

    h, w, _ = img.shape
    x_ls = range(0,h - 2 * buffer, dsize - 2 * buffer)
    y_ls = range(0, w - 2 * buffer, dsize - 2 * buffer)

    index = list(itertools.product(x_ls, y_ls))

    img_ls = []
    for i in range(len(index)):
        x, y = index[i]
        img_ls.append(img[x:x + dsize, y:y + dsize, :])

    return img_ls, index


# ==============================================================================
def display(img, label, mask):

    '''
    Display composites and their labels
    Params:
        img (torch.tensor): Image in format of (C,H,W)
        label (torch.tensor): Label in format of (H,W)
        mask (torch.tensor): Mask in format of (H,W)
    '''

    gsimg = (comp432_dis(img, "GS") * 255).permute(1, 2, 0).int()
    osimg = (comp432_dis(img, "OS") * 255).permute(1, 2, 0).int()


    _, figs = plt.subplots(1, 4, figsize=(20, 20))

    label = label.cpu()

    figs[0].imshow(gsimg)
    figs[1].imshow(osimg)
    figs[2].imshow(label)
    figs[3].imshow(mask)

    plt.show()

# ==============================================================================
# color composite
def comp432_dis(img, season):
    '''
    Generate false color composites
    Params:
        img (torch.tensor): Image in format of (C,H,W)
        season (str): Season of the composite to generate, be  "GS" or "OS"
    '''

    viewsize = img.shape[1:]

    if season == "GS":

        b4 = mmnorm1(img[3, :, :].cpu().view(1, *viewsize),0)
        b3 = mmnorm1(img[2, :, :].cpu().view(1, *viewsize),0)
        b2 = mmnorm1(img[1, :, :].cpu().view(1, *viewsize),0)

    elif season == "OS":
        b4 = mmnorm1(img[7, :, :].cpu().view(1, *viewsize), 0)
        b3 = mmnorm1(img[6, :, :].cpu().view(1, *viewsize), 0)
        b2 = mmnorm1(img[5, :, :].cpu().view(1, *viewsize), 0)

    else:
        raise ValueError("Bad season value")

    img = torch.cat([b4, b3, b2], 0)

    return img

# ==============================================================================
def make_reproducible(seed=42, cudnn=True):
    """Make all the randomization processes start from a shared seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if cudnn:
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
# ==============================================================================

def pickle_dataset(dataset, file_path):
    with open(file_path, "wb") as fp:
              pickle.dump(dataset, fp)

# ------------------------------------------------------------------------------
def load_pickle(file_path):
    return pd.read_pickle(file_path)

# ==============================================================================
def progress_reporter(msg, verbose, logger=None):
    """Helps control print statements and log writes
    Parameters
    ----------
    msg : str
      Message to write out
    verbose : bool
      Prints or not to console
    logger : logging.logger
      logger (defaults to none)

    Returns:
    --------
        Message to console and or log
    """

    if verbose:
        print(msg)

    if logger:
        logger.info(msg)

# ==============================================================================
def setup_logger(log_dir, log_name, use_date=False):
    """Create logger
    """
    if use_date:
        dt = datetime.now().strftime("%d%m%Y_%H%M")
        log = "{}/{}_{}.log".format(log_dir, log_name, dt)
    else:
        log = "{}/{}.log".format(log_dir, log_name)

    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    log_format = (
        f"%(asctime)s::%(levelname)s::%(name)s::%(filename)s::"
        f"%(lineno)d::%(message)s"
    )
    logging.basicConfig(filename=log, filemode='w',
                        level=logging.INFO, format=log_format)

    return logging.getLogger()

# Data Preparation
1. Find the suitable `image dataset` to apply `improved DeepLabv3+ model` for the image segmentation process.
  - For this task, we used image dataset that was used in `S. Khallaghi, (2024) ch. 2`.
2. Prepare the `labels (pixel-wise annotations)` that are compatible with selected image dataset.
  - For this task, we filtered [all_class_cataloge](/content/gdrive/MyDrive/adleo/project_data/label_catalog_allclasses.csv) using methods and functions given in [notebook](https://github.com/paudelsushil/labelcombinations/blob/main/Make_Labels_ADLEO_Final.ipynb) and prepared final [filtered cataloge](/content/gdrive/MyDrive/adleo/project_data/label-catalog-filtered.csv) to get our pixel-wise annotations as a [lable images](/content/gdrive/MyDrive/adleo/project_data/labels).


In [None]:
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


## Defining the Dataset for training, validating, and testing the model

In [None]:
# Define the datasets source and workingFolder source
src_dir = "/content/gdrive/MyDrive/adleo/project_data"

WorkingFolder = "/content/gdrive/MyDrive/adleo/project_data"


# Define the path for images, lables, and cataloge
# Image path
img_paths = list(Path(os.path.join(src_dir, "images")).glob("*.tif"))

# Label path
lbl_paths = list(Path(os.path.join(src_dir, "labels")).glob("*.tif"))

# Label cataloge
cataloge = pd.read_csv(os.path.join(src_dir, "label-catalog-filtered.csv"))

# Check if all paths are valid lists
if not all(isinstance(path_list, list) for path_list in (img_paths, lbl_paths)):
    raise ValueError("Both image_paths and label_paths must be lists.")

# Prints valid number of images and labels
print("No. of images:",len(img_paths), "\n",
      "No. of labels:", len(lbl_paths),"\n",
      "No. of rows in cataloge:", len(cataloge))



No. of images: 33873 
 No. of labels: 33756 
 No. of rows in cataloge: 33746


  cataloge = pd.read_csv(os.path.join(src_dir, "label-catalog-filtered.csv"))


## Pre-process Datasets:
1. `Resize images` to a standard size suitable for model.

2. `Normalize pixel values`(e.g., scale to range 0-1 or subtract mean).

3. `Image Augmentation`  (e.g., random flipping, cropping).


### Image Normalization

In [None]:
def min_max_normalize_image(image, dtype=np.float32):
    """
    image_path(str) : Absolute path to the image patch.
    dtype (numpy datatype) : data type of the normalized image default is
    "np.float32".
    """

    # Calculate the minimum and maximum values for each band
    min_values = np.nanmin(image, axis=(1, 2))[:, np.newaxis, np.newaxis]
    max_values = np.nanmax(image, axis=(1, 2))[:, np.newaxis, np.newaxis]

    # Normalize the image data to the range [0, 1]
    normalized_img = (image - min_values) / (max_values - min_values)

    # Return the normalized image data
    return normalized_img

### Image Augmentation

## Active dataset loading pipeline


In [None]:
class AquacultureData(Dataset):
    def __init__(self, src_dir, usage, dataset_name=None,
                 apply_normalization=False, transform=None, csv_name=None,
                 patch_size=None, overlap=None, catalog_index=None):
        r"""
        src_dir (str or path): Root of resource directory.
        dataset_name (str): Name of the training/validation dataset containing
                              structured folders for image, label
        usage (str): Either 'train' or 'validation'.
        transform (list): Each element is string name of the transformation to
            be used.
        """
        self.src_dir = src_dir
        self.dataset_name = dataset_name
        self.csv_name = csv_name
        self.apply_normalization = apply_normalization
        self.transform = transform
        self.patch_size = patch_size
        self.overlap = overlap

        self.usage = usage
        assert self.usage in ["train", "validation", "inference"], \
            "Usage is not recognized."

        if self.usage in ["train", "validation"]:
            assert self.dataset_name is not None
            img_dir = Path(src_dir) / self.dataset_name / self.usage / "bands"
            img_fnames = [Path(dirpath) / f
                          for (dirpath, dirnames, filenames) in os.walk(img_dir)
                          for f in filenames if f.endswith(".tif")]
            img_fnames.sort()

            lbl_dir = Path(src_dir) / self.dataset_name / self.usage / "labels"
            lbl_fnames = [Path(dirpath) / f
                          for (dirpath, dirnames, filenames) in os.walk(lbl_dir)
                          for f in filenames if f.endswith(".tif")]
            lbl_fnames.sort()

            self.img_chips = []
            self.lbl_chips = []

            for img_path, lbl_path in tqdm.tqdm(zip(img_fnames, lbl_fnames),
                                                total=len(img_fnames)):
                img_chip = load_data(
                    img_path, is_label=False,
                    apply_normalization=self.apply_normalization
                )
                img_chip = img_chip.transpose((1, 2, 0))

                lbl_chip = load_data(lbl_path, is_label=True)

                self.img_chips.append(img_chip)
                self.lbl_chips.append(lbl_chip)

            print('--------------{} patches cropped--------------'\
                  .format(len(self.img_chips)))

        # This part handles prediction dataset
        else:
            assert self.csv_name is not None

            ##### Add your code to read the "csv" file. (Expected 1 line)
            catalog = pd.read_csv(os.path.join(self.src_dir, self.csv_name))

            ##### use "iloc" and "catalog_index" to grab one line of catalog.
            ##### (Expected 1 line)
            self.catalog = catalog.iloc[catalog_index]

            self.tile = (self.catalog["wrs_path"], self.catalog["wrs_row"])

            img_path_ls = [self.catalog["img_dir"]]
            mask_path_ls = [self.catalog["mask_dir"]]

            self.meta = get_meta_from_bounds(Path(src_dir) / img_path_ls[0])

            half_size = self.patch_size // 2

            self.img_chips = []
            self.coor = []

            for img_path, mask_path in zip(img_path_ls, mask_path_ls):

                ###### Add your code to load the image and assign it to a
                ###### variable called "img".
                ###### Use the "load_data" function, provided in the utility
                ###### function. (Expected 1 line)
                img = load_data(os.path.join(self.src_dir, img_path),
                                is_label = False,
                                apply_normalization = self.apply_normalization)

                img = np.transpose(img, (1, 2, 0))

                ##### Load your mask again using "load_data" function.
                ##### (Expected 1 line)
                mask = load_data(os.path.join(self.src_dir, mask_path),
                                 is_label=True)

                crop_ref = mask

                index = patch_center_index(crop_ref, self.patch_size,
                                           self.overlap, self.usage)

                for i in range(len(index)):
                    x = index[i][0]
                    y = index[i][1]

                    self.img_chips.append(img[x - half_size: x + half_size,
                                              y - half_size: y + half_size, :])
                    self.coor.append([x, y])



            print('--------------{} patches cropped--------------'\
                  .format(len(self.img_chips)))


    def __getitem__(self, index):

        if self.usage in ["train", "validation"]:
            image_chip = self.img_chips[index]
            label_chip = self.lbl_chips[index]

            if self.usage == "train" and self.transform:
                trans_flip_ls = [m for m in self.transform if "flip" in m]
                if random.randint(0, 1) and len(trans_flip_ls) > 1:
                    trans_flip = random.sample(trans_flip_ls, 1)[0]
                    image_chip, label_chip = flip_image_and_label(
                        image_chip, label_chip, trans_flip
                    )

                if random.randint(0, 1) and "rotate" in self.transform:
                    img_chip, lbl_chip = rotate_image_and_label(
                        image_chip, label_chip, angle=[0,90]
                    )

            # Convert numpy arrays to torch tensors.
            # Image chips should be: CHW if not transpose to correct order of
            # dimensions.
            image_tensor = torch.from_numpy(image_chip.transpose((2, 0, 1)))\
                .float()
            label_tensor = torch.from_numpy(np.ascontiguousarray(label_chip))\
                .long()

            return image_tensor, label_tensor
        else:
            coor = self.coor[index]
            img_chip = self.img_chips[index]
            image_tensor = torch.from_numpy(img_chip.transpose((2, 0, 1)))\
                .float()

            return image_tensor, coor


    def __len__(self):
        return len(self.img_chips)

# Model Building
Deeplab3+ based on [Chen et al., 2024](https://link.springer.com/content/pdf/10.1007/s40747-023-01304-z.pdf)


# References


Chen LC et al., (2018)
Encoder-decoder with atrous separable convolution for semantic
image segmentation, In: Proceedings of the European conference
on computer vision (ECCV). 801–818