In [29]:
!pip install albumentations wandb

Collecting wandb
  Downloading wandb-0.17.3-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Using cached GitPython-3.1.43-py3-none-any.whl.metadata (13 kB)
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.7.1-py2.py3-none-any.whl.metadata (14 kB)
Collecting setproctitle (from wandb)
  Using cached setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb)
  Using cached gitdb-4.0.11-py3-none-any.whl.metadata (1.2 kB)
Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb)
  Using cached smmap-5.0.1-py3-none-any.whl.metadata (4.3 kB)
Downloading wandb-0.17.3-py3-none

In [55]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from PIL import Image
import numpy as np
import os
import glob
from sklearn.metrics import f1_score, precision_score, recall_score
from skimage.segmentation import find_boundaries
import matplotlib.pyplot as plt


import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import wandb

import warnings
warnings.filterwarnings("ignore")

In [41]:
class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask


In [42]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

In [43]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = CarvanaDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = CarvanaDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    
    acc = num_correct/num_pixels*100
    print(
        f"Got {num_correct}/{num_pixels} with acc {acc:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()
    return acc, dice_score

def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

In [44]:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 30
NUM_WORKERS = 2
IMAGE_HEIGHT = 160  # original 1280
IMAGE_WIDTH = 240   # original 1918
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "data/train_images/"
TRAIN_MASK_DIR = "data/train_masks/"
VAL_IMG_DIR = "data/val_images/"
VAL_MASK_DIR = "data/val_masks/"

In [45]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


In [46]:
train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)


In [47]:
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
wandb.init(project="Image segmentation of cars")
wandb.watch(model, criterion=loss_fn, log="all", log_freq=BATCH_SIZE, log_graph=True)

In [50]:
train_loader, val_loader = get_loaders(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    BATCH_SIZE,
    train_transform,
    val_transforms,
    NUM_WORKERS,
    PIN_MEMORY,
)

In [51]:
if LOAD_MODEL:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


check_accuracy(val_loader, model, device=DEVICE)
scaler = torch.cuda.amp.GradScaler()

INFO:numexpr.utils:Note: detected 96 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 96 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


Got 10184306/12211200 with acc 83.40
Dice score: 0.002155422465875745


In [56]:
for epoch in range(1, NUM_EPOCHS+1):
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer":optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    acc, dice = check_accuracy(val_loader, model, device=DEVICE)

    log_dict = {
        "epoch": epoch,
        "Accuracy": acc,
        "DICE": dice,
    }
    wandb.log(log_dict)

    # print some examples to a folder
    save_predictions_as_imgs(
        val_loader, model, folder="saved_images/", device=DEVICE
    )

100%|██████████| 299/299 [01:09<00:00,  4.33it/s, loss=0.0163]


=> Saving checkpoint
Got 12110036/12211200 with acc 99.17
Dice score: 0.9754392504692078


100%|██████████| 299/299 [01:04<00:00,  4.61it/s, loss=0.057] 


=> Saving checkpoint
Got 12124259/12211200 with acc 99.29
Dice score: 0.9787141680717468


100%|██████████| 299/299 [01:12<00:00,  4.13it/s, loss=0.0386]


=> Saving checkpoint
Got 12120650/12211200 with acc 99.26
Dice score: 0.9776038527488708


100%|██████████| 299/299 [01:06<00:00,  4.52it/s, loss=0.032] 


=> Saving checkpoint
Got 12139341/12211200 with acc 99.41
Dice score: 0.9823150634765625


100%|██████████| 299/299 [01:08<00:00,  4.38it/s, loss=0.0272]


=> Saving checkpoint
Got 12142847/12211200 with acc 99.44
Dice score: 0.9831758737564087


100%|██████████| 299/299 [01:09<00:00,  4.30it/s, loss=0.0222]


=> Saving checkpoint
Got 12146669/12211200 with acc 99.47
Dice score: 0.9841365218162537


100%|██████████| 299/299 [01:03<00:00,  4.68it/s, loss=0.0251]


=> Saving checkpoint
Got 12137262/12211200 with acc 99.39
Dice score: 0.9819357991218567


100%|██████████| 299/299 [01:10<00:00,  4.24it/s, loss=0.0209]


=> Saving checkpoint
Got 12147086/12211200 with acc 99.47
Dice score: 0.9842544794082642


100%|██████████| 299/299 [01:10<00:00,  4.24it/s, loss=0.025] 


=> Saving checkpoint
Got 12146721/12211200 with acc 99.47
Dice score: 0.9841861128807068


100%|██████████| 299/299 [01:10<00:00,  4.24it/s, loss=0.0273]


=> Saving checkpoint
Got 12146724/12211200 with acc 99.47
Dice score: 0.9841023683547974


100%|██████████| 299/299 [01:08<00:00,  4.38it/s, loss=0.0253]


=> Saving checkpoint
Got 12058556/12211200 with acc 98.75
Dice score: 0.9631582498550415


100%|██████████| 299/299 [01:09<00:00,  4.29it/s, loss=0.0169]


=> Saving checkpoint
Got 12154112/12211200 with acc 99.53
Dice score: 0.9859102368354797


100%|██████████| 299/299 [01:07<00:00,  4.41it/s, loss=0.0153]


=> Saving checkpoint
Got 12155956/12211200 with acc 99.55
Dice score: 0.9864065051078796


100%|██████████| 299/299 [01:11<00:00,  4.16it/s, loss=0.015] 


=> Saving checkpoint
Got 12149543/12211200 with acc 99.50
Dice score: 0.9847881197929382


100%|██████████| 299/299 [01:10<00:00,  4.22it/s, loss=0.0115]


=> Saving checkpoint
Got 12159859/12211200 with acc 99.58
Dice score: 0.9873365759849548


100%|██████████| 299/299 [01:02<00:00,  4.77it/s, loss=0.0131]


=> Saving checkpoint
Got 12161403/12211200 with acc 99.59
Dice score: 0.9876896142959595


100%|██████████| 299/299 [01:10<00:00,  4.24it/s, loss=0.0147]


=> Saving checkpoint
Got 12161292/12211200 with acc 99.59
Dice score: 0.9876465201377869


100%|██████████| 299/299 [01:05<00:00,  4.55it/s, loss=0.00929]


=> Saving checkpoint
Got 12160132/12211200 with acc 99.58
Dice score: 0.987423837184906


100%|██████████| 299/299 [01:12<00:00,  4.13it/s, loss=0.0146] 


=> Saving checkpoint
Got 12163294/12211200 with acc 99.61
Dice score: 0.9881804585456848


100%|██████████| 299/299 [01:10<00:00,  4.22it/s, loss=0.00826]


=> Saving checkpoint
Got 12163661/12211200 with acc 99.61
Dice score: 0.9882450103759766


100%|██████████| 299/299 [01:11<00:00,  4.19it/s, loss=0.012]  


=> Saving checkpoint
Got 12163165/12211200 with acc 99.61
Dice score: 0.9881316423416138


100%|██████████| 299/299 [01:10<00:00,  4.22it/s, loss=0.0104] 


=> Saving checkpoint
Got 12159173/12211200 with acc 99.57
Dice score: 0.9871697425842285


100%|██████████| 299/299 [01:11<00:00,  4.18it/s, loss=0.0097] 


=> Saving checkpoint
Got 12160943/12211200 with acc 99.59
Dice score: 0.987514317035675


100%|██████████| 299/299 [01:06<00:00,  4.47it/s, loss=0.0109] 


=> Saving checkpoint
Got 12159877/12211200 with acc 99.58
Dice score: 0.9873296022415161


100%|██████████| 299/299 [01:03<00:00,  4.74it/s, loss=0.0103] 


=> Saving checkpoint
Got 12143442/12211200 with acc 99.45
Dice score: 0.9834669232368469


100%|██████████| 299/299 [01:11<00:00,  4.19it/s, loss=0.0229] 


=> Saving checkpoint
Got 12163884/12211200 with acc 99.61
Dice score: 0.9882590174674988


100%|██████████| 299/299 [01:01<00:00,  4.89it/s, loss=0.0111] 


=> Saving checkpoint
Got 12158075/12211200 with acc 99.56
Dice score: 0.9868267178535461


100%|██████████| 299/299 [01:04<00:00,  4.65it/s, loss=0.00888]


=> Saving checkpoint
Got 12165843/12211200 with acc 99.63
Dice score: 0.9887617230415344


100%|██████████| 299/299 [01:10<00:00,  4.23it/s, loss=0.0108] 


=> Saving checkpoint
Got 12165859/12211200 with acc 99.63
Dice score: 0.9887663722038269


100%|██████████| 299/299 [01:07<00:00,  4.45it/s, loss=0.00975]


=> Saving checkpoint
Got 12169113/12211200 with acc 99.66
Dice score: 0.9896017909049988
