# Setup TPU XLA

In [None]:
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

%reload_ext autoreload
%autoreload
import os

if 'TPU_NAME' in os.environ.keys():
    
    try:
        import torch_xla
    except:
        # XLA powers the TPU support for PyTorch
        !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
        !python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

# Import Library

In [None]:
!pip install pytorch-lightning
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable

from PIL import Image
import cv2
import albumentations as A

import time
import os
from tqdm.notebook import tqdm

import segmentation_models_pytorch as smp
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import iou

# Datasets

In [None]:
IMAGE_PATH = '../input/semantic-drone-dataset/semantic_drone_dataset/original_images/'
MASK_PATH = '../input/semantic-drone-dataset/semantic_drone_dataset/label_images_semantic/'


n_classes = 23 
# tree, gras, other vegetation, dirt, gravel, rocks, water, 
#paved area, pool, person, dog, car, bicycle, roof, wall, fence, 
#fence-pole, window, door, obstacle

#read file id in directory
name = []
for dirname, _, filenames in os.walk(IMAGE_PATH):
    for filename in filenames:
        name.append(filename.split('.')[0])

df = pd.DataFrame({'id': name}, index = np.arange(0, len(name)))
print('Number of Data: ', len(df))

#split data
X_trainval, X_test = train_test_split(df['id'].values, test_size=0.1, random_state=19)
X_train, X_val = train_test_split(X_trainval, test_size=0.15, random_state=19)

print('Train Size   : ', len(X_train))
print('Val Size     : ', len(X_val))
print('Test Size    : ', len(X_test))

#Costum datasets

class Drone_data(Dataset):
    
    def __init__(self, img_path, mask_path, X, mean, std, transform=None, patch=False):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.transform = transform
        self.patches = patch
        self.mean = mean
        self.std = std
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        """
        get item per index
        """
        img = cv2.imread(self.img_path + self.X[idx] + '.jpg')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_path + self.X[idx] + '.png', cv2.IMREAD_GRAYSCALE)
        
        if self.transform is not None:
            aug = self.transform(image=img, mask=mask)
            img = Image.fromarray(aug['image'])
            mask = aug['mask']
        
        if self.transform is None:
            img = Image.fromarray(img)
        
        #build in transfrom image
        t = T.Compose([T.ToTensor(), T.Normalize(self.mean, self.std)])
        img = t(img)
        mask = torch.from_numpy(mask).long()
        
        if self.patches:
            img, mask = self.tiles(img, mask)
            
        return img, mask
    
    def tiles(self, img, mask):
        """
        split image into smaler patches 
        """
        #for image
        img_patches = img.unfold(1, 512, 512).unfold(2, 768, 768) #tile overlap 50pixel
        img_patches  = img_patches.contiguous().view(3,-1, 512, 768) #change to total tile
        img_patches = img_patches.permute(1,0,2,3)#place the tiles number in the 0 index
        
        #for mask
        mask_patches = mask.unfold(0, 512, 512).unfold(1, 768, 768)
        mask_patches = mask_patches.contiguous().view(-1, 512, 768)
        
        return img_patches, mask_patches
    
    
def calc_resize_ration(re_ration_to_depth=10, unet_depth=5, width=6000, height=4000):
    """
    help function to resize image in order to be used in U-Net base model
    """
    width_to_height_rat = width/height
    unet_ratio = 2**unet_depth

    wid = unet_ratio * re_ration_to_depth * width_to_height_rat
    heig = unet_ratio * re_ration_to_depth 
    
    print('After resize with the same ration:')
    print(f' height:width = {heig, wid}')
    print(f' size after encoding: {heig/(2**unet_depth), wid/(2**unet_depth)}')

# Lovasz Loss

In [None]:
#===================================================================================================================
#source :https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py

def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    if per_image:
        loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
                          for prob, lab in zip(probas, labels))
    else:
        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
    return loss

def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard

def lovasz_softmax_flat(probas, labels, classes='present'):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
    """
    if probas.numel() == 0:
        # only void pixels, the gradients should be 0
        return probas * 0.
    C = probas.size(1)
    losses = []
    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
    for c in class_to_sum:
        fg = (labels == c).float() # foreground for class c
        if (classes is 'present' and fg.sum() == 0):
            continue
        if C == 1:
            if len(classes) > 1:
                raise ValueError('Sigmoid output possible only with 1 class')
            class_pred = probas[:, 0]
        else:
            class_pred = probas[:, c]
        errors = (Variable(fg) - class_pred).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses)


def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    if probas.dim() == 3:
        # assumes output of a sigmoid layer
        B, H, W = probas.size()
        probas = probas.view(B, 1, H, W)
    B, C, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C
    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = (labels != ignore)
    vprobas = probas[valid.nonzero().squeeze()]
    vlabels = labels[valid]
    return vprobas, vlabels

def xloss(logits, labels, ignore=None):
    """
    Cross entropy loss
    """
    return F.cross_entropy(logits, Variable(labels), ignore_index=255)


# --------------------------- HELPER FUNCTIONS ---------------------------
def isnan(x):
    return x != x
    
    
def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n

# Dataloader

In [None]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]


t_train = A.Compose([A.Resize(672, 448, interpolation=cv2.INTER_NEAREST), 
                     A.HorizontalFlip(), A.VerticalFlip(), 
                     A.GridDistortion(p=0.2),
                     A.RandomBrightnessContrast((0,0.5),(0,0.5)),
                     A.GaussNoise()])

t_val = A.Compose([A.Resize(672, 448, interpolation=cv2.INTER_NEAREST),
                   A.HorizontalFlip(),
                   A.GridDistortion(p=0.2)])

t_test = A.Resize(768, 1152, interpolation=cv2.INTER_NEAREST)

#datasets
train_set = Drone_data(IMAGE_PATH, MASK_PATH, X_train, mean, std, t_train, patch=False)
val_set = Drone_data(IMAGE_PATH, MASK_PATH, X_val, mean, std, t_val, patch=False)
#test_set = Drone_data(IMAGE_PATH, MASK_PATH, X_test, mean, std, transform=t_test, patch=False)

#dataloader
batch_size= 3 #4

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)               

# Model

In [None]:
class Drone_Net(pl.LightningModule):
    def __init__(self, max_lr, epoch):
        super(Drone_Net, self).__init__()

        self.model = smp.FPN('efficientnet-b3', encoder_weights='imagenet', 
                             classes=23, activation=None)
        
        self.max_lr = max_lr
        self.epoch = epoch
        
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        # REQUIRED
        optimizer = torch.optim.AdamW(model.parameters(), lr=self.max_lr, weight_decay=1e-4)
        sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.max_lr, epochs=self.epoch,
                                                    steps_per_epoch=len(train_loader))
        return optimizer

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        return {'loss':loss}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self(x)
        
        loss = F.cross_entropy(y_hat, y)
     
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
   
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss,
                'log': tensorboard_logs} 

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

lr = 1e-3
epoch = 20
model = Drone_Net(lr, epoch)

# early_stop_callback = EarlyStopping(
#    monitor='val_loss',
#    min_delta=0.00,
#    patience=5,
#    verbose=False,
#    mode='min')



# # default used by the Trainer
# checkpoint_callback = ModelCheckpoint(
#     filepath='/kaggle/working/',
#     save_top_k=True,
#     verbose=True,
#     monitor='val_loss',
#     mode='min')

In [None]:
%%time
# most basic trainer, uses good defaults (1 TPU)
trainer = pl.Trainer(tpu_cores=8, max_epochs=epoch, precision=16, fast_dev_run=True,
                     early_stop_callback=True, 
                     #checkpoint_callback=checkpoint_callback,
                     check_val_every_n_epoch=1,
                     num_sanity_val_steps=1)

trainer.fit(model, train_loader, val_loader)

# Evaluation

In [None]:
def predict_image_mask(model, image, mask):
    model.eval()
    model.to(device); image=image.to(device)
    mask = mask.to(device)
    with torch.no_grad():
        image = image.unsqueeze(0)
        mask = mask.unsqueeze(0)
        
        output = model(image)
        score = mIoU(output, mask)
        masked = torch.argmax(output, dim=1)
        masked = masked.cpu().squeeze(0)
    return masked, score

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return 
        
unormal = UnNormalize(mean, std)

In [None]:
%%time
image, mask = test_set[0]
pred_mask, score = predict_image_mask(model, image,mask)

print(score)

In [None]:
unormal(image)
fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(15,10))
ax1.imshow(pred_mask)
ax1.set_title('Prediction Mask | mIoU Score {:.3f}'.format(score))
ax2.imshow(mask)
ax2.set_title('Ground truth')
ax3.imshow(image.permute(1,2,0))
ax3.set_title('Picture');
plt.show

In [None]:
pred = pred_mask.view(-1).long().detach().numpy()
true = mask.view(-1).long().detach().numpy()