<a href="https://colab.research.google.com/github/seanreed1111/cnn-demos/blob/main/unet_segmentation_on_carvana_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
pip install -q torchmetrics albumentations pytorch-lightning

[K     |████████████████████████████████| 419 kB 4.1 MB/s 
[K     |████████████████████████████████| 707 kB 33.5 MB/s 
[K     |████████████████████████████████| 5.9 MB 41.5 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.8.2+zzzcolab20220719082949 requires tensorboard<2.9,>=2.8, but you have tensorboard 2.10.0 which is incompatible.[0m
[?25h

source: https://www.kaggle.com/code/alanyu223/unet-segmentation-on-carvana-dataset

In [3]:
import gc
import math
import glob
import os
import torch
import numpy as np
import torch.nn as nn
import torchvision.transforms.functional as TF
import torchvision.utils
import pytorch_lightning as pl
import torchmetrics as tm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from typing import List
from PIL import Image

In [None]:
# # Check for GPU or setup TPU
# if not torch.cuda.is_available():
#     !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#     !python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
#     !pip install pytorch-lightning==1.1.5

# Model Architecture

The following image depicts the architecture that will be used for segmentation. We will first define a torch module as the building block for the model, then use a pytorch lightning module to define the final model.

![](https://raw.githubusercontent.com/aladdinpersson/Machine-Learning-Collection/master/ML/Pytorch/image_segmentation/semantic_segmentation_unet/UNET_architecture.png)

In [4]:
class Block(pl.LightningModule):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self, x):
        return self.conv(x)

class UNET(pl.LightningModule):
    def __init__(self, in_channels: int=3, out_channels: int=1, features: List=[64,128,256,512], learning_rate=1.5e-3):
        super().__init__()
        self.learning_rate = learning_rate 
        self.down = nn.ModuleList()
        self.up = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)
        for feature in features:
            self.down.append(Block(in_channels, feature))
            in_channels=feature
        for feature in reversed(features):
            self.up.append(
                nn.ConvTranspose2d(feature*2, feature, 2, 2)
            )
            self.up.append(
                Block(feature*2, feature) # x gets concat to 2xchannel
            )
        self.bottleneck = Block(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, 1)
        self.loss_fn = nn.BCEWithLogitsLoss()
        
        self.val_num_correct = 0
        self.val_num_pixels = 0
        self.val_dice_score = 0
        self.num_correct = 0
        self.num_pixels = 0
        self.dice_score = 0
    def forward(self, x):
        skip_connections = []
        for down in self.down:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.up), 2):
            x = self.up[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim=1) # Concat along channels (b, c, h, w)
            x = self.up[idx+1](concat_skip)
        return self.final_conv(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss_fn(pred, y)
        pred = torch.sigmoid(pred)
        pred = (pred > 0.5).float()
        self.num_correct += (pred == y).sum()
        self.num_pixels += torch.numel(pred)
        self.dice_score += (2 * (pred * y).sum()) / (
            (pred + y).sum() + 1e-8
        )
        self.log('train_loss', loss, logger = True)
        return {'loss': loss}
    
    def training_epoch_end(self, output):
        train_acc = float(f'{(self.num_correct/self.num_pixels)*100:.2f}')
        self.log('train_acc', train_acc, prog_bar = True, logger = True)
        dice_score = self.dice_score/len(output)
        self.log('train_dice_score', dice_score, prog_bar = True, logger = True)
        self.num_correct, self.num_pixels, self.dice_score = 0,0,0
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss_fn(pred, y)
        pred = torch.sigmoid(pred)
        pred = (pred > 0.5).float()
        self.val_num_correct += (pred == y).sum()
        self.val_num_pixels += torch.numel(pred)
        self.val_dice_score += (2 * (pred * y).sum()) / (
            (pred + y).sum() + 1e-8
        )
        self.log('val_loss', loss, prog_bar = True, logger = True)
        return {'loss': loss}
    
    def validation_epoch_end(self, output):
        val_acc = float(f'{(self.val_num_correct/self.val_num_pixels)*100:.2f}')
        self.log('val_acc', val_acc, prog_bar = True, logger = True)
        dice_score = self.val_dice_score/len(output)
        self.log('val_dice_score', dice_score, prog_bar = True, logger = True)
        self.val_num_correct, self.val_num_pixels, self.val_dice_score = 0,0,0
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(params = self.parameters(), lr = self.learning_rate, weight_decay=0.3)
        return optimizer

We can see how many parameters we have in our model using the following command. Note that the order shown does not reflect how a training example flows through the architecture. 

In [5]:
pl.utilities.model_summary.summarize(UNET(),-1)

   | Name              | Type              | Params
---------------------------------------------------------
0  | down              | ModuleList        | 4.7 M 
1  | down.0            | Block             | 38.8 K
2  | down.0.conv       | Sequential        | 38.8 K
3  | down.0.conv.0     | Conv2d            | 1.7 K 
4  | down.0.conv.1     | BatchNorm2d       | 128   
5  | down.0.conv.2     | ReLU              | 0     
6  | down.0.conv.3     | Conv2d            | 36.9 K
7  | down.0.conv.4     | BatchNorm2d       | 128   
8  | down.0.conv.5     | ReLU              | 0     
9  | down.1            | Block             | 221 K 
10 | down.1.conv       | Sequential        | 221 K 
11 | down.1.conv.0     | Conv2d            | 73.7 K
12 | down.1.conv.1     | BatchNorm2d       | 256   
13 | down.1.conv.2     | ReLU              | 0     
14 | down.1.conv.3     | Conv2d            | 147 K 
15 | down.1.conv.4     | BatchNorm2d       | 256   
16 | down.1.conv.5     | ReLU              | 0     
17 | d

# Data
Now that we have our model, we define a torch Dataset to retrieve and apply our transformations to our data. We define the three necessary methods and move on to defining our LightningDataModule, which will split our Dataset into train/val splits and prepare the appropriate Dataloader when called by the Trainer object later on when we start training.

In [6]:
class SegmentationDataset(torch.utils.data.Dataset):
  def __init__(self, image_path, mask_path, transforms):
    self.images = glob.glob(os.path.join(image_path, '*.jpg'))
    self.image_path = image_path
    self.mask_path = mask_path
    self.transforms = transforms

  def __len__(self):
    return len(self.images)
  
  def __getitem__(self, idx):
    img = np.array(Image.open(self.images[idx]).convert('RGB'))
    mask = np.array(Image.open(os.path.join(self.mask_path, os.path.basename(self.images[idx]).replace('.jpg', '.png')))) 
    mask[mask == 255.0] = 1.0  
    augmentations = self.transforms(image=img, mask=mask)
    image = augmentations["image"]
    mask = augmentations["mask"]
    mask = torch.unsqueeze(mask, 0)
    mask = mask.type(torch.float32)
    return image, mask

In [7]:
class SegmentationDataModule(pl.LightningDataModule):
    
    def __init__(self, image_path, mask_path, transform, train_size=0.90, batch_size: int = 9):
        super().__init__()
        self.image_path = image_path
        self.mask_path = mask_path
        self.batch_size = batch_size
        self.transform = transform
        self.train_size = train_size
        
    def setup(self, stage = None):
        if stage in (None, 'fit'):
            ds = SegmentationDataset(self.image_path, self.mask_path, self.transform)
            train_size = math.floor(len(ds)*self.train_size)
            val_size = len(ds)-train_size
            train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
            self.train_dataset = train_ds
            self.val_dataset = val_ds
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, self.batch_size, num_workers=2, shuffle = True, persistent_workers=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, self.batch_size, num_workers=2, persistent_workers=True)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, self.batch_size)

# Training

Once our model, optimizer+loss, and LightningDataModule is defined, we can use Trainer to run our training and validation loops. We can also use callbacks provided by lightning (check out [Bolts](https://lightning-bolts.readthedocs.io/en/latest/) for advanced callbacks, such as for sparsification) to create checkpoints or for early stopping. Trainer also takes care of multi-GPU and TPU training. 

We also use albumentations to apply the **same** data augmentations on the image **and** mask. I don't believe the torchvision transforms allows you to do the same, but correct me if I'm wrong. 

### Must Download data from Kaggle to run the rest of the notebook

In [None]:
torch.cuda.empty_cache()
transform = A.Compose(
    [
        A.Resize(height=360, width=480),
        A.Rotate(limit=45, p=0.7),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.3),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        A.pytorch.ToTensorV2(),
    ]
)
ds = SegmentationDataModule('../input/carvana-image-masking-png/train_images', '../input/carvana-image-masking-png/train_masks', transform=transform)
model = UNET()
if not os.path.isdir('./unet_3'): os.mkdir('./unet_3')
checkpointCallback = pl.callbacks.ModelCheckpoint(dirpath="./unet_3", 
                                                  save_top_k=1, 
                                                  monitor="val_loss",
                                                 filename='{epoch}-{val_loss:.5f}',
                                                 mode='min')
if torch.cuda.is_available():
    trainer = pl.Trainer(max_epochs=8, accelerator='gpu', gpus=1, 
                         callbacks=[checkpointCallback], profiler='simple',
                         auto_lr_find=True)
else:
    trainer = pl.Trainer(max_epochs=7, accelerator='tpu', tpu_cores=8, 
                         callbacks=[checkpointCallback], profiler='simple', 
                         auto_lr_find=True)

trainer.fit(model, datamodule=ds)

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

# Utility Functions

Here we define three utility functions:
  1. **save_images**(model, loader, folder, device)
      - This is used to save predictions and their corresponding masks from a specified model across a specified dataloader.
      - Device can be set to cpu if not running on gpu
  2. **get_concat_v**(im1, im2)
      - This is used to later concatenate the prediction and mask images we create using save_image.
      - im1 & im2 are PIL.Image objects
  3. **merge_photos**(src_folder, dst_folder, remove_single)
      - This is used to read and concatenate the prediction and mask images using the previously defined get_concat_v().
      - remove_single can be set to False to save the unmerged images

In [None]:
def save_images(model, loader, folder='val_img', device='cuda'):
    model.eval()
    if not os.path.isdir(folder):
        os.mkdir(folder)
    model.to(device=device)
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        y = y.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x).cuda())
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y, f"{folder}/mask_{idx}.png")       

def get_concat_v(im1, im2):
    dst = Image.new('RGB', (im1.width, im1.height + im2.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (0, im1.height))
    return dst

def merge_photos(src_folder: str='./val_img', dst_folder: str='./merged_val_img', remove_single: bool=True):
    files = glob.glob(src_folder+'/*.png')
    if not os.path.isdir(dst_folder):
        os.mkdir(dst_folder)
    for i in range(int(len(files)/2)):
        pred_img = Image.open(f'{src_folder}/pred_{i}.png')
        mask_img = Image.open(f'{src_folder}/mask_{i}.png')
        get_concat_v(pred_img, mask_img).save(f'{dst_folder}/merged_pred_mask_{i}.png')
        if remove_single:
            os.remove(f'./val_img/pred_{i}.png')
            os.remove(f'./val_img/mask_{i}.png')

In [None]:
ds.setup()
save_images(model, ds.val_dataloader())
merge_photos()
# TODO: Implement feature for sampling random img/label pairs for model prediction/eval
#         - Should show output and mask as images, pref side by side or across batches
#         - Conduct additional testing using the competition or data from the internet

In [None]:
!zip -r lightning_logs.zip ./lightning_logs
!zip -r merged_val_img.zip ./merged_val_img
!zip -r unet.zip ./unet_3

  adding: lightning_logs/ (stored 0%)
  adding: lightning_logs/version_0/ (stored 0%)
  adding: lightning_logs/version_0/hparams.yaml (stored 0%)
  adding: lightning_logs/version_0/events.out.tfevents.1655251958.ba12c49bf83b.23.0 (deflated 66%)
  adding: merged_val_img/ (stored 0%)
  adding: merged_val_img/merged_pred_mask_12.png (deflated 26%)
  adding: merged_val_img/merged_pred_mask_18.png (deflated 25%)
  adding: merged_val_img/merged_pred_mask_35.png (deflated 28%)
  adding: merged_val_img/merged_pred_mask_34.png (deflated 28%)
  adding: merged_val_img/merged_pred_mask_20.png (deflated 29%)
  adding: merged_val_img/merged_pred_mask_5.png (deflated 29%)
  adding: merged_val_img/merged_pred_mask_31.png (deflated 30%)
  adding: merged_val_img/merged_pred_mask_24.png (deflated 27%)
  adding: merged_val_img/merged_pred_mask_38.png (deflated 26%)
  adding: merged_val_img/merged_pred_mask_56.png (deflated 15%)
  adding: merged_val_img/merged_pred_mask_16.png (deflated 25%)

Since Kaggle doesn't work with Tensorboard, download and load the logs in Google Collab to analyze the training and validation metrics

<button><a href='./lightning_logs.zip'>Download Logs</a></button>
<button><a href='./merged_val_img.zip'>Download Images</a></button>
<button><a href='./unet.zip'>Download Model</a></button>