In [None]:
!pip install transformers
!pip install segmentation_models_pytorch

In [None]:
import json, os, torch, cv2, numpy as np, albumentations as A
from PIL import Image
import matplotlib.pyplot as plt
from glob import glob
from torch.utils.data import random_split, Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import time
from tqdm import tqdm
import torch.nn.functional as F
from segmentation_models_pytorch.losses import JaccardLoss, DiceLoss
import torch.nn as nn
from transformers import (
    SegformerForSemanticSegmentation,
    Mask2FormerForUniversalSegmentation,
    Mask2FormerImageProcessor
)
from functools import partial

In [None]:
class CustomSegmentationDataset(Dataset):
    def __init__(self, root, transformations=None):
        self.im_paths = sorted(glob(f'{root}/*/images/*.jpg'))
        self.gt_paths = sorted(glob(f'{root}/*/masks/*.png'))
        self.transformations = transformations
        self.n_cls = 5
        
    def __len__(self):
        return len(self.im_paths)
    
    def __getitem__(self, idx):
        im, gt = self.get_im_gt(self.im_paths[idx], self.gt_paths[idx])
        if self.transformations:
            im, gt = self.apply_transformations(im, gt)
        return im, gt
    
    def get_im_gt(self, im_path, gt_path):
        return self.read_im(im_path), self.read_im(gt_path)
    
    def read_im(self, path):
        return cv2.cvtColor(cv2.imread(path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
    
    def apply_transformations(self, im, gt):
        transformed = self.transformations(image=im, mask=gt)
        return transformed['image'], transformed['mask'], im, gt

def collate_fn(batch, image_processor):
    inputs = list(zip(*batch))
    images = inputs[0]
    segmentation_maps = inputs[1]

    batch = image_processor(
        images,
        segmentation_maps=segmentation_maps,
        return_tensors='pt',
        do_resize=False,
        do_rescale=False,
        do_normalize=False
    )

    batch['orig_image'] = inputs[2]
    batch['orig_mask'] = inputs[3]
    return batch

def get_dls(root, transformations, bs, processor=None, split=[0.9,0.05,0.05]):
    assert sum(split) == 1., 'Sum of the split must be exactly 1'
    ds = CustomSegmentationDataset(root=root, transformations=transformations)
    n_cls = ds.n_cls
    tr_len = int(len(ds) * split[0])
    val_len = int(len(ds) * split[1])
    test_len = len(ds) - (tr_len + val_len)
    tr_ds, val_ds, test_ds = torch.utils.data.random_split(
        ds,
        [tr_len, val_len, test_len]
    )
    print(len(tr_ds))
    print(len(val_ds))
    print(len(test_ds))
    collate_func = None
    if processor:
        collate_func = partial(collate_fn, image_processor=processor)
    # Get dataloaders
    tr_dl  = DataLoader(
        dataset=tr_ds,
        batch_size=bs,
        shuffle=True,
        collate_fn=collate_func
    )
    val_dl = DataLoader(
        dataset = val_ds,
        batch_size = bs,
        shuffle = False,
        collate_fn=collate_func
    )
    test_dl = DataLoader(
        dataset = test_ds,
        batch_size = 1,
        shuffle = False,
        collate_fn=collate_func
    )
    return tr_dl, val_dl, test_dl, n_cls
    
root = '/kaggle/input/semantic-segmentation-of-aerial-imagery/Semantic segmentation dataset'
mean, std, im_h, im_w = [0.485, 0.456, 0.406], [0.229,0.224,0.225], 256, 256
trans = A.Compose([
    A.Resize(im_h, im_w),
#     A.augmentations.transforms.Normalize(mean=mean, std=std),
#     ToTensorV2(transpose_mask=True)
])
tr_dl, val_dl, test_dl, n_cls = get_dls(
    root=root,
    transformations=trans,
    bs=4
)
    

In [None]:
class Metrics():
    def __init__(self, pred, gt, loss_fn, eps=1e-10, n_cls=2):
        self.pred, self.gt = torch.argmax(F.softmax(pred, dim=1), dim=1), gt
        self.loss_fn, self.eps, self.n_cls, self.pred_ = loss_fn, eps, n_cls, pred
    
    def to_contiguous(self, inp):
        return inp.contiguous().view(-1)
    
    def PA(self):
        with torch.no_grad():
            match = torch.eq(self.pred, self.gt).int()
        return float(match.sum()) / float(match.numel())
    
    def mIoU(self):
        with torch.no_grad():
            self.gt = torch.argmax(self.gt, dim=1)
            pred, gt = self.to_contiguous(self.pred), self.to_contiguous(self.gt)
            iou_per_class = []
            for c in range(self.n_cls):
                match_pred = pred == c
                match_gt = gt == c
                if match_gt.long().sum().item() == 0:
                    iou_per_class.append(np.nan)
                else:
                    intersect = torch.logical_and(match_pred, match_gt).sum().float().item()
                    union = torch.logical_or(match_pred, match_gt).sum().float().item()
                    iou = (intersect + self.eps) / (union + self.eps)
                    iou_per_class.append(iou)
            return np.nanmean(iou_per_class)
    
    def loss(self):
        return self.loss_fn(self.pred_, torch.argmax(self.gt, dim=1))
    
def tic_toc(start_time=None):
    return time.time() if start_time == None else time.time() - start_time

In [None]:
def train(
    model_name,
    model,
    tr_dl,
    val_dl,
    loss_fn,
    opt,
    device,
    epochs,
    save_prefix,
    save_path='saved_models',
    is_huggingface_model=False
):
    tr_loss, tr_pa, tr_iou = [],[],[]
    val_loss, val_pa, val_iou = [],[],[]
    tr_len, val_len = len(tr_dl), len(val_dl)
    best_loss, decrease, not_improve, early_stop_threshold = np.inf, 1, 0, 7
    os.makedirs(save_path, exist_ok=True)
    model.to(device)
    train_start = tic_toc()
    print('Start training process...')
    for epoch in range(1, epochs+1):
        tic = tic_toc()
        tr_loss_, tr_iou_, tr_pa_ = 0,0,0
        model.train()
        print(f'Epoch {epoch} train process is started...')
        for idx, batch in enumerate(tqdm(tr_dl)):
            ims, gts = batch
            ims, gts = ims.to(device), gts.to(device)
            if model_name == "segformer":
                preds = model(ims)
                preds = preds.logits
                preds = nn.functional.interpolate(
                    preds,
                    size=ims.shape[-2:],
                    mode="bilinear",
                    align_corners=False
                )
            elif model_name == "mask2former":
                pass
            else:
                preds = model(ims)

            met = Metrics(preds, gts, loss_fn, n_cls=n_cls)
            loss_ = met.loss()
#             print("loss", loss_)
            tr_iou_ += met.mIoU()
            tr_pa_ += met.PA()
            tr_loss_ += loss_.item()
            loss_.backward()
            opt.step()
            opt.zero_grad()
        print(f"Epoch {epoch} validation process is started...")
        model.eval()
        val_loss_, val_iou_, val_pa_ = 0,0,0
        with torch.no_grad():
            for idx, batch in enumerate(tqdm(val_dl)):
                ims, gts = batch
                ims, gts = ims.to(device), gts.to(device)
                if model_name == "segformer":
                    preds = model(ims)
                    preds = preds.logits
                    preds = nn.functional.interpolate(
                        preds,
                        size=ims.shape[-2:],
                        mode="bilinear",
                        align_corners=False
                    )
                else:
                    preds = model(ims)
                met = Metrics(preds, gts, loss_fn, n_cls=n_cls)
                val_loss_ += met.loss().item()
                val_iou_ += met.mIoU()
                val_pa_ += met.PA()
        
        print(f'Epoch {epoch} train process is completed')
        tr_loss_ /= tr_len
        tr_iou_ /= tr_len
        tr_pa_ /= tr_len
        
        val_loss_ /= val_len
        val_iou_ /= val_len
        val_pa_ /= val_len
        
        print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        print(f"\nEpoch {epoch} train process results: \n")
        print(f"Train Time         -> {tic_toc(tic):.3f} secs")
        print(f"Train Loss         -> {tr_loss_:.3f}")
        print(f"Train PA           -> {tr_pa_:.3f}")
        print(f"Train IoU          -> {tr_iou_:.3f}")
        print(f"Validation Loss    -> {val_loss_:.3f}")
        print(f"Validation PA      -> {val_pa_:.3f}")
        print(f"Validation IoU     -> {val_iou_:.3f}\n")
        
        tr_loss.append(tr_loss_)
        tr_iou.append(tr_iou_)
        tr_pa.append(tr_pa_)

        val_loss.append(val_loss_)
        val_iou.append(val_iou_)
        val_pa.append(val_pa_)
        
        if best_loss > (val_loss_):
            print(f"Loss decreased from {best_loss:.3f} to {val_loss_:.3f}!")
            best_loss = val_loss_
            decrease += 1
#             not_improve = 0
            if decrease % 2 == 0:
                print("Saving the model with the best loss value...")
                torch.save(model, f"{save_path}/{save_prefix}_best_model.pt")
            
        if val_loss_ > best_loss:

            not_improve += 1
            best_loss = val_loss_
            print(f"Loss did not decrease for {not_improve} epoch(s)!")
            if not_improve == early_stop_threshold:
                print(f"Stopping training process becuase loss value did not decrease for {early_stop_threshold} epochs!")
                break
        print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
            
    print(f"Train process is completed in {(tic_toc(train_start)) / 60:.3f} minutes.")
    
    return {"tr_loss": tr_loss, "tr_iou": tr_iou, "tr_pa": tr_pa,
            "val_loss": val_loss, "val_iou": val_iou, "val_pa" : val_pa}

In [None]:
# from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
# from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
# processor = SegformerImageProcessor()
processor = Mask2FormerImageProcessor()
tr_dl, val_dl, test_dl, n_cls = get_dls(
    root=root,
    transformations=trans,
    bs=4,
    processor=processor
)
batch = next(iter(tr_dl))
imgs, labels = batch
print(imgs.shape)
print(labels.shape)

# pixel_values = processor.preprocess(imgs, labels, return_tensors='pt').to('cuda')
model = Mask2FormerForUniversalSegmentation.from_pretrained(
    'facebook/mask2former-swin-tiny-ade-semantic',
    num_labels=n_cls,
    ignore_mismatched_sizes=True
).to('cuda')
outputs = model(**pixel_values)

In [None]:
semantic_segmetation = processor.post_process_semantic_segmentation(
    outputs,
    target_sizes=[(256,256)]
)

In [None]:
model_name = "segformer"
if model_name == "deeplab3plus":
    model = smp.DeepLabV3Plus(classes=n_cls)
elif model_name == "unet":
    model = smp.Unet(
        encoder_name="timm-efficientnet-b8",
        encoder_weights='imagenet',
        in_channels=3,
        classes=n_cls,
    )
elif model_name == "segformer":
    model = SegformerForSemanticSegmentation.from_pretrained(
        'nvidia/mit-b5',
        num_labels=n_cls,
    )
elif model_name == "mask2former":
    model = Mask2FormerForUniversalSegmentation.from_pretrained(
        'facebook/mask2former-swin-tiny-ade-semantic',
        num_labels=n_cls,
        ignore_mismatched_sizes=True
    )
else:
    model = smp.DeepLabV3Plus(classes=n_cls)

# loss_fn = torch.nn.CrossEntropyLoss()
loss_fn = JaccardLoss(mode='multiclass')

optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"

history = train(
    model_name=model_name,
    model=model,
    tr_dl=tr_dl,
    val_dl=val_dl,
    loss_fn=loss_fn,
    opt=optimizer,
    device=device,
    epochs=50,
    save_prefix="aerial"
)

In [None]:
class Plot():
    
    def __init__(self, res):
        
        self.res = res
        
        self.visualize(metric1 = "tr_iou", metric2 = "val_iou", label1 = "Train IoU", 
                  label2 = "Validation IoU", title = "Mean Intersection Over Union Learning Curve", ylabel = "mIoU Score")
    
        self.visualize(metric1 = "tr_pa", metric2 = "val_pa", label1 = "Train PA", 
                  label2 = "Validation PA", title = "Pixel Accuracy Learning Curve", ylabel = "PA Score")
        
        self.visualize(metric1 = "tr_loss", metric2 = "val_loss", label1 = "Train Loss", 
                  label2 = "Validation Loss", title = "Loss Learning Curve", ylabel = "Loss Value")
        
        
    def plot(self, metric, label): plt.plot(self.res[metric], label = label)
    
    def decorate(self, ylabel, title): plt.title(title); plt.xlabel("Epochs"); plt.ylabel(ylabel); plt.legend(); plt.show()
    
    def visualize(self, metric1, metric2, label1, label2, title, ylabel):
        
        plt.figure(figsize=(10, 5))
        self.plot(metric1, label1); self.plot(metric2, label2)
        self.decorate(ylabel, title)                
        
Plot(history)

In [None]:
import random
from torchvision import transforms as tfs

def tn_2_np(t): 
    invTrans = tfs.Compose([ tfs.Normalize(mean = [ 0., 0., 0. ], std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                tfs.Normalize(mean = [ -0.485, -0.456, -0.406 ], std = [ 1., 1., 1. ]) ])
    
    rgb = True if len(t) == 3 else False
    
    return (invTrans(t) * 255).detach().cpu().permute(1,2,0).numpy().astype(np.uint8) if rgb else (t*255).detach().cpu().numpy().astype(np.uint8)

def plot(rows, cols, count, im, gt = None, title = "Original Image"):
    plt.subplot(rows, cols, count)
    plt.imshow(tn_2_np(im.squeeze(0).float())) if gt else plt.imshow(tn_2_np(im.squeeze(0)))
    plt.axis("off"); plt.title(title)
    return count + 1

def visualize(ds, n_ims):
    
    plt.figure(figsize = (25, 20))
    rows = n_ims // 4; cols = n_ims // rows
    count = 1
    indices = [random.randint(0, len(ds) - 1) for _ in range(n_ims)]
    
    for idx, index in enumerate(indices):
        
        if count == n_ims + 1: break
        im, gt = ds[index]
        
        # First Plot
        count = plot(rows, cols, count, im = im)
        
        # Second Plot
        count = plot(rows, cols, count, im = gt, gt = True)
        
visualize(tr_dl.dataset, n_ims = 20)

In [None]:
def inference(dl, model_name, model, device, n_ims = 15):
    
    cols = n_ims // 3; rows = n_ims // cols
    
    count = 1
    ims, gts, preds = [], [], []
    for idx, data in enumerate(dl):
        im, gt = data

        # Get predicted mask
        with torch.no_grad():
#             pred = torch.argmax(model(im.to(device)), dim = 1)
            if model_name == "segformer":
#                 print(pred.shape)
                pred = model(im.to(device)).logits
                pred = nn.functional.interpolate(
                    pred,
                    size=im.shape[-2:],
                    mode="bilinear",
                    align_corners=False
                )
                pred = torch.argmax(pred, dim = 1)
            else:
                pred = torch.argmax(model(im.to(device)), dim = 1)
        ims.append(im);
        gts.append(gt);
        preds.append(pred)
#         print(pred.unique())
        
    plt.figure(figsize = (25, 20))
    for idx, (im, gt, pred) in enumerate(zip(ims, gts, preds)):
        if idx == cols: break
        
        # First plot
        count = plot(cols, rows, count, im)

        # Second plot
        count = plot(cols, rows, count, im = gt, gt = True, title = "Ground Truth")

        # Third plot
        count = plot(cols, rows, count, im = pred, title = "Predicted Mask")

model = torch.load("/kaggle/working/saved_models/aerial_best_model.pt")
inference(
    test_dl,
    model_name=model_name,
    model=model,
    device=device,
)