In [None]:

!pip install torchgeometry


In [None]:
!pip show torch


In [None]:

# from torchgeometry.losses import one_hot
import os
import os.path as osp

import cv2
import time
import imageio
import random
import math
import numbers
from tqdm import tqdm
from collections import OrderedDict
import wandb
import glob
from pathlib import Path
import argparse
from typing import Dict,Tuple,Optional,List

import pandas as pd
import numpy as np
from PIL import Image, ImageOps, ImageEnhance
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch import Tensor
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import (Pad, ColorJitter, Resize, FiveCrop, RandomResizedCrop,
                                    RandomHorizontalFlip, RandomRotation, RandomVerticalFlip,
                                    PILToTensor, ToPILImage, Compose, InterpolationMode)

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

import albumentations as A
import albumentations as albu
from albumentations.pytorch.transforms import ToTensorV2

import multiprocessing.pool as mpp
import multiprocessing as mp

from scipy.ndimage.morphology import generate_binary_structure, binary_erosion
from scipy.ndimage import maximum_filter


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Parameters

In [None]:
num_classes = 6

# Number of epoch
epochs = 100

# Hyperparameters for training 
learning_rate = 8e-03
batch_size = 16
display_step = 2

# Model path
checkpoint_path = '/kaggle/working/unet_model.pth'
pretrained_path = "/kaggle/input/model2/unet_model (1).pth"
# Initialize lists to keep track of loss and accuracy
loss_epoch_array = []
train_accuracy = []
test_accuracy = []
valid_accuracy = []

In [None]:
SEED = 42


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(42)

ImSurf = np.array([255, 255, 255])  # label 0
Building = np.array([255, 0, 0]) # label 1
LowVeg = np.array([255, 255, 0]) # label 2
Tree = np.array([0, 255, 0]) # label 3
Car = np.array([0, 255, 255]) # label 4
Clutter = np.array([0, 0, 255]) # label 5
Boundary = np.array([0, 0, 0]) # label 6
num_classes = 6





def pv2rgb(mask):
    h, w = mask.shape[0], mask.shape[1]
    mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8)
    mask_convert = mask[np.newaxis, :, :]
    mask_rgb[np.all(mask_convert == 3, axis=0)] = [0, 255, 0]
    mask_rgb[np.all(mask_convert == 0, axis=0)] = [255, 255, 255]
    mask_rgb[np.all(mask_convert == 1, axis=0)] = [255, 0, 0]
    mask_rgb[np.all(mask_convert == 2, axis=0)] = [255, 255, 0]
    mask_rgb[np.all(mask_convert == 4, axis=0)] = [0, 204, 255]
    mask_rgb[np.all(mask_convert == 5, axis=0)] = [0, 0, 255]
    return mask_rgb
def label2rgb(mask):
    h, w = mask.shape[0], mask.shape[1]
    mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8)
    mask_convert = mask[np.newaxis, :, :]
    mask_rgb[np.all(mask_convert == 3, axis=0)] = [0, 255, 0]
    mask_rgb[np.all(mask_convert == 0, axis=0)] = [255, 255, 255]
    mask_rgb[np.all(mask_convert == 1, axis=0)] = [255, 0, 0]
    mask_rgb[np.all(mask_convert == 2, axis=0)] = [255, 255, 0]
    mask_rgb[np.all(mask_convert == 4, axis=0)] = [0, 204, 255]
    mask_rgb[np.all(mask_convert == 5, axis=0)] = [0, 0, 255]
    return mask_rgb



def car_color_replace(mask):
    mask = cv2.cvtColor(np.array(mask.copy()), cv2.COLOR_RGB2BGR)
    mask[np.all(mask == [0, 255, 255], axis=-1)] = [0, 204, 255]

    return mask


def rgb_to_2D_label(_label):
    _label = _label.transpose(2, 0, 1)
    label_seg = np.zeros(_label.shape[1:], dtype=np.uint8)
    label_seg[np.all(_label.transpose([1, 2, 0]) == ImSurf, axis=-1)] = 0
    label_seg[np.all(_label.transpose([1, 2, 0]) == Building, axis=-1)] = 1
    label_seg[np.all(_label.transpose([1, 2, 0]) == LowVeg, axis=-1)] = 2
    label_seg[np.all(_label.transpose([1, 2, 0]) == Tree, axis=-1)] = 3
    label_seg[np.all(_label.transpose([1, 2, 0]) == Car, axis=-1)] = 4
    label_seg[np.all(_label.transpose([1, 2, 0]) == Clutter, axis=-1)] = 5
    label_seg[np.all(_label.transpose([1, 2, 0]) == Boundary, axis=-1)] = 6
    return label_seg



    

 # Split Data

 # Dataloader

In [None]:
CLASSES = ('ImSurf', 'Building', 'LowVeg', 'Tree', 'Car', 'Clutter')
class_label={0:'ImSurf', 1:'Building',2: 'LowVeg', 3:'Tree', 4:'Car',5: 'Clutter'}
PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 204, 0], [255, 0, 0]]

ORIGIN_IMG_SIZE = (256, 256)
INPUT_IMG_SIZE = (256, 256)
TEST_IMG_SIZE = (256, 256)

def get_training_transform():
    train_transform = [
        # albu.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.15),
        # albu.RandomRotate90(p=0.25),
        albu.Normalize()
    ]
    return albu.Compose(train_transform)

def get_val_transform():
    val_transform = [
        albu.Normalize()
    ]
    return albu.Compose(val_transform)


def val_aug(img, mask):
    img, mask = np.array(img), np.array(mask)
    aug = get_val_transform()(image=img.copy(), mask=mask.copy())
    img, mask = aug['image'], aug['mask']
    return img, mask


class PotsdamDataset(Dataset):
    def __init__(self, data_root='/kaggle/input/deep-data-potsdam-vaihingen/Deep_data_segmentation/Split/Potsdam/', mode='val', img_dir='Image', mask_dir='Label/',
                 img_suffix='.tif', mask_suffix='.png', transform=val_aug, mosaic_ratio=0.0,
                 img_size=ORIGIN_IMG_SIZE):
        self.data_root = data_root
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_suffix = img_suffix
        self.mask_suffix = mask_suffix
        self.transform = transform
        self.mode = mode
        self.mosaic_ratio = mosaic_ratio
        self.img_size = img_size
        self.img_ids = self.get_img_ids(self.data_root, self.img_dir, self.mask_dir)

    def __getitem__(self, index):
        p_ratio = random.random()
        if p_ratio > self.mosaic_ratio or self.mode == 'val' or self.mode == 'test':
            img, mask = self.load_img_and_mask(index)
            if self.transform:
                img, mask = self.transform(img, mask)
        else:
            img, mask = self.load_mosaic_img_and_mask(index)
            if self.transform:
                img, mask = self.transform(img, mask)

        img = torch.from_numpy(img).permute(2, 0, 1).float()
        mask = torch.from_numpy(mask).long()
        img_id = self.img_ids[index]
        results = dict(img_id=img_id, img=img, gt_semantic_seg=mask)
        return results

    def __len__(self):
        return len(self.img_ids)

    def get_img_ids(self, data_root, img_dir, mask_dir):
        img_filename_list = os.listdir(osp.join(data_root, img_dir))
        mask_filename_list = os.listdir(osp.join(data_root, mask_dir))
        assert len(img_filename_list) == len(mask_filename_list)
        img_ids = [str(id.split('.')[0]) for id in mask_filename_list]
        return img_ids

    def load_img_and_mask(self, index):
        img_id = self.img_ids[index]
        img_name = osp.join(self.data_root, self.img_dir, img_id + self.img_suffix)
        mask_name = osp.join(self.data_root, self.mask_dir, img_id + self.mask_suffix)
        img = Image.open(img_name).convert('RGB')
        mask = Image.open(mask_name).convert('L')
        return img, mask

    def load_mosaic_img_and_mask(self, index):
        indexes = [index] + [random.randint(0, len(self.img_ids) - 1) for _ in range(3)]
        img_a, mask_a = self.load_img_and_mask(indexes[0])
        img_b, mask_b = self.load_img_and_mask(indexes[1])
        img_c, mask_c = self.load_img_and_mask(indexes[2])
        img_d, mask_d = self.load_img_and_mask(indexes[3])

        img_a, mask_a = np.array(img_a), np.array(mask_a)
        img_b, mask_b = np.array(img_b), np.array(mask_b)
        img_c, mask_c = np.array(img_c), np.array(mask_c)
        img_d, mask_d = np.array(img_d), np.array(mask_d)

        w = self.img_size[1]
        h = self.img_size[0]

        start_x = w // 4
        strat_y = h // 4
        # The coordinates of the splice center
        offset_x = random.randint(start_x, (w - start_x))
        offset_y = random.randint(strat_y, (h - strat_y))

        crop_size_a = (offset_x, offset_y)
        crop_size_b = (w - offset_x, offset_y)
        crop_size_c = (offset_x, h - offset_y)
        crop_size_d = (w - offset_x, h - offset_y)

        random_crop_a = albu.RandomCrop(width=crop_size_a[0], height=crop_size_a[1])
        random_crop_b = albu.RandomCrop(width=crop_size_b[0], height=crop_size_b[1])
        random_crop_c = albu.RandomCrop(width=crop_size_c[0], height=crop_size_c[1])
        random_crop_d = albu.RandomCrop(width=crop_size_d[0], height=crop_size_d[1])

        croped_a = random_crop_a(image=img_a.copy(), mask=mask_a.copy())
        croped_b = random_crop_b(image=img_b.copy(), mask=mask_b.copy())
        croped_c = random_crop_c(image=img_c.copy(), mask=mask_c.copy())
        croped_d = random_crop_d(image=img_d.copy(), mask=mask_d.copy())

        img_crop_a, mask_crop_a = croped_a['image'], croped_a['mask']
        img_crop_b, mask_crop_b = croped_b['image'], croped_b['mask']
        img_crop_c, mask_crop_c = croped_c['image'], croped_c['mask']
        img_crop_d, mask_crop_d = croped_d['image'], croped_d['mask']

        top = np.concatenate((img_crop_a, img_crop_b), axis=1)
        bottom = np.concatenate((img_crop_c, img_crop_d), axis=1)
        img = np.concatenate((top, bottom), axis=0)

        top_mask = np.concatenate((mask_crop_a, mask_crop_b), axis=1)
        bottom_mask = np.concatenate((mask_crop_c, mask_crop_d), axis=1)
        mask = np.concatenate((top_mask, bottom_mask), axis=0)
        mask = np.ascontiguousarray(mask)
        img = np.ascontiguousarray(img)

        img = Image.fromarray(img)
        mask = Image.fromarray(mask)

        return img, mask

In [None]:

train_dataset = PotsdamDataset(data_root='/kaggle/input/deep-data-potsdam-vaihingen/Deep_data_segmentation/Split/Potsdam/', mode='train',
                               mosaic_ratio=0.25, transform=val_aug)

train_size = 0.35
valid_size = 0.65
dumb1 = 0.15
dumb2 = 0.85
train_length = round(train_size * len(train_dataset))
valid_length = round(valid_size * len(train_dataset))
train_set, dumb0 = random_split(train_dataset, [train_length, valid_length])
val_set, xxx = random_split(dumb0, [dumb1, dumb2])


test_dataset = PotsdamDataset(data_root='/kaggle/input/deep-data-potsdam-vaihingen/Deep_data_segmentation/Split/Potsdam/Test',
                              transform=val_aug)

train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size,
                          num_workers=0,
                          pin_memory=False,
                          shuffle=True,
                          drop_last=True)

val_loader = DataLoader(dataset=val_set,
                        batch_size=batch_size,
                        num_workers=0,
                        shuffle=False,
                        pin_memory=False,
                        drop_last=False)


# Metric

In [None]:
class Evaluator(object):
    def __init__(self, num_class):
        self.num_class = num_class
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
        self.eps = 1e-8

    def get_tp_fp_tn_fn(self):
        tp = np.diag(self.confusion_matrix)
        fp = self.confusion_matrix.sum(axis=0) - np.diag(self.confusion_matrix)
        fn = self.confusion_matrix.sum(axis=1) - np.diag(self.confusion_matrix)
        tn = np.diag(self.confusion_matrix).sum() - np.diag(self.confusion_matrix)
        return tp, fp, tn, fn

    def Precision(self):
        tp, fp, tn, fn = self.get_tp_fp_tn_fn()
        precision = tp / (tp + fp)
        return precision

    def Recall(self):
        tp, fp, tn, fn = self.get_tp_fp_tn_fn()
        recall = tp / (tp + fn)
        return recall

    def F1(self):
        tp, fp, tn, fn = self.get_tp_fp_tn_fn()
        Precision = tp / (tp + fp)
        Recall = tp / (tp + fn)
        F1 = (2.0 * Precision * Recall) / (Precision + Recall)
        return F1

    def OA(self):
        OA = np.diag(self.confusion_matrix).sum() / (self.confusion_matrix.sum() + self.eps)
        return OA

    def Intersection_over_Union(self):
        tp, fp, tn, fn = self.get_tp_fp_tn_fn()
        IoU = tp / (tp + fn + fp)
        return IoU

    def Dice(self):
        tp, fp, tn, fn = self.get_tp_fp_tn_fn()
        Dice = 2 * tp / ((tp + fp) + (tp + fn))
        return Dice

    def Pixel_Accuracy_Class(self):
        #         TP                                  TP+FP
        Acc = np.diag(self.confusion_matrix) / (self.confusion_matrix.sum(axis=0) + self.eps)
        return Acc

    def Frequency_Weighted_Intersection_over_Union(self):
        freq = np.sum(self.confusion_matrix, axis=1) / (np.sum(self.confusion_matrix) + self.eps)
        iou = self.Intersection_over_Union()
        FWIoU = (freq[freq > 0] * iou[freq > 0]).sum()
        return FWIoU

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class ** 2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def add_batch(self, gt_image, pre_image):
        assert gt_image.shape == pre_image.shape, 'pre_image shape {}, gt_image shape {}'.format(pre_image.shape,
                                                                                                 gt_image.shape)
        self.confusion_matrix += self._generate_matrix(gt_image, pre_image)

    def reset(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

# Model

**Backbone**

In [None]:
class WideBlock(nn.Module):
    expansion:int = 4
    def __init__(self, c1, c2, stride = 1, downsample = None):
        super(WideBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(c2)
        self.conv2 = nn.Conv2d(c2, c2, 3, stride, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(c2)
        self.conv3 = nn.Conv2d(c2, c2 * self.expansion, 1, 1, 0, bias=False)
        self.bn3 = nn.BatchNorm2d(c2 * self.expansion)
        self.downsample = downsample
        
    def forward(self, x):
        identity = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        out = self.bn3(self.conv3(x))
        if self.downsample is not None:
            identity = self.downsample(identity)
        out += identity
        return out

In [None]:
settings = {
    '50': [[3, 4, 6, 3], [256, 512, 1024, 2048]],
    '101': [[3, 4, 23, 3], [256, 512, 1024, 2048]],}
class Resnet(nn.Module):
    def __init__(self, setting:str = '50'):
        super(Resnet, self).__init__()
        assert setting in settings.keys(), f"ResNet model name should be in {list(settings.keys())}"
        depths, channels = settings[setting]

        self.inplanes = 64
        self.channels = channels
        self.conv1 = nn.Conv2d(3, self.inplanes, 7, 2, 3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.maxpool = nn.MaxPool2d(3, 2, 1)

        self.layer1 = self._make_layer(64, depths[0], s=1)
        self.layer2 = self._make_layer(128, depths[1], s=2)
        self.layer3 = self._make_layer(256, depths[2], s=2)
        self.layer4 = self._make_layer(512, depths[3], s=2)
        
    def _make_layer(self, planes, depth, s=1) -> nn.Sequential:
        downsample = None
        if s != 1 or self.inplanes != planes * WideBlock.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * WideBlock.expansion, 1, s, bias=False),
                nn.BatchNorm2d(planes * WideBlock.expansion)
            )
        layers = nn.Sequential(
            WideBlock(self.inplanes, planes, s, downsample),
            *[WideBlock(planes * WideBlock.expansion, planes) for _ in range(1, depth)]
        )
        self.inplanes = planes * WideBlock.expansion
        return layers


    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        x = self.maxpool(F.relu(self.bn1(self.conv1(x))))   # [64, H/4, W/4]
        x1 = self.layer1(x)  # [64/256, H/4, W/4]   
        x2 = self.layer2(x1)  # [128/512, H/8, W/8]
        x3 = self.layer3(x2)  # [256/1024, H/16, W/16]
        x4 = self.layer4(x3)  # [512/2048, H/32, W/32]
        return x1, x2, x3, x4


**Head**

In [None]:
class PPM(nn.Module):
    """Pyramid Pooling Module in PSPNet
    """
    def __init__(self, c1, c2=128, scales=(1, 2, 3, 6)):
        super().__init__()
        self.stages = nn.ModuleList([
            nn.Sequential(
                nn.AdaptiveAvgPool2d(scale),
                ConvModule(c1, c2, 1)
            )
        for scale in scales])

        self.bottleneck = ConvModule(c1 + c2 * len(scales), c2, 3, 1, 1)

    def forward(self, x: Tensor) -> Tensor:
        outs = []
        for stage in self.stages:
            outs.append(F.interpolate(stage(x), size=x.shape[-2:], mode='bilinear', align_corners=True))

        outs = [x] + outs[::-1]
        out = self.bottleneck(torch.cat(outs, dim=1))
        return out

In [None]:
class ConvModule(nn.Sequential):
    def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):
        super().__init__(
            nn.Conv2d(c1, c2, k, s, p, d, g, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(True)
        )


In [None]:
class UperHead(nn.Module):
    def __init__(self, in_channels, channel=128, num_classes: int = num_classes, scales=(1, 2, 3, 6)):
        super().__init__()
        # PPM Module
        self.ppm = PPM(in_channels[-1], channel, scales)

        # FPN Module
        self.fpn_in = nn.ModuleList()
        self.fpn_out = nn.ModuleList()

        for in_ch in in_channels[:-1]: # skip the top layer
            self.fpn_in.append(ConvModule(in_ch, channel, 1))
            self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1))

        self.bottleneck = ConvModule(len(in_channels)*channel, channel, 3, 1, 1)
        self.dropout = nn.Dropout2d(0.1)
        self.conv_seg = nn.Conv2d(channel, num_classes, 1)


    def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor:
        f = self.ppm(features[-1])
        fpn_features = [f]

        for i in reversed(range(len(features)-1)):
            feature = self.fpn_in[i](features[i])
            f = feature + F.interpolate(f, size=feature.shape[-2:], mode='bilinear', align_corners=False)
            fpn_features.append(self.fpn_out[i](f))

        fpn_features.reverse()
        for i in range(1, len(features)):
            fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', align_corners=False)
 
        output = self.bottleneck(torch.cat(fpn_features, dim=1))
        output = self.conv_seg(self.dropout(output))
        return output

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.backbone = Resnet()
        self.head = UperHead([256, 512, 1024, 2048])
        self.conv1 = ConvModule(9, 6, 3, p=1)
        
    def forward(self, x):
        features = self.backbone(x)
        outs = self.head(features)
        outs = F.interpolate(outs, x.shape[-2:], mode='bilinear')
        outs = torch.cat([outs, x], dim = 1)
        outs = self.conv1(outs)
        return outs

# Loss function

In [None]:
def label_smoothed_nll_loss(
    lprobs: torch.Tensor, target: torch.Tensor, epsilon: float, ignore_index=None, reduction="mean", dim=-1
) -> torch.Tensor:
    """

    Source: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py

    :param lprobs: Log-probabilities of predictions (e.g after log_softmax)
    :param target:
    :param epsilon:
    :param ignore_index:
    :param reduction:
    :return:
    """
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(dim)

    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        target = target.masked_fill(pad_mask, 0)
        nll_loss = -lprobs.gather(dim=dim, index=target)
        smooth_loss = -lprobs.sum(dim=dim, keepdim=True)

        # nll_loss.masked_fill_(pad_mask, 0.0)
        # smooth_loss.masked_fill_(pad_mask, 0.0)
        nll_loss = nll_loss.masked_fill(pad_mask, 0.0)
        smooth_loss = smooth_loss.masked_fill(pad_mask, 0.0)
    else:
        nll_loss = -lprobs.gather(dim=dim, index=target)
        smooth_loss = -lprobs.sum(dim=dim, keepdim=True)

        nll_loss = nll_loss.squeeze(dim)
        smooth_loss = smooth_loss.squeeze(dim)

    if reduction == "sum":
        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
    if reduction == "mean":
        nll_loss = nll_loss.mean()
        smooth_loss = smooth_loss.mean()

    eps_i = epsilon / lprobs.size(dim)
    loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
    return loss

In [None]:
__all__ = ["SoftCrossEntropyLoss"]
from typing import Optional
from torch.nn.modules.loss import _Loss
from typing import List
class SoftCrossEntropyLoss(nn.Module):
    """
    Drop-in replacement for nn.CrossEntropyLoss with few additions:
    - Support of label smoothing
    """

    __constants__ = ["reduction", "ignore_index", "smooth_factor"]

    def __init__(self, reduction: str = "mean", smooth_factor: float = 0.0, ignore_index: Optional[int] = -100, dim=1):
        super().__init__()
        self.smooth_factor = smooth_factor
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.dim = dim

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        log_prob = F.log_softmax(input, dim=self.dim)
        pad_mask = target.eq(self.ignore_index)
        target = target.masked_fill(pad_mask, 0)
        log_prob = log_prob.masked_fill(pad_mask.unsqueeze(1), 0)
        return label_smoothed_nll_loss(
            log_prob,
            target,
            epsilon=self.smooth_factor,
            ignore_index=self.ignore_index,
            reduction=self.reduction,
            dim=self.dim,
        )
__all__ = ["JointLoss", "WeightedLoss"]


class WeightedLoss(_Loss):
    """Wrapper class around loss function that applies weighted with fixed factor.
    This class helps to balance multiple losses if they have different scales
    """

    def __init__(self, loss, weight=1.0):
        super().__init__()
        self.loss = loss
        self.weight = weight

    def forward(self, *input):
        return self.loss(*input) * self.weight
class JointLoss(_Loss):
    """
    Wrap two loss functions into one. This class computes a weighted sum of two losses.
    """

    def __init__(self, first: nn.Module, second: nn.Module, first_weight=1.0, second_weight=1.0):
        super().__init__()
        self.first = WeightedLoss(first, first_weight)
        self.second = WeightedLoss(second, second_weight)

    def forward(self, *input):
        return self.first(*input) + self.second(*input)
__all__ = ["DiceLoss"]

BINARY_MODE = "binary"
MULTICLASS_MODE = "multiclass"
MULTILABEL_MODE = "multilabel"
def soft_dice_score(
    output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
) -> torch.Tensor:
    """

    :param output:
    :param target:
    :param smooth:
    :param eps:
    :return:

    Shape:
        - Input: :math:`(N, NC, *)` where :math:`*` means any number
            of additional dimensions
        - Target: :math:`(N, NC, *)`, same shape as the input
        - Output: scalar.

    """
    assert output.size() == target.size()
    if dims is not None:
        intersection = torch.sum(output * target, dim=dims)
        cardinality = torch.sum(output + target, dim=dims)
    else:
        intersection = torch.sum(output * target)
        cardinality = torch.sum(output + target)
    dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
    return dice_score
class DiceLoss(_Loss):
    """
    Implementation of Dice loss for image segmentation task.
    It supports binary, multiclass and multilabel cases
    """

    def __init__(
        self,
        mode: str = 'multiclass',
        classes: List[int] = None,
        log_loss=False,
        from_logits=True,
        smooth: float = 0.0,
        ignore_index=None,
        eps=1e-7,
    ):
        """

        :param mode: Metric mode {'binary', 'multiclass', 'multilabel'}
        :param classes: Optional list of classes that contribute in loss computation;
        By default, all channels are included.
        :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
        :param from_logits: If True assumes input is raw logits
        :param smooth:
        :param ignore_index: Label that indicates ignored pixels (does not contribute to loss)
        :param eps: Small epsilon for numerical stability
        """
        assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
        super(DiceLoss, self).__init__()
        self.mode = mode
        if classes is not None:
            assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary"
            classes = to_tensor(classes, dtype=torch.long)

        self.classes = classes
        self.from_logits = from_logits
        self.smooth = smooth
        self.eps = eps
        self.ignore_index = ignore_index
        self.log_loss = log_loss

    def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
        """

        :param y_pred: NxCxHxW
        :param y_true: NxHxW
        :return: scalar
        """
        assert y_true.size(0) == y_pred.size(0)

        if self.from_logits:
            # Apply activations to get [0..1] class probabilities
            # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
            # extreme values 0 and 1
            if self.mode == MULTICLASS_MODE:
                y_pred = y_pred.log_softmax(dim=1).exp()
            else:
                y_pred = F.logsigmoid(y_pred).exp()

        bs = y_true.size(0)
        num_classes = y_pred.size(1)
        dims = (0, 2)

        if self.mode == BINARY_MODE:
            y_true = y_true.view(bs, 1, -1)
            y_pred = y_pred.view(bs, 1, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask
                y_true = y_true * mask

        if self.mode == MULTICLASS_MODE:
            y_true = y_true.view(bs, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask.unsqueeze(1)

                y_true = F.one_hot((y_true * mask).to(torch.long), num_classes)  # N,H*W -> N,H*W, C
                y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1)  # H, C, H*W
            else:
                y_true = F.one_hot(y_true, num_classes)  # N,H*W -> N,H*W, C
                y_true = y_true.permute(0, 2, 1)  # H, C, H*W

        if self.mode == MULTILABEL_MODE:
            y_true = y_true.view(bs, num_classes, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask
                y_true = y_true * mask

        scores = soft_dice_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)

        if self.log_loss:
            loss = -torch.log(scores.clamp_min(self.eps))
        else:
            loss = 1.0 - scores

        # Dice loss is undefined for non-empty classes
        # So we zero contribution of channel that does not have true pixels
        # NOTE: A better workaround would be to use loss term `mean(y_pred)`
        # for this case, however it will be a modified jaccard loss

        mask = y_true.sum(dims) > 0
        loss *= mask.to(loss.dtype)

        if self.classes is not None:
            loss = loss[self.classes]

        return loss.mean()


In [None]:

class UnetLoss(nn.Module):
    def __init__(self, ignore_index=255):
        super().__init__()
        self.main_loss = JointLoss(SoftCrossEntropyLoss(smooth_factor=0.05, ignore_index=ignore_index),
                                   DiceLoss(smooth=0.05, ignore_index=ignore_index), 1.0, 1.0)
        self.aux_loss = SoftCrossEntropyLoss(smooth_factor=0.05, ignore_index=ignore_index)

    def forward(self, logits, labels):
        if self.training and len(logits) == 2:
            logit_main, logit_aux = logits
            loss = self.main_loss(logit_main, labels) + 0.4 * self.aux_loss(logit_aux, labels)
        else:
            loss = self.main_loss(logits, labels)

        return loss

# Training

**Initialize weights**

In [None]:
def weights_init(model):
    if isinstance(model, nn.Linear):
        # Xavier Distribution
        torch.nn.init.xavier_uniform_(model.weight)

In [None]:
def save_model(model, optimizer, path):
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, path)

def load_model(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer

**Train model**

In [None]:
def train(train_dataloader, valid_dataloader, learing_rate_scheduler, epoch, display_step):
    print(f"Start epoch #{epoch+1}, learning rate for this epoch: {learing_rate_scheduler.get_last_lr()}")
    start_time = time.time()
    train_loss_epoch = 0
    test_loss_epoch = 0
    last_loss = 999999999
    model.train()
    metrics_train = Evaluator(num_class=6)
    metrics_val = Evaluator(num_class=6)

    for i,input in enumerate(tqdm(train_dataloader)):
        # Load data into GPU
        data=input['img']
        masks_true = input['gt_semantic_seg']
        data,mask = data.to(device),masks_true.to(device)
        optimizer.zero_grad()
        prediction = model(data)
        # Backpropagation, compute gradients
        loss = loss_function(prediction, mask.long())
        pre_mask = nn.Softmax(dim=1)(prediction)
        pre_mask = pre_mask.argmax(dim=1)
        for i in range(mask.shape[0]):
            metrics_train.add_batch(mask[i].cpu().numpy(), pre_mask[i].cpu().numpy())
        loss.backward()

        # Apply gradients
        optimizer.step()

        # Save loss
        train_loss_epoch += loss.item()
    print(f"Done epoch #{epoch+1}, time for this epoch: {time.time()-start_time}s")
    train_loss_epoch /= (i + 1)
    mIoU = np.nanmean(metrics_train.Intersection_over_Union()[:-1])
    F1 = np.nanmean(metrics_train.F1()[:-1])
    OA = np.nanmean(metrics_train.OA())
    iou_per_class = metrics_train.Intersection_over_Union()
    train_eval_value =  (iou_per_class,mIoU,F1,OA)
    metrics_train.reset()
    # Evaluate the validation set
    model.eval()
    with torch.no_grad():
        for k,input in enumerate(tqdm(valid_dataloader)):
            data=input['img'].to(device)
            mask = input['gt_semantic_seg'].to(device)
            prediction = model(data)
            test_loss = loss_function(prediction, mask.long())
            pre_mask = nn.Softmax(dim=1)(prediction)
            pre_mask = pre_mask.argmax(dim=1)
            test_loss_epoch += test_loss.item()
            if k<=3:
            # Convert predictions to 2D array
                predictions_2d = pre_mask[0].cpu().numpy()  # Assuming you want the first prediction

            # Convert ground truth masks to 2D array
                masks_true_2d = mask[0].cpu().numpy()  # Assuming you want the first ground truth mask

            # Convert to np.int8 if needed
                predictions_2d = predictions_2d.astype(np.int8)
                masks_true_2d = masks_true_2d.astype(np.int8)
                wandb.log(
      {f"val_image{k}" : wandb.Image(data[0], masks={
        "predictions" : {
            "mask_data" : predictions_2d,
            "class_labels" : class_label
        },
        "ground_truth" : {
            "mask_data" : masks_true_2d,
            "class_labels" : class_label
        }
    })})
            for i in range(mask.shape[0]):
                metrics_val.add_batch(mask[i].cpu().numpy(), pre_mask[i].cpu().numpy())
    test_loss_epoch /= (i + 1)
    mIoU = np.nanmean(metrics_val.Intersection_over_Union()[:-1])
    F1 = np.nanmean(metrics_val.F1()[:-1])
    OA = np.nanmean(metrics_val.OA())
    iou_per_class_val = metrics_val.Intersection_over_Union()
    eval_value =  (iou_per_class_val,mIoU,F1,OA)
    print(eval_value)
    metrics_val.reset()
    return train_loss_epoch, train_eval_value, test_loss_epoch, eval_value


**Test model**

In [None]:
device

In [None]:
model = UNet()
model.to(device)

In [None]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
    # Use GPU-accelerated PyTorch functions here
    print("CUDA is available.")
else:
    print("CUDA is not available.")

In [None]:

loss_function = UnetLoss(ignore_index=6)
# Define the optimizer (Adam optimizer)
optimizer = optim.Adam(params=model.parameters(), lr=learning_rate)
# optimizer.load_state_dict(checkpoint['optimizer'])

# Learning rate scheduler
learing_rate_scheduler = lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.6)

In [None]:
save_model(model, optimizer, checkpoint_path)
load_checkpoint_flag=False

In [None]:
# print("Model keys:")
# print(model.state_dict().keys())

# print("\nCheckpoint keys:")
# print(checkpoint['model'].keys())


In [None]:
# Load the model checkpoint if needed
# Load the model checkpoint if needed
if load_checkpoint_flag:
    model, optimize= load_model(model, optimizer, pretrained_path)
    


In [None]:
wandb.login(
    # set the wandb project where this run will be logged
#     project= "PolypSegment", 
    key = "redacted",
)
id = wandb.util.generate_id()
print(id)
wandb.init(id=id,project = "UperNet", resume="allow")



# Training loop
train_loss_array = []
test_loss_array = []
last_loss = 9999999999999
for epoch in range(epochs):
    train_loss_epoch = 0
    test_loss_epoch = 0
    train_loss_epoch, train_eval_value, test_loss_epoch, eval_value = train(train_loader, 
                                              val_loader, 
                                              learing_rate_scheduler, epoch, display_step)
    
    if test_loss_epoch < last_loss:
        save_model(model, optimizer, checkpoint_path)
        last_loss = test_loss_epoch
        
    iou_value = {}
    iou_per_class,mIoU,F1,OA=train_eval_value
    eval_value_train = {'mIoU': mIoU,
                      'F1': F1,
                      'OA': OA}                                          
    print('train:', eval_value_train)
    train_accuracy.append(OA)
    wandb.log({'mIoU_train': mIoU,
                      'F1_train': F1,
                      'OA_train': OA}) 
    
    for class_name, iou in zip(CLASSES,iou_per_class):
        wandb.log({f"{class_name}_train_IOU": iou})
        iou_value[class_name] = iou
    print(iou_value)
           
    iou_value = {}
    iou_per_class,mIoU,F1,OA = eval_value
    eval_value_val = {'mIoU': mIoU,
                      'F1': F1,
                      'OA': OA}                                          
    print('val:', eval_value_val)
    valid_accuracy.append(OA)
    wandb.log({'mIoU_val': mIoU,
                      'F1_val': F1,
                      'OA_val': OA})
    for class_name, iou in zip(CLASSES,iou_per_class):
        wandb.log({f"{class_name}_val_IoU": iou})
        iou_value[class_name] = iou
    print(iou_value)
    
    learing_rate_scheduler.step()
    train_loss_array.append(train_loss_epoch)
    test_loss_array.append(test_loss_epoch)
    wandb.log({"Train loss": train_loss_epoch, "Valid loss": test_loss_epoch})
    print("Epoch {}: loss: {:.4f}, train accuracy: {:.4f}, valid accuracy:{:.4f}".format(epoch + 1, 
                                        train_loss_array[-1], train_accuracy[-1], valid_accuracy[-1]))
    

In [None]:
torch.cuda.empty_cache()

In [None]:
def img_writer(inp):
    (mask,  mask_id, rgb) = inp
    if rgb:
        mask_name_tif = mask_id + '.png'
        mask_tif = label2rgb(mask)
        cv2.imwrite(mask_name_tif, mask_tif)
    else:
        mask_png = mask.astype(np.uint8)
        mask_name_png = mask_id + '.png'
        cv2.imwrite(mask_name_png, mask_png)


def get_args():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg("-c", "--config_path", type=Path, required=True, help="Path to  config")
    arg("-o", "--output_path", type=Path, help="Path where to save resulting masks.", required=True)
    arg("-t", "--tta", help="Test time augmentation.", default=None, choices=[None, "d4", "lr"])
    arg("--rgb", help="whether output rgb images", action='store_true')
    return parser.parse_args()

In [None]:
test_dataset = PotsdamDataset(data_root='/kaggle/input/deep-data-potsdam-vaihingen/Deep_data_segmentation/Split/Vaihingen/Test',
                              transform=val_aug)

In [None]:
model.cuda()
model.eval()
evaluator = Evaluator(num_class=6)
evaluator.reset()
with torch.no_grad():
    test_loader = DataLoader(
        test_dataset,
        batch_size=2,
        num_workers=4,
        pin_memory=True,
        drop_last=False,
    )
    results = []
    for input in tqdm(test_loader):
        # raw_prediction NxCxHxW
        raw_predictions = model(input['img'].cuda())

        image_ids = input["img_id"]
        masks_true = input['gt_semantic_seg']

        raw_predictions = nn.Softmax(dim=1)(raw_predictions)
        predictions = raw_predictions.argmax(dim=1)

        for i in range(raw_predictions.shape[0]):
            mask = predictions[i].cpu().numpy()
            evaluator.add_batch(pre_image=mask, gt_image=masks_true[i].cpu().numpy())
            mask_name = image_ids[i]
iou_per_class = evaluator.Intersection_over_Union()
f1_per_class = evaluator.F1()
OA = evaluator.OA()
for class_name, class_iou, class_f1 in zip(CLASSES, iou_per_class, f1_per_class):
    print('F1_{}:{}, IOU_{}:{}'.format(class_name, class_f1, class_name, class_iou))
print('F1:{}, mIOU:{}, OA:{}'.format(np.nanmean(f1_per_class[:-1]), np.nanmean(iou_per_class[:-1]), OA))
t0 = time.time()
mpp.Pool(processes=2).map(img_writer, results)
t1 = time.time()
img_write_time = t1 - t0
print('images writing spends: {} s'.format(img_write_time))