In [1]:
import torch

if not torch.cuda.is_available():
    !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
    !python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
    !pip install pytorch-lightning==1.1.5

# Model Architecture

The following image depicts the architecture that will be used for segmentation. We will first define a torch module as the building block for the model, then use a pytorch lightning module to define the final model.

![](https://raw.githubusercontent.com/aladdinpersson/Machine-Learning-Collection/master/ML/Pytorch/image_segmentation/semantic_segmentation_unet/UNET_architecture.png)

In [2]:
# TODO: Get test set and create CSV submission for competition
import math
import torch.nn as nn
import torchvision.transforms.functional as TF
import torchvision.utils
from transformers import get_cosine_schedule_with_warmup
import pytorch_lightning as pl
import torchmetrics as tm
from typing import List

class Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self, x):
        return self.conv(x)

class UNET(pl.LightningModule):
    def __init__(self, in_channels: int=3, out_channels: int=1, features: List=[64,128,256,512]):
        super().__init__()
        self.down = nn.ModuleList()
        self.up = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)
        for feature in features:
            self.down.append(Block(in_channels, feature))
            in_channels=feature
        for feature in reversed(features):
            self.up.append(
                nn.ConvTranspose2d(feature*2, feature, 2, 2)
            )
            self.up.append(
                Block(feature*2, feature) # x gets concat to 2xchannel
            )
        self.bottleneck = Block(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, 1)
        self.loss_fn = nn.BCEWithLogitsLoss()
        
        self.num_correct = 0
        self.num_pixels = 0
        self.dice_score = 0
    def forward(self, x):
        skip_connections = []
        for down in self.down:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.up), 2):
            x = self.up[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim=1) # Concat along channels (b, c, h, w)
            x = self.up[idx+1](concat_skip)
        return self.final_conv(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss_fn(pred, y)
        self.log('train_loss', loss, logger = True)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss_fn(pred, y)
        pred = torch.sigmoid(pred)
        pred = (pred > 0.5).float()
        self.num_correct += (pred == y).sum()
        self.num_pixels += torch.numel(pred)
        self.dice_score += (2 * (pred * y).sum()) / (
            (pred + y).sum() + 1e-8
        )
        self.log('val_loss', loss, prog_bar = True, logger = True)
        return {'loss': loss, 'len': len(self.trainer.val_dataloaders[0])}
    
    def validation_epoch_end(self, output):
        val_acc = float(f'{self.num_correct/self.num_pixels*100:.2f}')
        self.log('val_acc', val_acc, prog_bar = True, logger = True)
        dice_score = self.dice_score/len(output)
        self.log('dice_score', dice_score, prog_bar = True, logger = True)
        self.num_correct, self.num_pixels, self.dice_score = 0,0,0
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(params = self.parameters(), lr = 1.5e-3, weight_decay = 0.3)
        return optimizer
        

In [3]:
pl.utilities.model_summary.summarize(UNET(),-1)

   | Name              | Type              | Params
---------------------------------------------------------
0  | down              | ModuleList        | 4.7 M 
1  | down.0            | Block             | 38.8 K
2  | down.0.conv       | Sequential        | 38.8 K
3  | down.0.conv.0     | Conv2d            | 1.7 K 
4  | down.0.conv.1     | BatchNorm2d       | 128   
5  | down.0.conv.2     | ReLU              | 0     
6  | down.0.conv.3     | Conv2d            | 36.9 K
7  | down.0.conv.4     | BatchNorm2d       | 128   
8  | down.0.conv.5     | ReLU              | 0     
9  | down.1            | Block             | 221 K 
10 | down.1.conv       | Sequential        | 221 K 
11 | down.1.conv.0     | Conv2d            | 73.7 K
12 | down.1.conv.1     | BatchNorm2d       | 256   
13 | down.1.conv.2     | ReLU              | 0     
14 | down.1.conv.3     | Conv2d            | 147 K 
15 | down.1.conv.4     | BatchNorm2d       | 256   
16 | down.1.conv.5     | ReLU              | 0     
17 | d

In [4]:
import numpy as np
import glob
import os
from torch.utils.data import DataLoader, Dataset
from PIL import Image

class SegmentationDataset(Dataset):
  def __init__(self, image_path, mask_path, transforms):
    self.images = glob.glob(os.path.join(image_path, '*.jpg'))
    self.image_path = image_path
    self.mask_path = mask_path
    self.transforms = transforms

  def __len__(self):
    return len(self.images)
  
  def __getitem__(self, idx):
    img = np.array(Image.open(self.images[idx]).convert('RGB'))
    mask = np.array(Image.open(os.path.join(self.mask_path, os.path.basename(self.images[idx]).replace('.jpg', '.png')))) 
    mask[mask == 255.0] = 1.0  
    augmentations = self.transforms(image=img, mask=mask)
    image = augmentations["image"]
    mask = augmentations["mask"]
    mask = torch.unsqueeze(mask, 0)
    mask = mask.type(torch.float32)
    return image, mask

In [5]:
class SegmentationDataModule(pl.LightningDataModule):
    
    def __init__(self, image_path, mask_path, transform, train_size=0.90, batch_size: int = 7):
        super().__init__()
        self.image_path = image_path
        self.mask_path = mask_path
        self.batch_size = batch_size
        self.transform = transform
        self.train_size = train_size
        
    def setup(self, stage = None):
        if stage in (None, 'fit'):
            ds = SegmentationDataset(self.image_path, self.mask_path, self.transform)
            train_size = math.floor(len(ds)*self.train_size)
            val_size = len(ds)-train_size
            train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
            self.train_dataset = train_ds
            self.val_dataset = val_ds
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, self.batch_size, num_workers=2, shuffle = True, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size, num_workers=2, persistent_workers=True)
    
    def test_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size)

In [6]:
def tensorToLImage(tensor: torch.Tensor()):
    if tensor.shape[0]==1: tensor = tensor.squeeze(0) # no channel dimension for 8bit black and white image
    tensor = torch.sigmoid(tensor)
    tensor = (tensor > 0.5).float()
    tensor *= 255
    tensor = tensor.numpy().astype(np.uint8)
    if np.ndim(tensor)>3:
        assert tensor.shape[0] == 1
        tensor = tensor[0]
    return Image.fromarray(tensor, mode='L')

def tensorToRGBImage(tensor: torch.Tensor()):
    tensor = tensor*255
    tensor = tensor.numpy().astype(np.uint8)
    if np.ndim(tensor)>3:
        assert tensor.shape[0] == 1
        tensor = tensor[0]
    return Image.fromarray(tensor, mode='RGB')

# Training

Once the model, optimizer, and lightning data module is defined (to provide us with train & val dataloaders), we can use the lightning trainer to automatically do our training and validation loops. This not only saves us from writing boilerplate code, but also allows us to better organize our code. We use albumentations to augment our data, and to allow us to apply the <strong>same</strong> transformations on both the image and mask.

Callback functions from lightning are also used, in this case it saves the top performing model state based on the validation loss.

In [7]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

torch.cuda.empty_cache()
transform = A.Compose(
    [
        A.Resize(height=360, width=480),
        A.Rotate(limit=45, p=0.7),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.3),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        A.pytorch.ToTensorV2(),
    ]
)
ds = SegmentationDataModule('../input/carvana-image-masking-png/train_images', '../input/carvana-image-masking-png/train_masks', transform=transform)
model = UNET()
if not os.path.isdir('./unet_3'): os.mkdir('./unet_3')
checkpointCallback = pl.callbacks.ModelCheckpoint(dirpath="./unet_3", 
                                                  save_top_k=3, 
                                                  monitor="val_loss",
                                                 filename='{epoch}-{val_loss:.5f}')
if torch.cuda.is_available():
    trainer = pl.Trainer(max_epochs=8, accelerator='gpu', gpus=1, callbacks=[checkpointCallback], profiler='simple')
else:
    trainer = pl.Trainer(max_epochs=7, accelerator='tpu', tpu_cores=8, callbacks=[checkpointCallback], profiler='simple')
trainer.fit(model, ds)

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

# Utility Functions

Here we define three utility functions:
  1. **save_images**(model, loader, folder, device)
      - This is used to save predictions and their corresponding masks from a specified model across a specified dataloader.
      - Device can be set to cpu if not running on gpu
  2. **get_concat_v**(im1, im2)
      - This is used to later concatenate the prediction and mask images we create using save_image.
      - im1 & im2 are PIL.Image objects
  3. **merge_photos**(src_folder, dst_folder, remove_single)
      - This is used to read and concatenate the prediction and mask images using the previously defined get_concat_v().
      - remove_single can be set to False to save the unmerged images

In [8]:
def save_images(model, loader, folder='val_img', device='cuda'):
    model.eval()
    if not os.path.isdir(folder):
        os.mkdir(folder)
    model.to(device=device)
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        y = y.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x).cuda())
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y, f"{folder}/mask_{idx}.png")       

def get_concat_v(im1, im2):
    dst = Image.new('RGB', (im1.width, im1.height + im2.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (0, im1.height))
    return dst

def merge_photos(src_folder: str='./val_img', dst_folder: str='./merged_val_img', remove_single: bool=True):
    files = glob.glob(src_folder+'/*.png')
    if not os.path.isdir(dst_folder):
        os.mkdir(dst_folder)
    for i in range(int(len(files)/2)):
        pred_img = Image.open(f'{src_folder}/pred_{i}.png')
        mask_img = Image.open(f'{src_folder}/mask_{i}.png')
        get_concat_v(pred_img, mask_img).save(f'{dst_folder}/merged_pred_mask_{i}.png')
        if remove_single:
            os.remove(f'./val_img/pred_{i}.png')
            os.remove(f'./val_img/mask_{i}.png')

In [9]:
ds.setup()
save_images(model, ds.val_dataloader())
merge_photos()
# TODO: Implement feature for sampling random img/label pairs for model prediction/eval
#         - Should show output and mask as images, pref side by side or across batches
#         - Conduct additional testing using the competition or data from the internet

In [10]:
!zip -r lightning_logs.zip ./lightning_logs
!zip -r merged_val_img.zip ./merged_val_img
!zip -r unet.zip ./unet_3

  adding: lightning_logs/ (stored 0%)
  adding: lightning_logs/version_0/ (stored 0%)
  adding: lightning_logs/version_0/hparams.yaml (stored 0%)
  adding: lightning_logs/version_0/events.out.tfevents.1654404811.03fdac3ac7f0.23.0 (deflated 67%)
  adding: merged_val_img/ (stored 0%)
  adding: merged_val_img/merged_pred_mask_2.png (deflated 10%)
  adding: merged_val_img/merged_pred_mask_21.png (deflated 17%)
  adding: merged_val_img/merged_pred_mask_13.png (deflated 9%)
  adding: merged_val_img/merged_pred_mask_5.png (deflated 12%)
  adding: merged_val_img/merged_pred_mask_0.png (deflated 13%)
  adding: merged_val_img/merged_pred_mask_37.png (deflated 11%)
  adding: merged_val_img/merged_pred_mask_60.png (deflated 10%)
  adding: merged_val_img/merged_pred_mask_63.png (deflated 17%)
  adding: merged_val_img/merged_pred_mask_12.png (deflated 13%)
  adding: merged_val_img/merged_pred_mask_8.png (deflated 9%)
  adding: merged_val_img/merged_pred_mask_67.png (deflated 10%)
  a

<button><a href='./lightning_logs.zip'>Download Logs</a></button>
<button><a href='./merged_val_img.zip'>Download Images</a></button>
<button><a href='./unet.zip'>Download Model</a></button>