In [None]:
# General utilities
import os
import gc
import re
import json
import time
import shutil
import pickle

# Numerical and data handling
import numpy as np
import pandas as pd

# Image processing
import cv2

# Visualization
import matplotlib
import matplotlib.pyplot as plt

# PyTorch and related libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import _LRScheduler

# Torchvision
import torchvision
import torchvision.models as models
import torchvision.transforms as T

# Segmentation models
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss

# Evaluation
from torchmetrics import ConfusionMatrix

# Progress bar
from tqdm import tqdm

default_matplotlib_backend = matplotlib.get_backend()
print('imported')
print('default_matplotlib_backend: {}'.format(default_matplotlib_backend))
print(' numpy version', np.__version__) 

In [None]:
# Important, to have the same repartition of data between different machines
np.random.seed(42)
RUN_MODE = "RUN"
root_path = '' # path of the root project folder
weights_path =  ''
stats_path = ''

USE_2_GPUS = False

if USE_2_GPUS:
    # working device
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
    os.environ["CUDA_VISIBLE_DEVICES"]= "0,1" 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('Selected device: {}'.format(device))
else:
    # working device
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    print('Selected device: {}'.format(device))

if device == 'cuda':
        print('Device name: {}'.format(torch.cuda.get_device_name(device)))

encoders  = ['resnet34', 'timm-regnetx_160', 'timm-regnety_160', 'mobilenet_v2']

encoders_v2 = ['resnet34', 'resnext50_32x4d', 'resnet101', 'resnet152',
            'timm-regnetx_064', 'timm-regnety_064', 'resnext101_32x8d',
            'se_resnet50', 'se_resnet152', 'se_resnext101_32x4d']
decoders = ['unet', 'unet++', 'pspnet', 'deeplabv3+']
#decoders_v2 = ['unet', 'unet++', 'pspnet', 'deeplabv3+']
#encoders = ['timm-resnest26d', 'resnext50_32x4d']
decoders_v2 = ['deeplabv3+']

print(f"number of encoders {len(encoders) + len(encoders_v2)}")
print(f"number of decoders {len(decoders) + len(decoders_v2)}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Number of GPUS available: ", torch.cuda.device_count())

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torchvision.transforms.functional as TF
import random
import imgaug.augmenters as iaa
from imgaug import SegmentationMapsOnImage

train_valid_test_split_json_name = ''
with open(os.path.join(root_path, train_valid_test_split_json_name), 'r') as f:
    data_dict = json.load(f)


augmentations_list = [
                    iaa.Sharpen(alpha=(0, 0.5), lightness=(2.5, 3.0)),
                    iaa.GammaContrast((0.4, 0.9)),  # previous values iaa.GammaContrast((0.4, 0.9))
                    iaa.AddToHueAndSaturation((-40, 20)),
                    iaa.Multiply((1.1, 1.8))   # previous values iaa.Multiply((0.9, 1.1))
                    
                ]
zoomin_list = [
    iaa.Affine(scale=(1.2, 1.8)), 
    iaa.Crop(percent=(0.01, 0.3))
]

bg_images_folder = ""


# Custom Dataset class to load images and masks
class CustomDataset(Dataset):
    def __init__(self, image_paths, mask_paths, aug=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.aug = aug

    def __len__(self):
        return len(self.image_paths)

    def transform(self, image, mask):
        # Resize
        resize = transforms.Resize(size=(512, 512))
        

        if self.aug:
            if random.random() > 0.5:
            # Random horizontal flipping
                image = TF.hflip(image)
                mask = TF.hflip(mask)

            # Random vertical flipping
            if random.random() > 0.5:
                image = TF.vflip(image)
                mask = TF.vflip(mask)

            # Random rotating
            if random.random() > 0.5:
                angle = random.randrange(45, 270)
                image = TF.rotate(image, angle)
                mask = TF.rotate(mask, angle)

            if random.random() > 0.6:
                image_np = np.array(image)
                bg_images_folder = ""  # Update with your BG folder path
                bg_images_list = [
                    f for f in os.listdir(bg_images_folder)
                    if f.lower().endswith(('.png', '.jpg', '.jpeg'))
                ]

                if not bg_images_list:
                    raise ValueError("No valid background images found in the specified folder.")

                bg_image_choice = random.choice(bg_images_list)
                bg_image_path = os.path.join(bg_images_folder, bg_image_choice)

                # Read the background image
                background_image = cv2.imread(bg_image_path)

                if background_image is None:
                    raise ValueError(f"Failed to load background image: {bg_image_path}")

                # Resize the background image
                background_image = cv2.resize(background_image, (image_np.shape[1], image_np.shape[0]))

                # Combine with segmentation
                seg_image = np.where(mask == 0, 0, image_np)
                seg_image_with_background = np.where(seg_image != 0, 0, background_image)
                image = cv2.add(seg_image, seg_image_with_background)
                image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

            if random.random() > 0.5:
                image_np = np.array(image)
                selected_augmentation = random.choice(augmentations_list)
                image_np = selected_augmentation.augment_image(image_np)
                image = Image.fromarray(image_np)

            if random.random() > 0.5:
                image_np = np.array(image)
                mask_np = np.array(mask)
                
                selected_augmentation = random.choice(zoomin_list)
                aug = iaa.Sequential([selected_augmentation])
                mask_segmaps = SegmentationMapsOnImage(mask_np, shape=mask_np.shape)
                zoom_image_np, masks_aug = aug(image=image_np, segmentation_maps=mask_segmaps)
                zoom_mask_np = masks_aug.get_arr()
                
                image = Image.fromarray(zoom_image_np)
                mask = Image.fromarray(zoom_mask_np)

        
        image = resize(image)
        mask = resize(mask)
        
        # Transform to tensor
        image = TF.to_tensor(image)
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        image = TF.normalize(image, mean=mean, std=std)
        mask = TF.to_tensor(mask)
        return image, mask

    def __getitem__(self, idx):

        image = Image.open(self.image_paths[idx])
        mask = Image.open(self.mask_paths[idx]).convert('L')
        image, mask = self.transform(image, mask)
        return image, mask


# Base directory where your images and masks are stored
base_dir = ""
base_dir_label = ""

# Function to gather file paths based on the dictionary
def get_file_paths(data_dict, base_dir, base_dir_label):
    train_image_paths, val_image_paths, test_image_paths = [], [], []
    train_mask_paths, val_mask_paths, test_mask_paths = [], [], []

    for dataset, splits in data_dict.items():
        for rep, split in splits.items():
            for f in os.listdir(os.path.join(base_dir, dataset, rep)):
                image_path = os.path.join(base_dir, dataset, rep, f)
                mask_path = os.path.join(base_dir_label, dataset, rep, "masks", f.replace('.png', '_mask.png'))

                if split == 'train':
                    train_image_paths.append(image_path)
                    train_mask_paths.append(mask_path)
                elif split == 'valid':
                    val_image_paths.append(image_path)
                    val_mask_paths.append(mask_path)
                elif split == 'test':
                    test_image_paths.append(image_path)
                    test_mask_paths.append(mask_path)

    return train_image_paths, train_mask_paths, val_image_paths, val_mask_paths, test_image_paths, test_mask_paths

print('Loading data ...')
train_image_paths, train_mask_paths, val_image_paths, val_mask_paths, test_image_paths, test_mask_paths = get_file_paths(data_dict, base_dir, base_dir_label)


print('Number of training images: {}'.format(len(train_image_paths)))
print('Number of validation images: {}'.format(len(val_image_paths)))
print('Number of test images: {}'.format(len(test_image_paths)))

In [None]:
# Define the mean and std used for normalization (example values)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def unnormalize(image, mean, std):
    """
    Unnormalize the image.

    Args:
    - image (torch tensor): The normalized image tensor.

    Returns:
    - unnormalized_image (numpy array): The unnormalized image.
    """
    for t, m, s in zip(image, mean, std):
        t.mul_(s).add_(m)  # Unnormalize
    return image

def plot_image_and_mask(image, mask):
    """
    Plots the given image and mask side by side.

    Args:
    - image (torch tensor): The input image.
    - mask (torch tensor): The corresponding mask.
    """
    if torch.is_tensor(image):
        image = unnormalize(image, mean, std)
        image = image.permute(1, 2, 0).numpy()  # Convert from CxHxW to HxWxC
        # image = image[...,::-1]
        image = np.clip(image, 0, 1)  # Ensure values are in the range [0, 1]
    if torch.is_tensor(mask):
        mask = mask.permute(1, 2, 0).numpy()

    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(mask, cmap='gray')
    plt.title('Mask')
    plt.axis('off')

    plt.show()


In [None]:
# Create datasets
AUGMENTATION = False
train_dataset = CustomDataset(train_image_paths, train_mask_paths, aug=AUGMENTATION)
val_dataset = CustomDataset(val_image_paths, val_mask_paths, aug=False)
test_dataset = CustomDataset(test_image_paths, test_mask_paths, aug=False)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print('Data loaders created')

### Data sample (Run this only if you want to see an example of the data)


In [None]:
# Example usage
# Assuming you have a DataLoader named train_loader
data_iter = iter(train_loader)
images, masks = next(data_iter)

# Plot the first image and mask from the batch
plot_image_and_mask(images[0], masks[0])

### LR Scheduler and utility functions

In [None]:
class PolyScheduler(_LRScheduler):
    def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
        self.base_lr = base_lr
        self.warmup_lr_init = 0.0001
        self.max_steps: int = max_steps
        self.warmup_steps: int = warmup_steps
        self.power = 2
        super(PolyScheduler, self).__init__(optimizer, -1, False)
        self.last_epoch = last_epoch

    def get_warmup_lr(self):
        alpha = float(self.last_epoch) / float(self.warmup_steps)
        return [self.base_lr * alpha for _ in self.optimizer.param_groups]

    def get_lr(self):
        if self.last_epoch == -1:
            return [self.warmup_lr_init for _ in self.optimizer.param_groups]
        if self.last_epoch < self.warmup_steps:
            return self.get_warmup_lr()
        else:
            alpha = pow(
                1
                - float(self.last_epoch - self.warmup_steps)
                / float(self.max_steps - self.warmup_steps),
                self.power,
            )
            return [self.base_lr * alpha for _ in self.optimizer.param_groups]

def plot_model_stats(model_name, train_loss, train_acc, train_iou, test_loss, test_acc, test_iou):
    fig = plt.figure(figsize=(16, 5))
    ax1 = plt.subplot(1, 3, 1)
    plt.plot(np.arange(len(train_loss)), train_loss, label = 'Train loss')
    plt.plot(np.arange(len(test_loss)), test_loss, label = 'Val loss')
    plt.title('Model: {} - Validation loss: {:.4f}'.format(model_name, test_loss[-1]))
    ax1.set_ylabel("Loss")
    ax1.set_xlabel("Epochs")
    ax1.legend(loc='upper right')

    ax2 = plt.subplot(1, 3, 2)
    plt.plot(np.arange(len(train_acc)), np.array(train_acc) * 100, color='green', label = 'Train accuracy')
    plt.plot(np.arange(len(test_acc)), np.array(test_acc) * 100, color='red', label = 'Val accuracy')
    plt.title('Model: {} - Validation accuracy: {:.2f} %'.format(model_name, test_acc[-1] * 100))
    ax2.set_ylabel("Accuracy")
    ax2.set_xlabel("Epochs")
    ax2.legend(loc='lower right')

    ax2 = plt.subplot(1, 3, 2)
    plt.plot(np.arange(len(train_iou)), np.array(train_iou) * 100, color='green', label = 'Train accuracy')
    plt.plot(np.arange(len(test_iou)), np.array(test_iou) * 100, color='red', label = 'Val accuracy')
    plt.title('Model: {} - Validation accuracy: {:.2f} %'.format(model_name, test_acc[-1] * 100))
    ax2.set_ylabel("IOU")
    ax2.set_xlabel("Epochs")
    ax2.legend(loc='lower right')

    plt.show()

def load_model_weights(BACKBONE, OPTIMIZER):
    # this function loads the whole model with weights
    pth = os.path.join(weights_path, 'backbone_{}_{}.pth'.format(BACKBONE, OPTIMIZER))
    assert os.path.exists(pth), 'Configuration not found'

    model = torch.load(pth).to(device)
    print('Backbone weights loaded: {}'.format(BACKBONE))

    return model

def load_model_stats(BACKBONE, OPTIMIZER):
    # this function loads training stats (train/val loss and acc)
    stats_file = os.path.join(stats_path, 'stats_{}_{}.pkl'.format(BACKBONE, OPTIMIZER))
    assert os.path.exists(stats_file), 'Configuration not found'

    with open(stats_file, 'rb') as stats:
        stats = pickle.load(stats)
    print('Train/Validation stats loaded for the model: {}'.format(BACKBONE))

    return stats

## Training and evaluation functions

In [None]:
from pytorch_utils import training_utils as pt_train

# Dice Loss
dice_loss = DiceLoss(mode='binary')
SMOOTH = 1e-6
# Helper functions for evaluation and accuracy calculation
def _evaluate(model: torch.nn.Module, val_loader: DataLoader, device=device):
    model.eval()
    outputs = []
    with torch.no_grad():
        for batch in val_loader:
            batch = [tensor.to(device).float() for tensor in batch]
            if USE_2_GPUS:
                outputs.append(model.module.validation_step(batch, model))
            else:
                outputs.append(model.validation_step(batch, model))
    if USE_2_GPUS:
        return model.module.validation_epoch_end(outputs)
    else:
        return model.validation_epoch_end(outputs)

def _accuracy(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")

def _iou(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

def _iou_plant(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    iou_class_1 = smp.metrics.iou_score(tp[1], fp[1], fn[1], tn[1], reduction="micro")
    return iou_class_1

def precision_score(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.positive_predictive_value(tp, fp, fn, tn, reduction="micro")

def recall_score(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")

def f1_score(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")

def specificity(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.specificity(tp, fp, fn, tn, reduction="micro")

def false_positive_rate(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.false_positive_rate(tp, fp, fn, tn, reduction="micro")

def false_negative_rate(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.false_negative_rate(tp, fp, fn, tn, reduction="micro")

def false_discovery_rate(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.false_discovery_rate(tp, fp, fn, tn, reduction="micro")

def false_omission_rate(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.false_omission_rate(tp, fp, fn, tn, reduction="micro")

def false_omission_rate(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    return smp.metrics.false_omission_rate(tp, fp, fn, tn, reduction="micro")

def _handlezero_division_np(a,b):
    # initialize output tensor with desired value
    # c = torch.zeros_like(a)
    #c = torch.full_like(a, fill_value=float('nan'))
    # zero mask
    c = np.zeros_like(a)
    mask = (b != 0)
    # finally perform division
    c[mask] = a[mask] / b[mask]
    return c

def mathews_correlation_coefficient_np(tp, fp, fn, tn, eps=1e-11):
    tp = tp.sum().astype(np.float64)
    tn = tn.sum().astype(np.float64)
    fp = fp.sum().astype(np.float64)
    fn = fn.sum().astype(np.float64)
    _numerator = (tp*tn - fp*fn)
    _denomerator = np.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn))
    x = _numerator / (_denomerator + eps)
    x = _handlezero_division_np(_numerator, _denomerator)
    return x

class MCC_Loss(nn.Module):
    """
    Calculates the proposed Matthews Correlation Coefficient-based loss.

    Args:
        inputs (torch.Tensor): 1-hot encoded predictions
        targets (torch.Tensor): 1-hot encoded ground truth
    """

    def __init__(self):
        super(MCC_Loss, self).__init__()

    def forward(self, inputs, targets):
        """
        MCC = (TP.TN - FP.FN) / sqrt((TP+FP) . (TP+FN) . (TN+FP) . (TN+FN))
        where TP, TN, FP, and FN are elements in the confusion matrix.
        """
        tp, fp, fn, tn = smp.metrics.get_stats(inputs.long(), targets.long(), mode="binary")
        numerator = torch.mul(tp, tn) - torch.mul(fp, fn)
        denominator = torch.sqrt(
            torch.add(tp, 1, fp)
            * torch.add(tp, 1, fn)
            * torch.add(tn, 1, fp)
            * torch.add(tn, 1, fn)
        )

        # Adding 1 to the denominator to avoid divide-by-zero errors.
        # print(denominator.sum(), "hoho -----------")
        if denominator.sum() == 0.0:
            return torch.tensor(0.0, dtype=torch.float32)  # Return 0 to avoid division by zero
        mcc = torch.div(numerator.sum(), denominator.sum() + 1.0)
        # print(mcc ,"------------------------")
        return 1 - mcc

def mcc_cal_old(outputs: torch.Tensor, labels: torch.Tensor):
    mcc_loss = MCC_Loss()
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    return 1 - mcc_loss(outputs.float(), labels)

def mcc_cal(outputs: torch.Tensor, labels: torch.Tensor):
    prob_mask = outputs.sigmoid()
    pred_mask = (prob_mask > 0.5).float()
    tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
    tp, fp, fn, tn = tp.cpu().numpy(), fp.cpu().numpy(), fn.cpu().numpy(), tn.cpu().numpy()
    return mathews_correlation_coefficient_np(tp, fp, fn, tn)

# Funtion to create model
# class CustomModelBase(pt_train.CustomModelBase):
class CustomModelBase(torch.nn.Module):
    """
    ModelBase override for training and validation steps
    """
    def __init__(self, class_weights=None, loss_function=dice_loss, accuracy_function=_accuracy, iou_function=_iou):
        super(CustomModelBase, self).__init__()
        self.class_weights = class_weights
        self.loss_function = loss_function
        self.accuracy_function = accuracy_function
        self.iou_function = iou_function

    def training_step(self, batch: list, forward_func: torch.nn.Module):
        # using outer forward function as it is different on DataParallelization usage 
        images, labels = batch
        out = forward_func(images)
        loss = self.loss_function(out, labels)
        acc = self.accuracy_function(out, labels)
        iou = self.iou_function(out, labels)
        return loss, acc, iou

    def validation_step(self, batch: list, forward_func: torch.nn.Module):
        # using outer forward function as it is different on DataParallelization usage
        images, labels = batch
        out = forward_func(images)
        loss = self.loss_function(out, labels)
        acc = self.accuracy_function(out, labels)
        iou = self.iou_function(out, labels)
        return {'val_loss': loss.detach(), 'val_acc': acc, 'val_iou': iou}

    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()
        batch_ious = [x['val_iou'] for x in outputs]
        epoch_iou = torch.stack(batch_ious).mean()
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item(), 'val_iou': epoch_iou.item()}

    def epoch_end(self, epoch, result):
        print(
            f"train_loss: {result['train_loss']:.4f}, val_loss: {result['val_loss']:.4f}\n"
            f"train_acc: {result['train_acc']:.4f}, val_acc: {result['val_acc']:.4f}\n"
            f"train_iou: {result['train_iou']:.4f}, val_iou: {result['val_iou']:.4f}"
        )
        print()

class CreateModel(CustomModelBase):
    def __init__(self, model):
        super(CreateModel, self).__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)


def get_model(encoder_name, decoder_name):
    if decoder_name == 'unet':
        model = smp.Unet(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    elif decoder_name == 'unet++':
        model = smp.UnetPlusPlus(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    elif decoder_name == 'pspnet':
        model = smp.PSPNet(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    elif decoder_name == 'deeplabv3+':
        model = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    elif decoder_name == 'deeplabv3':
        model = smp.DeepLabV3(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    elif decoder_name == 'pan':
        model = smp.PAN(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    return CreateModel(model)

In [None]:
def save_results_and_plots(encoder_name, decoder_name, train_loss_history, train_acc_history, train_iou_history, val_loss_history, val_acc_history, val_iou_history):
    results_df = pd.DataFrame({
        'Epoch': list(range(1, len(train_loss_history) + 1)),
        'Train Loss': train_loss_history,
        'Val Loss': val_loss_history,
        'Train Accuracy': train_acc_history,
        'Val Accuracy': val_acc_history,
        'Train IOU': train_iou_history,
        'Val IOU': val_iou_history
    })

    results_df.to_excel(f'results_{encoder_name}_{decoder_name}.xlsx', index=False)

    plt.figure()
    plt.plot(results_df['Epoch'], results_df['Train Loss'], label='Train Loss')
    plt.plot(results_df['Epoch'], results_df['Val Loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Loss Curves for {encoder_name} with {decoder_name}')
    plt.savefig(f'loss_curves_{encoder_name}_{decoder_name}.png')

    plt.figure()
    plt.plot(results_df['Epoch'], results_df['Train Accuracy'], label='Train Accuracy')
    plt.plot(results_df['Epoch'], results_df['Val Accuracy'], label='Val Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(f'Accuracy Curves for {encoder_name} with {decoder_name}')
    plt.savefig(f'accuracy_curves_{encoder_name}_{decoder_name}.png')

    plt.figure()
    plt.plot(results_df['Epoch'], results_df['Train IOU'], label='Train IOU')
    plt.plot(results_df['Epoch'], results_df['Val IOU'], label='Val IOU')
    plt.xlabel('Epoch')
    plt.ylabel('IOU')
    plt.legend()
    plt.title(f'IOU Curves for {encoder_name} with {decoder_name}')
    plt.savefig(f'iou_curves_{encoder_name}_{decoder_name}.png')

    # metrics_df = pd.DataFrame([metrics])
    # metrics_df.to_excel(f'metrics_{encoder_name}_{decoder_name}.xlsx', index=False)


In [None]:
import pytorch_utils.callbacks as pt_callbacks
import pytorch_utils.training_utils as pt_train

from torch.optim import Adam as adam_opt
import traceback

def fit(
        epochs: int,
        lr: float,
        weight_decay: float,
        model: torch.nn.Module,
        train_loader: torch.utils.data.DataLoader,
        val_loader: torch.utils.data.DataLoader,
        callbacks_function=None,
        continue_training=False,
        opt_func=adam_opt,
        device=device,
        num_retries_inner=10,
        max_retry=10,
        evaluate=_evaluate
):
    """
    Meant to resemble the fit function in keras.

    Parameters
    ----------
    epochs - Set this to a high number and use callbacks to stop training early
    lr - Initial learning rate in case of a scheduler
    weight_decay - Weight decay to be fed to the optimizer
    model - The model to train. Must inherit from CustomModelBase of this module
    train_loader - The training data loader
    val_loader - The validation data loader
    callbacks_function - A function that takes the model and returns a list of callbacks
    opt_func - The optimizer function to use
    device - The device to use
    num_retries_inner - Number of times to retry training if it fails
    max_retry - Maximum number of times to retry training if anything other than training_step fails, like dataloader
    evaluate - The function to use for evaluation, defaults to _evaluate()

    Returns
    -------
    history - A list of dictionaries containing the loss and accuracy for each epoch

    """

    if weight_decay is not None:
        optimizer = opt_func(model.parameters(), lr, weight_decay=weight_decay)
    else:
        optimizer = opt_func(model.parameters(), lr)

    # model.to(device)
    defined_callbacks = None  # must be None so that it can be defined in the function when it is called for the first time
    num_retry = 0
    history = []

    for epoch in range(epochs):
        model.train()  # Make sure the model is in training mode at each epoch, because it is set to eval() in evaluate()
        train_losses = []
        accuracies = []
        ious = []
        print("LR: ", optimizer.param_groups[0]['lr'])
        # Wrap the train_loader with tqdm to create a progress bar

        while num_retry < max_retry:
            try:
                progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", delay=1)

                for batch in progress_bar:
                    batch = [tensor.to(device).float() for tensor in batch]

                    # run the training step many times until it works
                    flag = False
                    for i in range(num_retries_inner):
                        try:
                            if USE_2_GPUS:
                                loss, acc, iou = model.module.training_step(batch, model)
                            else:
                                loss, acc, iou = model.training_step(batch, model)
                                
                            flag = True
                            break
                        except:
                            if i == num_retries_inner - 1:
                                traceback.print_exc()

                            # try cleaning the cache
                            torch.cuda.empty_cache()
                            gc.collect()

                    if not flag:
                        raise RuntimeError(f"Training step failed {num_retries_inner} times")

                    train_losses.append(loss)
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()

                    accuracies.append(acc)
                    ious.append(iou)

                    # rounded_loss = round(loss.item(), 3)

                    # progress_bar.set_postfix(loss=rounded_loss, accuracy=acc.item())
                    # Update the progress bar with the current loss and accuracy
                    progress_bar.set_postfix(loss=loss.item(), accuracy=acc.item())

                num_retry = 0
                break
            except:
                # try cleaning the cache
                torch.cuda.empty_cache()
                gc.collect()

                num_retry += 1
                if num_retry < max_retry:
                    continue
                else:
                    traceback.print_exc()
                    raise RuntimeError(f"Training failed {max_retry} times")

        result = evaluate(model, val_loader, device)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['train_acc'] = torch.stack(accuracies).mean().item()
        result['train_iou'] = torch.stack(ious).mean().item()

        if callbacks_function is not None:
            defined_callbacks, stop_flag = callbacks_function(
                optimiser=optimizer,
                result=result,
                model=model,
                defined_callbacks=defined_callbacks,
                continue_training=continue_training,
            )
            if USE_2_GPUS:
                model.module.epoch_end(epoch, result)
            else:
                model.epoch_end(epoch, result)
            history.append(result)

            if stop_flag:
                print("Early stopping triggered")
                break

    return history


def train_loop(
        model,
        encoder,
        decoder,
        optimizer,
        epochs,
        train_loader,
        val_loader,
        device=device,
        initial_lr=None,
        weight_decay=None,
        verbose=False,
        save=False
):
    def get_result_list(history, metric):
        return [history[i][metric] for i in range(len(history))]

    # Train the model using torch

    def get_callbacks(
        # model_save_path,
        optimiser,
        result,
        model,
        defined_callbacks=None,
        continue_training=False,
        other_stats=None):

        model_save_path = '/'.format(encoder, decoder)
        os.makedirs(model_save_path, exist_ok=True)

        if defined_callbacks is None:
            defined_callbacks = {
                'val': pt_callbacks.Callbacks(optimizer=optimiser,
                                              model_save_path=model_save_path + 'best_model.pth',
                                              training_stats_path=model_save_path + 'training_stats_val',
                                              continue_training=continue_training),

                'train': pt_callbacks.Callbacks(optimizer=optimiser,
                                                training_stats_path=model_save_path + 'training_stats_train',
                                                continue_training=continue_training)
            }

        defined_callbacks['val'].reduce_lr_on_plateau(
            monitor_value=result["val_iou"],
            mode='max',
            factor=0.5,
            patience=5,
            indicator_text="Val LR scheduler: "
        )
        defined_callbacks['val'].model_checkpoint(
            model=model,
            monitor_value=result["val_iou"],
            mode='max',
            indicator_text="Val checkpoint: "
        )
        stop_flag = defined_callbacks['val'].early_stopping(
            monitor_value=result["val_iou"],
            mode='max',
            patience=20,
            indicator_text="Early stopping: "
        )
        defined_callbacks['val'].clear_memory()
        print("_________")

        return defined_callbacks, stop_flag

    def get_callbacks_no_save(
            # model_save_path,
            optimiser,
            result,
            model,
            defined_callbacks=None,
            continue_training=False,
            other_stats=None):

        model_save_path = ''

        if defined_callbacks is None:
            defined_callbacks = {
                'val': pt_callbacks.Callbacks(optimizer=optimiser,
                                              model_save_path=model_save_path + 'best_model.pth',
                                              training_stats_path=model_save_path + 'training_stats_val',
                                              continue_training=continue_training),

                'train': pt_callbacks.Callbacks(optimizer=optimiser,
                                                training_stats_path=model_save_path + 'training_stats_train',
                                                continue_training=continue_training)
            }

        defined_callbacks['val'].reduce_lr_on_plateau(
            monitor_value=result["val_iou"],
            mode='max',
            factor=0.5,
            patience=5,
            indicator_text="Val LR scheduler: "
        )
        defined_callbacks['val'].model_checkpoint(
            model=model,
            monitor_value=result["val_iou"],
            mode='max',
            indicator_text="Val checkpoint: "
        )
        stop_flag = defined_callbacks['val'].early_stopping(
            monitor_value=result["val_iou"],
            mode='max',
            patience=20,
            indicator_text="Early stopping: "
        )
        defined_callbacks['val'].clear_memory()
        print("_________")

        return defined_callbacks, stop_flag

    if not save:
        history = fit(
            epochs=epochs,
            lr=initial_lr,
            weight_decay=weight_decay,
            model=model,
            device=device,
            callbacks_function=get_callbacks_no_save,
            train_loader=train_loader,
            val_loader=val_loader,
            opt_func=optimizer,
        )
        del model
        model_save_path = ''
        os.makedirs(model_save_path, exist_ok=True)
        # load the best model from checkpoint
        model = torch.load(model_save_path + "best_model.pth")
    else:
        history = fit(
            epochs=epochs,
            lr=initial_lr,
            weight_decay=weight_decay,
            model=model,
            device=device,
            callbacks_function=get_callbacks,
            train_loader=train_loader,
            val_loader=val_loader,
            opt_func=optimizer,
        )

        del model
        model_save_path = ''.format(encoder, decoder)
        os.makedirs(model_save_path, exist_ok=True)
        # load the best model from checkpoint
        model = torch.load(model_save_path + "best_model.pth")

    train_loss_history = get_result_list(history, "train_loss")
    train_acc_history = get_result_list(history, "train_acc")
    train_iou_history = get_result_list(history, "train_iou")
    val_loss_history = get_result_list(history, "val_loss")
    val_acc_history = get_result_list(history, "val_acc")
    val_iou_history = get_result_list(history, "val_iou")

    return model, train_loss_history, train_acc_history, train_iou_history, val_loss_history, val_acc_history, val_iou_history


def train_model(model, encoder, decoder, train_loader, val_loader, device, kwargs):

    print(kwargs)    y_true = torch.cat(y_true)
    y_pred = torch.cat(y_pred)
    y_pred_binary = (y_pred > 0.5).float()
    y_true_binary = (y_true > 0.5).float()

    epochs = kwargs.get("epochs")
    
    model = model.to(device)
    save = kwargs.get("save")
    # get optimizer
    optim_args = kwargs.get("optim")
    print(optim_args["params"])
    optimizer_cls = get_optimizer_by_name(optim_args.get("name"))
    optimizer = optimizer_cls(model.parameters(), **optim_args["params"])
    lr = optim_args["params"]["lr"]
    weight_decay = optim_args["params"]["weight_decay"]
    model, train_loss_history, train_acc_history, train_iou_history, val_loss_history, val_acc_history, val_iou_history = train_loop(model, encoder, decoder, optimizer_cls, epochs, train_loader, val_loader, device=device, initial_lr=lr, weight_decay=weight_decay, verbose=True, save=save)

    return model, train_loss_history, train_acc_history, train_iou_history, val_loss_history, val_acc_history, val_iou_history

mcc_loss = MCC_Loss()

def evaluate_model(model, test_loader, device, verbose=True, eps=1e-10):
    if verbose:
        print('--------------------------------------------')
        print('Test metrics (on test set)')

    model.eval()

    y_true = []
    y_pred = []

    with torch.no_grad():
        for inputs, masks in test_loader:
            inputs, masks = inputs.to(device).float(), masks.to(device).float()
            outputs = model(inputs)
            outputs = outputs.sigmoid()
            y_true.append(masks.cpu())
            y_pred.append(outputs.cpu())

    y_true = torch.cat(y_true)
    y_pred = torch.cat(y_pred)
    metrics = {
        'Overall Accuracy': round(_accuracy(y_true, y_pred).item() * 100, 2),
        'Precision': round(precision_score(y_true, y_pred).item(), 2),
        'Recall': round(recall_score(y_true, y_pred).item(), 2),
        'F1 Score': round(f1_score(y_true, y_pred).item(), 2),
        'Specificity': round(recall_score(y_true, y_pred).item(), 2),
        'False Positive Rate': round(false_positive_rate(y_true, y_pred).item(), 4),
        'False Negative Rate': round(false_negative_rate(y_true, y_pred).item(), 4),
        'False Discovery Rate': round(false_discovery_rate(y_true, y_pred).item(), 4),
        'False Omission Rate': round(false_omission_rate(y_true, y_pred).item(), 7),
        'Misclassification Rate': round(1 - _accuracy(y_true, y_pred).item(), 2),
        'Matthew\'s Correlation Coefficient': round(mcc_cal(y_true, y_pred).item(), 2),
        'IoU (Jaccard Index)': round(_iou(y_true, y_pred).item(), 2),
        'IoU_plant': round(_iou_plant(y_true, y_pred).item(), 2),
        'Geometric Mean': round(np.sqrt(precision_score(y_true, y_pred).item() * recall_score(y_true, y_pred).item()), 2)
    }
    print('--------------------------------------------')
    print()

    return metrics


## Train non-transformer modles

### Training loop

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load existing results from Excel file if it exists
excel_file = ''

kwargs = {'epochs': 10, 'optim': {'name': 'Adam', 'params': {
     'lr': 0.01, 'weight_decay': 3.310305423548208e-05}}, 'save': False}

try:
    existing_df = pd.read_excel(excel_file)
    results = existing_df.to_dict(orient='list')
    processed_models = list(existing_df.apply(lambda row: f"{row['Encoder']}--{row['Decoder']}", axis=1))
except FileNotFoundError:
    print("Experiment from the begining!!!")
    processed_models = []
# Initialize lists to store metrics
    results = {
        'Encoder': [],
        'Decoder': [],
        'Train Loss': [],
        'Train Accuracy': [],
        'Train IoU': [],
        'Val Loss': [],
        'Val Accuracy': [],
        'Val IoU': [],
        'Overall Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1 Score': [],
        'Specificity': [],
        'False Positive Rate': [],
        'False Negative Rate': [],
        'False Discovery Rate': [],
        'False Omission Rate': [],
        'Misclassification Rate': [],
        'Matthew\'s Correlation Coefficient': [],
        'IoU (Jaccard Index)': [],
        'IoU_plant' : [],
        'Geometric Mean': []
    }

for encoder_name in encoders:
    for decoder_name in decoders:
        model_info = f"{encoder_name}--{decoder_name}"
        if model_info in processed_models:
            print(f'Skipping {model_info}, already processed.')
            continue
        print(f'Training {encoder_name} with {decoder_name}')
        
        if USE_2_GPUS:
            model = get_model(encoder_name, decoder_name).cuda()
            # model = nn.DataParallel(model).to(device)
            model = nn.DataParallel(model, device_ids=[0,1])
        else:
            model = get_model(encoder_name, decoder_name).to(device)

        # Train the model and collect metrics
        model, train_loss_history, train_acc_history, train_iou_history, val_loss_history, val_acc_history, val_iou_history = train_model(model, encoder_name, decoder_name, train_loader, val_loader, device, kwargs)
        # save_results_and_plots(encoder_name, decoder_name, train_loss_history, train_acc_history, train_iou_history, val_loss_history, val_acc_history, val_iou_history)
        ## show plots
        results_df = pd.DataFrame({
            'Epoch': list(range(1, len(train_loss_history) + 1)),
            'Train Loss': train_loss_history,
            'Val Loss': val_loss_history,
            'Train Accuracy': train_acc_history,
            'Val Accuracy': val_acc_history,
            'Train IOU': train_iou_history,
            'Val IOU': val_iou_history
        })
        plt.figure()
        plt.plot(results_df['Epoch'], results_df['Train Loss'], label='Train Loss')
        plt.plot(results_df['Epoch'], results_df['Val Loss'], label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title(f'Loss Curves for {encoder_name} with {decoder_name}')
        plt.savefig(f'loss_curves_{encoder_name}_{decoder_name}.png')
        plt.show()

        plt.figure()
        plt.plot(results_df['Epoch'], results_df['Train Accuracy'], label='Train Accuracy')
        plt.plot(results_df['Epoch'], results_df['Val Accuracy'], label='Val Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.title(f'Accuracy Curves for {encoder_name} with {decoder_name}')
        plt.savefig(f'accuracy_curves_{encoder_name}_{decoder_name}.png')
        plt.show()

        plt.figure()
        plt.plot(results_df['Epoch'], results_df['Train IOU'], label='Train IOU')
        plt.plot(results_df['Epoch'], results_df['Val IOU'], label='Val IOU')
        plt.xlabel('Epoch')
        plt.ylabel('IOU')
        plt.legend()
        plt.title(f'IOU Curves for {encoder_name} with {decoder_name}')
        plt.savefig(f'iou_curves_{encoder_name}_{decoder_name}.png')
        plt.show()

        # Evaluate on the test set
        test_metrics = evaluate_model(model, test_loader, device=device)
        for key, value in test_metrics.items():
            if key in ["False Discovery Rate", "False Negative Rate", "False Positive Rate"]:
                print(f'"{key}": {value:.4f},')
            elif key == "False Omission Rate":
                print(f'"{key}": {value:.7f},')
            else:
                print(f'"{key}": {value:.2f},')

        # Save metrics to results dictionary
        results['Encoder'].append(encoder_name)
        results['Decoder'].append(decoder_name)
        results['Train Loss'].append(np.min(train_loss_history))
        results['Train Accuracy'].append(np.max(train_acc_history))
        results['Train IoU'].append(np.max(train_iou_history))
        results['Val Loss'].append(np.min(val_loss_history))
        results['Val Accuracy'].append(np.max(val_acc_history))
        results['Val IoU'].append(np.max(val_iou_history))
        for metric_name, metric_value in test_metrics.items():
            results[metric_name].append(metric_value)

        
        # Convert results to pandas DataFrame
        df_new = pd.DataFrame(results)
        
        # Load existing DataFrame
        try:
            df_existing = pd.read_excel(excel_file)
            df_combined = pd.concat([df_existing, df_new]).drop_duplicates().reset_index(drop=True)
        except FileNotFoundError:
            df_combined = df_new
        
        # Save DataFrame to Excel
        df_combined.to_excel(excel_file, index=False)
        
        print(f'Results saved to {excel_file}')


### Train 2 best models + save pt files + test

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Enter your best encoder-decoder combination here
best_models = ['inceptionv4--unet']

kwargs = {'epochs': 10000, 'optim': {'name': 'Adam', 'params': {
     'lr': 0.01, 'weight_decay': 3.310305423548208e-05}}, 'save': True}

# Load existing results from Excel file if it exists
excel_file = ''

try:
    existing_df = pd.read_excel(excel_file)
    results = existing_df.to_dict(orient='list')
    processed_models = list(existing_df.apply(lambda row: f"{row['Encoder']}--{row['Decoder']}", axis=1))
except FileNotFoundError:
    print("Experiment from the begining!!!")
    processed_models = []
# Initialize lists to store metrics
    results = {
        'Encoder': [],
        'Decoder': [],
        'Train Loss': [],
        'Train Accuracy': [],
        'Train IoU': [],
        'Val Loss': [],
        'Val Accuracy': [],
        'Val IoU': [],
        'Overall Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1 Score': [],
        'Specificity': [],
        'False Positive Rate': [],
        'False Negative Rate': [],
        'False Discovery Rate': [],
        'False Omission Rate': [],
        'Misclassification Rate': [],
        'Matthew\'s Correlation Coefficient': [],
        'IoU (Jaccard Index)': [],
        'IoU_plant' : [],
        'Geometric Mean': []
    }

for model_info in best_models:
    if model_info in processed_models:
        print(f'Skipping {model_info}, already processed.')
        continue
    encoder_name, decoder_name = model_info.split('--')
    print(f'Training {encoder_name} with {decoder_name}')
    if USE_2_GPUS:
        model = get_model(encoder_name, decoder_name)
        model = nn.DataParallel(model).to(device)
    else:
        model = get_model(encoder_name, decoder_name).to(device)

    # # Train the model and collect metrics
    model, train_loss_history, train_acc_history, train_iou_history, val_loss_history, val_acc_history, val_iou_history = train_model(model, encoder_name, decoder_name, train_loader, val_loader, device, kwargs)

    # # Load the best model state
    # model.load_state_dict(torch.load(f'best_model_{encoder_name}_{decoder_name}.pth'))
    # model.load_state_dict(torch.load(f'/content/model/{encoder_name}-{decoder_name}/model.pth'))
    ## show plots
    results_df = pd.DataFrame({
        'Epoch': list(range(1, len(train_loss_history) + 1)),
        'Train Loss': train_loss_history,
        'Val Loss': val_loss_history,
        'Train Accuracy': train_acc_history,
        'Val Accuracy': val_acc_history,
        'Train IOU': train_iou_history,
        'Val IOU': val_iou_history
    })
    plt.figure()
    plt.plot(results_df['Epoch'], results_df['Train Loss'], label='Train Loss', marker=11)
    plt.plot(results_df['Epoch'], results_df['Val Loss'], label='Val Loss', marker=11)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Loss Curves for {encoder_name} with {decoder_name}')
    plt.savefig(f'loss_curves_{encoder_name}_{decoder_name}.png')
    plt.show()

    plt.figure()
    plt.plot(results_df['Epoch'], results_df['Train Accuracy'], label='Train Accuracy', marker=11)
    plt.plot(results_df['Epoch'], results_df['Val Accuracy'], label='Val Accuracy', marker=11)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(f'Accuracy Curves for {encoder_name} with {decoder_name}')
    plt.savefig(f'accuracy_curves_{encoder_name}_{decoder_name}.png')
    plt.show()

    plt.figure()
    plt.plot(results_df['Epoch'], results_df['Train IOU'], label='Train IOU', marker=11)
    plt.plot(results_df['Epoch'], results_df['Val IOU'], label='Val IOU', marker=11)
    plt.xlabel('Epoch')
    plt.ylabel('IOU')
    plt.legend()
    plt.title(f'IOU Curves for {encoder_name} with {decoder_name}')
    plt.savefig(f'iou_curves_{encoder_name}_{decoder_name}.png')
    plt.show()

    x = model

    # Evaluate on the test set
    test_metrics = evaluate_model(x, test_loader, device=device)
    for key, value in test_metrics.items():
        if key in ["False Discovery Rate", "False Negative Rate", "False Positive Rate"]:
            print(f'"{key}": {value:.4f}')
        elif key == "False Omission Rate":
            print(f'"{key}": {value:.7f}')
        else:
            print(f'"{key}": {value:.2f}')
    print('--------------------------------------------')
    # Append metrics to results dictionary
    results['Encoder'].append(encoder_name)
    results['Decoder'].append(decoder_name)
    results['Train Loss'].append(np.min(train_loss_history))
    results['Train Accuracy'].append(np.max(train_acc_history))
    results['Train IoU'].append(np.max(train_iou_history))
    results['Val Loss'].append(np.min(val_loss_history))
    results['Val Accuracy'].append(np.max(val_acc_history))
    results['Val IoU'].append(np.max(val_iou_history))
    for metric_name, metric_value in test_metrics.items():
        results[metric_name].append(metric_value)
    processed_models.append(model_info)

# Convert results to pandas DataFrame
df_new = pd.DataFrame(results)

# Load existing DataFrame
try:
    df_existing = pd.read_excel(excel_file)
    df_combined = pd.concat([df_existing, df_new]).drop_duplicates().reset_index(drop=True)
except FileNotFoundError:
    df_combined = df_new

# Save DataFrame to Excel
df_combined.to_excel(excel_file, index=False)

print(f'Results saved to {excel_file}')

### Show some results

In [None]:
next(iter(test_loader))
images, masks = next(iter(test_loader))
with torch.no_grad():
    model.eval()
    logits = model(images.to(device).float())
pr_masks = logits.sigmoid()

for image, gt_mask, pr_mask in zip(images, masks, pr_masks):
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(image.numpy().transpose(1, 2, 0))  # convert CHW -> HWC
    plt.title("Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(gt_mask.numpy().squeeze()) # just squeeze classes dim, because we have only one class
    plt.title("Ground truth")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(pr_mask.cpu().numpy().squeeze()) # just squeeze classes dim, because we have only one class
    plt.title("Prediction")
    plt.axis("off")

    plt.show()

# Segformer Training

In [None]:
def get_model_v2(encoder_name, decoder_name):
    if decoder_name == 'unet':
        model = smp.Unet(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    elif decoder_name == 'unet++':
        model = smp.UnetPlusPlus(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    elif decoder_name == 'pspnet':
        model = smp.PSPNet(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    elif decoder_name == 'deeplabv3+':
        model = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    elif decoder_name == 'deeplabv3':
        model = smp.DeepLabV3(encoder_name=encoder_name, encoder_weights=None, in_channels=3, classes=1)
    return CreateModel(model)

In [None]:
encoders = ['mit_b1', 'mit_b3', 'mit_b5']
decoders = ['unet', 'unet++', 'pspnet', 'deeplabv3+']

print(f"number of encoders {len(encoders)}")
print(f"number of decoders {len(decoders)}")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load existing results from Excel file if it exists
excel_file = ''

kwargs = {'epochs': 10000, 'optim': {'name': 'Adam', 'params': {
     'lr': 0.01, 'weight_decay': 3.310305423548208e-05}}, 'save': False}

try:
    existing_df = pd.read_excel(excel_file)
    results = existing_df.to_dict(orient='list')
    processed_models = list(existing_df.apply(lambda row: f"{row['Encoder']}--{row['Decoder']}", axis=1))
except FileNotFoundError:
    print("Experiment from the begining!!!")
    processed_models = []
# Initialize lists to store metrics
    results = {
        'Encoder': [],
        'Decoder': [],
        'Train Loss': [],
        'Train Accuracy': [],
        'Train IoU': [],
        'Val Loss': [],
        'Val Accuracy': [],
        'Val IoU': [],
        'Overall Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1 Score': [],
        'Specificity': [],
        'False Positive Rate': [],
        'False Negative Rate': [],
        'False Discovery Rate': [],
        'False Omission Rate': [],
        'Misclassification Rate': [],
        'Matthew\'s Correlation Coefficient': [],
        'IoU (Jaccard Index)': [],
        'IoU (Jaccard Index)': [],
        'Geometric Mean': []
    }

for encoder_name in encoders:
    for decoder_name in decoders:
        model_info = f"{encoder_name}--{decoder_name}"
        if model_info in processed_models:
            print(f'Skipping {model_info}, already processed.')
            continue
        print(f'Training {encoder_name} with {decoder_name}')
        
        if USE_2_GPUS:
            model = get_model_v2(encoder_name, decoder_name).cuda()
            # model = nn.DataParallel(model).to(device)
            model = nn.DataParallel(model, device_ids=[0,1])
        else:
            model = get_model_v2(encoder_name, decoder_name).to(device)

        # Train the model and collect metrics
        model, train_loss_history, train_acc_history, train_iou_history, val_loss_history, val_acc_history, val_iou_history = train_model(model, encoder_name, decoder_name, train_loader, val_loader, device, kwargs)
        # save_results_and_plots(encoder_name, decoder_name, train_loss_history, train_acc_history, train_iou_history, val_loss_history, val_acc_history, val_iou_history)
        ## show plots
        results_df = pd.DataFrame({
            'Epoch': list(range(1, len(train_loss_history) + 1)),
            'Train Loss': train_loss_history,
            'Val Loss': val_loss_history,
            'Train Accuracy': train_acc_history,
            'Val Accuracy': val_acc_history,
            'Train IOU': train_iou_history,
            'Val IOU': val_iou_history
        })
        plt.figure()
        plt.plot(results_df['Epoch'], results_df['Train Loss'], label='Train Loss')
        plt.plot(results_df['Epoch'], results_df['Val Loss'], label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title(f'Loss Curves for {encoder_name} with {decoder_name}')
        plt.savefig(f'loss_curves_{encoder_name}_{decoder_name}.png')
        plt.show()

        plt.figure()
        plt.plot(results_df['Epoch'], results_df['Train Accuracy'], label='Train Accuracy')
        plt.plot(results_df['Epoch'], results_df['Val Accuracy'], label='Val Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.title(f'Accuracy Curves for {encoder_name} with {decoder_name}')
        plt.savefig(f'accuracy_curves_{encoder_name}_{decoder_name}.png')
        plt.show()

        plt.figure()
        plt.plot(results_df['Epoch'], results_df['Train IOU'], label='Train IOU')
        plt.plot(results_df['Epoch'], results_df['Val IOU'], label='Val IOU')
        plt.xlabel('Epoch')
        plt.ylabel('IOU')
        plt.legend()
        plt.title(f'IOU Curves for {encoder_name} with {decoder_name}')
        plt.savefig(f'iou_curves_{encoder_name}_{decoder_name}.png')
        plt.show()

        # Evaluate on the test set
        test_metrics = evaluate_model(model, test_loader, device=device)
        for key, value in test_metrics.items():
            if key in ["False Discovery Rate", "False Negative Rate", "False Positive Rate"]:
                print(f'"{key}": {value:.4f},')
            elif key == "False Omission Rate":
                print(f'"{key}": {value:.7f},')
            else:
                print(f'"{key}": {value:.2f},')

        # Save metrics to results dictionary
        results['Encoder'].append(encoder_name)
        results['Decoder'].append(decoder_name)
        results['Train Loss'].append(np.min(train_loss_history))
        results['Train Accuracy'].append(np.max(train_acc_history))
        results['Train IoU'].append(np.max(train_iou_history))
        results['Val Loss'].append(np.min(val_loss_history))
        results['Val Accuracy'].append(np.max(val_acc_history))
        results['Val IoU'].append(np.max(val_iou_history))
        for metric_name, metric_value in test_metrics.items():
            results[metric_name].append(metric_value)

        
        # Convert results to pandas DataFrame
        df_new = pd.DataFrame(results)
        
        # Load existing DataFrame
        try:
            df_existing = pd.read_excel(excel_file)
            df_combined = pd.concat([df_existing, df_new]).drop_duplicates().reset_index(drop=True)
        except FileNotFoundError:
            df_combined = df_new
        
        # Save DataFrame to Excel
        df_combined.to_excel(excel_file, index=False)
        
        print(f'Results saved to {excel_file}')


### Run 2 best models, save and test

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Enter your best encoder-decoder combination here
best_models = ['mit_b5--unet']

kwargs = {'epochs': 10000, 'optim': {'name': 'Adam', 'params': {
     'lr': 0.01, 'weight_decay': 3.310305423548208e-05}}, 'save': True}

# Load existing results from Excel file if it exists
excel_file = ''

try:
    existing_df = pd.read_excel(excel_file)
    results = existing_df.to_dict(orient='list')
    processed_models = list(existing_df.apply(lambda row: f"{row['Encoder']}--{row['Decoder']}", axis=1))
except FileNotFoundError:
    print("Experiment from the begining!!!")
    processed_models = []
# Initialize lists to store metrics
    results = {
        'Encoder': [],
        'Decoder': [],
        'Train Loss': [],
        'Train Accuracy': [],
        'Train IoU': [],
        'Val Loss': [],
        'Val Accuracy': [],
        'Val IoU': [],
        'Overall Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1 Score': [],
        'Specificity': [],
        'False Positive Rate': [],
        'False Negative Rate': [],
        'False Discovery Rate': [],
        'False Omission Rate': [],
        'Misclassification Rate': [],
        'Matthew\'s Correlation Coefficient': [],
        'IoU (Jaccard Index)': [],
        'IoU_plant' : [],
        'Geometric Mean': []
    }

for model_info in best_models:
    if model_info in processed_models:
        print(f'Skipping {model_info}, already processed.')
        continue
    encoder_name, decoder_name = model_info.split('--')
    print(f'Training {encoder_name} with {decoder_name}')
    if USE_2_GPUS:
        model = get_model(encoder_name, decoder_name)
        model = nn.DataParallel(model).to(device)
    else:
        model = get_model(encoder_name, decoder_name).to(device)

    # # Train the model and collect metrics
    model, train_loss_history, train_acc_history, train_iou_history, val_loss_history, val_acc_history, val_iou_history = train_model(model, encoder_name, decoder_name, train_loader, val_loader, device, kwargs)

    # # Load the best model state
    # model.load_state_dict(torch.load(f'best_model_{encoder_name}_{decoder_name}.pth'))
    # model.load_state_dict(torch.load(f'/content/model/{encoder_name}-{decoder_name}/model.pth'))
    ## show plots
    results_df = pd.DataFrame({
        'Epoch': list(range(1, len(train_loss_history) + 1)),
        'Train Loss': train_loss_history,
        'Val Loss': val_loss_history,
        'Train Accuracy': train_acc_history,
        'Val Accuracy': val_acc_history,
        'Train IOU': train_iou_history,
        'Val IOU': val_iou_history
    })
    plt.figure()
    plt.plot(results_df['Epoch'], results_df['Train Loss'], label='Train Loss', marker=11)
    plt.plot(results_df['Epoch'], results_df['Val Loss'], label='Val Loss', marker=11)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Loss Curves for {encoder_name} with {decoder_name}')
    plt.savefig(f'loss_curves_{encoder_name}_{decoder_name}.png')
    plt.show()

    plt.figure()
    plt.plot(results_df['Epoch'], results_df['Train Accuracy'], label='Train Accuracy', marker=11)
    plt.plot(results_df['Epoch'], results_df['Val Accuracy'], label='Val Accuracy', marker=11)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(f'Accuracy Curves for {encoder_name} with {decoder_name}')
    plt.savefig(f'accuracy_curves_{encoder_name}_{decoder_name}.png')
    plt.show()

    plt.figure()
    plt.plot(results_df['Epoch'], results_df['Train IOU'], label='Train IOU', marker=11)
    plt.plot(results_df['Epoch'], results_df['Val IOU'], label='Val IOU', marker=11)
    plt.xlabel('Epoch')
    plt.ylabel('IOU')
    plt.legend()
    plt.title(f'IOU Curves for {encoder_name} with {decoder_name}')
    plt.savefig(f'iou_curves_{encoder_name}_{decoder_name}.png')
    plt.show()
    x = model
    # Evaluate on the test set
    test_metrics = evaluate_model(x, test_loader, device=device)
    for key, value in test_metrics.items():
        if key in ["False Discovery Rate", "False Negative Rate", "False Positive Rate"]:
            print(f'"{key}": {value:.4f}')
        elif key == "False Omission Rate":
            print(f'"{key}": {value:.7f}')
        else:
            print(f'"{key}": {value:.2f}')
    print('--------------------------------------------')
    # Append metrics to results dictionary
    # Save metrics to results dictionary
    results['Encoder'].append(encoder_name)
    results['Decoder'].append(decoder_name)
    results['Train Loss'].append(np.min(train_loss_history))
    results['Train Accuracy'].append(np.max(train_acc_history))
    results['Train IoU'].append(np.max(train_iou_history))
    results['Val Loss'].append(np.min(val_loss_history))
    results['Val Accuracy'].append(np.max(val_acc_history))
    results['Val IoU'].append(np.max(val_iou_history))
    for metric_name, metric_value in test_metrics.items():
        results[metric_name].append(metric_value)
    processed_models.append(model_info)

# Convert results to pandas DataFrame
df_new = pd.DataFrame(results)

# Load existing DataFrame
try:
    df_existing = pd.read_excel(excel_file)
    df_combined = pd.concat([df_existing, df_new]).drop_duplicates().reset_index(drop=True)
except FileNotFoundError:
    df_combined = df_new

# Save DataFrame to Excel
df_combined.to_excel(excel_file, index=False)

print(f'Results saved to {excel_file}')

In [None]:
next(iter(test_loader))
images, masks = next(iter(test_loader))
with torch.no_grad():
    model.eval()
    logits = model(images.to(device).float())
pr_masks = logits.sigmoid()

for image, gt_mask, pr_mask in zip(images, masks, pr_masks):
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(image.numpy().transpose(1, 2, 0))  # convert CHW -> HWC
    plt.title("Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(gt_mask.numpy().squeeze()) # just squeeze classes dim, because we have only one class
    plt.title("Ground truth")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(pr_mask.cpu().numpy().squeeze()) # just squeeze classes dim, because we have only one class
    plt.title("Prediction")
    plt.axis("off")

    plt.show()