<a href="https://www.kaggle.com/code/alanyu223/unet-segmentation-on-carvana-dataset?scriptVersionId=100433479" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import gc
import math
import glob
import os
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torchvision.utils
import pytorch_lightning as pl
import torchmetrics as tm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from typing import List
from PIL import Image

In [2]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed()

> SEEDING DONE


In [3]:
folder = '/kaggle/working/'
for filename in os.listdir(folder):
    file_path = os.path.join(folder, filename)
    try:
        if os.path.isfile(file_path) or os.path.islink(file_path):
            os.unlink(file_path)
        elif os.path.isdir(file_path):
            shutil.rmtree(file_path)
    except Exception as e:
        print('Failed to delete %s. Reason: %s' % (file_path, e))

In [4]:
# Check for GPU or setup TPU
# if not torch.cuda.is_available():
#     !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#     !python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
#     !pip install pytorch-lightning==1.1.5

# Model Architecture

The following image depicts the architecture that will be used for segmentation. We will first define a torch module as the building block for the model, then use a pytorch lightning module to define the final model.

![](https://raw.githubusercontent.com/aladdinpersson/Machine-Learning-Collection/master/ML/Pytorch/image_segmentation/semantic_segmentation_unet/UNET_architecture.png)

In [5]:
!git clone https://github.com/wemoveon2/SegLoss.git
!cp -r SegLoss/losses_pytorch/* .
from dice_loss import GDiceLossV2, FocalTversky_loss
from hausdorff import HausdorffERLoss

Cloning into 'SegLoss'...
remote: Enumerating objects: 494, done.[K
remote: Counting objects: 100% (160/160), done.[K
remote: Compressing objects: 100% (91/91), done.[K
remote: Total 494 (delta 85), reused 128 (delta 69), pack-reused 334[K
Receiving objects: 100% (494/494), 457.25 KiB | 777.00 KiB/s, done.
Resolving deltas: 100% (242/242), done.


In [6]:
ft_loss_args = {
    'apply_nonlin': F.sigmoid, 
    'batch_dice': True,
}
hausdorff_er_loss = HausdorffERLoss()
generalized_dice_loss_v2 = GDiceLossV2(apply_nonlin=F.sigmoid)
ft_loss = FocalTversky_loss(ft_loss_args)

def loss_fn(pred, gt):
    return 0.5*hausdorff_er_loss(pred, gt) + 0.25*generalized_dice_loss_v2(pred, gt) + 0.25*ft_loss(pred, gt)

In [7]:
!pip install segmentation_models_pytorch
import segmentation_models_pytorch as smp

BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)

@torch.no_grad()
def dice_coef(y_pred, y_true, thr=0.5, dim=(2,3), epsilon=0.001):
    y_t = y_true.to(torch.float32)
    y_p = (y_pred>thr).to(torch.float32)
    inter = (y_t*y_p).sum(dim=dim)
    den = y_t.sum(dim=dim) + y_p.sum(dim=dim)
    del y_t, y_p
    dice = ((2*inter+epsilon)/(den+epsilon)).mean(dim=(1,0))
    del inter, den
    return dice
@torch.no_grad()
def iou_coef(y_pred, y_true, thr=0.5, dim=(2,3), epsilon=0.001):
    y_t = y_true.to(torch.float32)
    y_p = (y_pred>thr).to(torch.float32)
    inter = (y_t*y_p).sum(dim=dim)
    union = (y_t + y_p - y_t*y_p).sum(dim=dim)
    del y_t, y_p
    iou = ((inter+epsilon)/(union+epsilon)).mean(dim=(1,0))
    del inter, union
    return iou

def loss_fn(y_pred, y_true):
    return 0.5*BCELoss(y_pred, y_true) + 0.5*TverskyLoss(y_pred, y_true)

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.2.1-py3-none-any.whl (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.6/88.6 kB[0m [31m300.2 kB/s[0m eta [36m0:00:00[0m
Collecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- done
[?25hCollecting efficientnet-pytorch==0.6.3
  Downloading efficientnet_pytorch-0.6.3.tar.gz (16 kB)
  Preparing metadata (setup.py) ... [?25l- done
[?25hCollecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m377.0/377.0 kB[0m [31m931.8 kB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: efficientnet-pytorch, pretrainedmodels
  Building wheel for efficientnet-pytorch (setup.py) ..

In [8]:
class DropBlock2D(nn.Module):
    r"""Randomly zeroes 2D spatial blocks of the input tensor.
    As described in the paper
    `DropBlock: A regularization method for convolutional networks`_ ,
    dropping whole blocks of feature map allows to remove semantic
    information as compared to regular dropout.
    Args:
        drop_prob (float): probability of an element to be dropped.
        block_size (int): size of the block to drop
    Shape:
        - Input: `(N, C, H, W)`
        - Output: `(N, C, H, W)`
    .. _DropBlock: A regularization method for convolutional networks:
       https://arxiv.org/abs/1810.12890
    """

    def __init__(self, drop_prob, block_size):
        super(DropBlock2D, self).__init__()
        self.drop_prob = drop_prob
        self.block_size = block_size

    def forward(self, x):
        # shape: (bsize, channels, height, width)
        assert x.dim() == 4, \
            "Expected input with 4 dimensions (bsize, channels, height, width)"
        if not self.training or self.drop_prob == 0.:
            return x
        else:
            # get gamma value
            gamma = self._compute_gamma(x)
            # sample mask
            mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
            # place mask on input device
            mask = mask.to(x.device)
            # compute block mask
            block_mask = self._compute_block_mask(mask)
            # apply block mask
            out = x * block_mask[:, None, :, :]
            # scale output
            out = out * block_mask.numel() / block_mask.sum()
            return out

    def _compute_block_mask(self, mask):
        block_mask = F.max_pool2d(input=mask[:, None, :, :],
                                  kernel_size=(self.block_size, self.block_size),
                                  stride=(1, 1),
                                  padding=self.block_size // 2)
        if self.block_size % 2 == 0:
            block_mask = block_mask[:, :, :-1, :-1]
        block_mask = 1 - block_mask.squeeze(1)
        return block_mask

    def _compute_gamma(self, x):
        return self.drop_prob / (self.block_size ** 2)

class Block(pl.LightningModule):
    def __init__(self, in_channels: int, out_channels: int, kernel: int = 3, 
                 stride: int = 1, padding: int = 1, bias: bool = False, shortcut: bool = True):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel, stride, padding, bias=bias),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel, stride, padding, bias=bias),
            nn.BatchNorm2d(out_channels),
            DropBlock2D(0.8, 3)
        )
        if shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel, stride, padding, bias=bias),
                nn.BatchNorm2d(out_channels))
        else:
            self.shortcut = None
    def forward(self, x):
        identity = x
        out = self.conv(x)
        if self.expand_dims and self.shortcut is not None:
            identity = self.shortcut(identity)
        out += identity
        out = F.relu(out)
        return out
    
    def expand_dims(self):
        return self.in_channels != self.out_channels

class UNET(pl.LightningModule):
    def __init__(self, loss_fn, in_channels: int = 3, n_classes: int = 1, 
                 features: List[int] = [64,128,256,512,1024], learning_rate: float = 1.5e-3):
        super().__init__()
        self.learning_rate = learning_rate 
        self.down = nn.ModuleList()
        self.up = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)
        for feature in features:
            self.down.append(Block(in_channels, feature))
            in_channels=feature
        for feature in reversed(features):
            self.up.append(
                nn.ConvTranspose2d(feature*2, feature, 2, 2)
            )
            self.up.append(
                Block(feature*2, feature) # x gets concat to 2xchannel
            )
        self.bottleneck = Block(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], n_classes, 1)
        self.loss_fn = loss_fn
        
        self.val_num_correct = 0
        self.val_num_pixels = 0
        self.val_dice_score = 0
        self.num_correct = 0
        self.num_pixels = 0
        self.dice_score = 0
    def forward(self, x):
        skip_connections = []
        for down in self.down:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.up), 2):
            x = self.up[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim=1) # Concat along channels (b, c, h, w)
            x = self.up[idx+1](concat_skip)
        return self.final_conv(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss_fn(pred, y)
        pred = torch.sigmoid(pred)
        pred = (pred > 0.5).float()
        self.num_correct += (pred == y).sum()
        self.num_pixels += torch.numel(pred)
        self.dice_score += (2 * (pred * y).sum()) / (
            (pred + y).sum() + 1e-8
        )
        self.log('train_loss', loss, prog_bar = True, logger = True)
        return {'loss': loss}
    
    def training_epoch_end(self, output):
        train_acc = float(f'{(self.num_correct/self.num_pixels)*100:.2f}')
        self.log('train_acc', train_acc, prog_bar = True, logger = True)
        dice_score = self.dice_score/len(output)
        self.log('train_dice_score', dice_score, prog_bar = True, logger = True)
        self.num_correct, self.num_pixels, self.dice_score = 0,0,0
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss_fn(pred, y)
        pred = torch.sigmoid(pred)
        pred = (pred > 0.5).float()
        self.val_num_correct += (pred == y).sum()
        self.val_num_pixels += torch.numel(pred)
        self.val_dice_score += (2 * (pred * y).sum()) / (
            (pred + y).sum() + 1e-8
        )
        self.log('val_loss', loss, prog_bar = True, logger = True)
        return {'loss': loss}
    
    def validation_epoch_end(self, output):
        val_acc = float(f'{(self.val_num_correct/self.val_num_pixels)*100:.2f}')
        self.log('val_acc', val_acc, prog_bar = True, logger = True)
        dice_score = self.val_dice_score/len(output)
        self.log('val_dice_score', dice_score, prog_bar = True, logger = True)
        self.val_num_correct, self.val_num_pixels, self.val_dice_score = 0,0,0
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(params = self.parameters(), lr = self.learning_rate, weight_decay=0.3)
        return optimizer

We can see how many parameters we have in our model using the following command. Note that the order shown does not reflect how a training example flows through the architecture. 

In [9]:
# pl.utilities.model_summary.summarize(UNET(),-1)

# Data
Now that we have our model, we define a torch Dataset to retrieve and apply our transformations to our data. We define the three necessary methods and move on to defining our LightningDataModule, which will split our Dataset into train/val splits and prepare the appropriate Dataloader when called by the Trainer object later on when we start training.

In [10]:
class SegmentationDataset(torch.utils.data.Dataset):
        
    def __init__(self, image_path, mask_path, transforms):
        self.images = glob.glob(os.path.join(image_path, '*.jpg'))
        self.image_path = image_path
        self.mask_path = mask_path
        self.transforms = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = np.array(Image.open(self.images[idx]).convert('RGB'))
        mask = np.array(Image.open(os.path.join(self.mask_path, os.path.basename(self.images[idx]).replace('.jpg', '.png')))) 
        mask[mask == 255.0] = 1.0  
        augmentations = self.transforms(image=img, mask=mask)
        image = augmentations["image"]
        mask = augmentations["mask"]
        mask = torch.unsqueeze(mask, 0)
        mask = mask.type(torch.float32)
        return image.float(), mask

In [11]:
import cv2

data_transforms = {
    "valid": A.Compose([
        A.Resize(*[255,255], interpolation=cv2.INTER_NEAREST),
        A.pytorch.ToTensorV2()]),
    "train": A.Compose([
        A.Resize(*[255,255], interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
        A.OneOf([
            A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
            A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=0.3),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
        ], p=0.40),
        A.pytorch.ToTensorV2()])
}

In [12]:
class SegmentationDataModule(pl.LightningDataModule):
    
    def __init__(self, image_path: str, mask_path:str , transform, 
                 train_size: float = 0.90, batch_size: int = 5):
        super().__init__()
        self.image_path = image_path
        self.mask_path = mask_path
        self.batch_size = batch_size
        self.transform = transform
        self.train_size = train_size
        
    def setup(self, stage = None):
        if stage in (None, 'fit'):
            ds = SegmentationDataset(self.image_path, self.mask_path, self.transform)
            train_size = math.floor(len(ds)*self.train_size)
            val_size = len(ds)-train_size
            train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
            self.train_dataset = train_ds
            self.val_dataset = val_ds
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, self.batch_size, num_workers=2, shuffle = True, persistent_workers=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, self.batch_size, num_workers=2, persistent_workers=True)

# Training

Once our model, optimizer+loss, and LightningDataModule is defined, we can use Trainer to run our training and validation loops. We can also use callbacks provided by lightning (check out [Bolts](https://lightning-bolts.readthedocs.io/en/latest/) for advanced callbacks, such as for sparsification) to create checkpoints or for early stopping. Trainer also takes care of multi-GPU and TPU training. 

We also use albumentations to apply the **same** data augmentations on the image **and** mask. I don't believe the torchvision transforms allows you to do the same, but correct me if I'm wrong. 

In [13]:
torch.cuda.empty_cache()
# transform = A.Compose(
#     [
#         A.Resize(height=255, width=255),
#         A.Rotate(limit=45, p=0.7),
#         A.HorizontalFlip(p=0.5),
#         A.VerticalFlip(p=0.3),
#         A.Normalize(
#             mean=[0.0, 0.0, 0.0],
#             std=[1.0, 1.0, 1.0],
#             max_pixel_value=255.0,
#         ),
#         A.pytorch.ToTensorV2(),
#     ]
# )
args = {
    'in_channels': 3,
    'n_classes': 1,
    'features': [64,128,256,512,1024],
    'loss_fn': loss_fn
}
ds = SegmentationDataModule('../input/carvana-image-masking-png/train_images', '../input/carvana-image-masking-png/train_masks', transform=data_transforms['train'])
model = UNET(**args)
if not os.path.isdir('./unet_model'): os.mkdir('./unet_model')
checkpointCallback = pl.callbacks.ModelCheckpoint(dirpath="./unet_model", 
                                                  save_top_k=1, 
                                                  monitor="val_loss",
                                                 filename='{epoch}-{val_loss:.5f}',
                                                 mode='min')
if torch.cuda.is_available():
    trainer = pl.Trainer(max_epochs=8, accelerator='gpu', gpus=1, 
                         callbacks=[checkpointCallback], profiler='simple',
                         auto_lr_find=True, accumulate_grad_batches=20)
else:
    trainer = pl.Trainer(max_epochs=7, accelerator='tpu',  
                         callbacks=[checkpointCallback], auto_lr_find=True)

trainer.fit(model, datamodule=ds)

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

# Utility Functions

Here we define three utility functions:
  1. **save_images**(model, loader, folder, device)
      - This is used to save predictions and their corresponding masks from a specified model across a specified dataloader.
      - Device can be set to cpu if not running on gpu
  2. **get_concat_v**(im1, im2)
      - This is used to later concatenate the prediction and mask images we create using save_image.
      - im1 & im2 are PIL.Image objects
  3. **merge_photos**(src_folder, dst_folder, remove_single)
      - This is used to read and concatenate the prediction and mask images using the previously defined get_concat_v().
      - remove_single can be set to False to save the unmerged images

In [14]:
def save_images(model, loader, folder='val_img', device='cuda'):
    model.eval()
    if not os.path.isdir(folder):
        os.mkdir(folder)
    model.to(device=device)
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        y = y.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x).cuda())
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y, f"{folder}/mask_{idx}.png")       

def get_concat_v(im1, im2):
    dst = Image.new('RGB', (im1.width, im1.height + im2.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (0, im1.height))
    return dst

def merge_photos(src_folder: str='./val_img', dst_folder: str='./merged_val_img', remove_single: bool=True):
    files = glob.glob(src_folder+'/*.png')
    if not os.path.isdir(dst_folder):
        os.mkdir(dst_folder)
    for i in range(int(len(files)/2)):
        pred_img = Image.open(f'{src_folder}/pred_{i}.png')
        mask_img = Image.open(f'{src_folder}/mask_{i}.png')
        get_concat_v(pred_img, mask_img).save(f'{dst_folder}/merged_pred_mask_{i}.png')
        if remove_single:
            os.remove(f'./val_img/pred_{i}.png')
            os.remove(f'./val_img/mask_{i}.png')

In [15]:
ds = SegmentationDataModule('../input/carvana-image-masking-png/train_images', '../input/carvana-image-masking-png/train_masks', transform=data_transforms['valid'])
ds.setup()
save_images(model, ds.val_dataloader())
merge_photos()
# TODO: Implement feature for sampling random img/label pairs for model prediction/eval
#         - Should show output and mask as images, pref side by side or across batches
#         - Conduct additional testing using the competition or data from the internet

In [16]:
!zip -r lightning_logs.zip ./lightning_logs
!zip -r merged_val_img.zip ./merged_val_img
!zip -r unet.zip ./unet_model

  adding: lightning_logs/ (stored 0%)
  adding: lightning_logs/version_0/ (stored 0%)
  adding: lightning_logs/version_0/hparams.yaml (stored 0%)
  adding: lightning_logs/version_0/events.out.tfevents.1657398734.e00bbb48e8c2.23.0 (deflated 61%)
  adding: merged_val_img/ (stored 0%)
  adding: merged_val_img/merged_pred_mask_89.png (deflated 11%)
  adding: merged_val_img/merged_pred_mask_48.png (deflated 11%)
  adding: merged_val_img/merged_pred_mask_41.png (deflated 11%)
  adding: merged_val_img/merged_pred_mask_44.png (deflated 10%)
  adding: merged_val_img/merged_pred_mask_79.png (deflated 9%)
  adding: merged_val_img/merged_pred_mask_56.png (deflated 12%)
  adding: merged_val_img/merged_pred_mask_28.png (deflated 12%)
  adding: merged_val_img/merged_pred_mask_97.png (deflated 11%)
  adding: merged_val_img/merged_pred_mask_85.png (deflated 10%)
  adding: merged_val_img/merged_pred_mask_86.png (deflated 9%)
  adding: merged_val_img/merged_pred_mask_25.png (deflated 10%)

Since Kaggle doesn't work with Tensorboard, download and load the logs in Google Collab to analyze the training and validation metrics

<button><a href='./lightning_logs.zip'>Download Logs</a></button>
<button><a href='./merged_val_img.zip'>Download Images</a></button>
<button><a href='./unet.zip'>Download Model</a></button>