<a href="https://colab.research.google.com/github/crsdvaibhav/unet/blob/main/RP_1_UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

# Model

In [None]:
from torch.nn.modules.batchnorm import BatchNorm2d
class DoubleConv(nn.Module):
    def __init__(self, inChannels, outChannels):
        super(DoubleConv,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inChannels, outChannels, 3,1,1,bias = False), #3x3 filter with stride 1, padding 1 and we switch off bias
            #because we are using batchnorm
            nn.BatchNorm2d(outChannels),
            nn.ReLU(inplace=True),
            nn.Conv2d(outChannels, outChannels, 3,1,1,bias = False),
            nn.BatchNorm2d(outChannels),
            nn.ReLU(inplace=True)
        )

    def forward(self, X):
        return self.conv(X)

In [None]:
class UNET(nn.Module):
    def __init__(self, inChannels=3, outChannels=1, features=[64, 128, 256, 512]): #3 coz coloured and 1 coz binary classification
        super(UNET,self).__init__()

        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        #Down part
        for feature in features:
            self.downs.append(DoubleConv(inChannels, feature))
            inChannels = feature

        #Up part
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size = 2, stride=2)) #x2 becoz we will concatenate
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)

        self.final_conv = nn.Conv2d(features[0], outChannels, kernel_size = 1)

    def forward(self, x):
        skip_connections = [] #to be concatenated later

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
            
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1] #reverse the list for ease

        for idx in range(0, len(self.ups), 2): #2 because we want only the conv
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2] #because our loop is running with step 2

            #One problem here is that if imput is 161x161 it will be donwsampled to 80x80 and then upsampled to 160x160, which will not allow us to concat

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:]) #[2:] because we will ignore batch and channel size

            concat_skip = torch.cat((skip_connection,x), dim=1) #concat along channels
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)
            



# Dataloading



In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

In [None]:
class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir) #list all images in dataset

    def __len__(self):
        return len(self.images)

    def __getitem__(self,index):
        image_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg","_mask.gif"))

        image = np.array(Image.open(image_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype = np.float32) #Greyscale

        mask[mask == 255.0] = 1.0 #We will use sigmoid activation for probability for white pixel

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image,mask

# Utils

In [None]:
"""
load_checkpoint
save_checkpoint
get_loaders
check_accuracy
save_predictions_as_imgs
"""

'\nload_checkpoint\nsave_checkpoint\nget_loaders\ncheck_accuracy\nsave_predictions_as_imgs\n'

In [None]:
import torchvision
from torch.utils.data import DataLoader

In [None]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

In [None]:
def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

In [None]:
def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = CarvanaDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = CarvanaDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

In [None]:
def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

In [None]:
def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

# Training

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim

In [None]:
# Hyperparameters

LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
EPOCHS = 3
NUM_WORKERS = 2
IMAGE_HEIGHT = 160 #1280 originally
IMAGE_WIDTH = 240 #1918 originally
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "/content/drive/MyDrive/carvana_dataset/train_images"
TRAIN_MASK_DIR = "/content/drive/MyDrive/carvana_dataset/train_masks"
VAL_IMG_DIR = "/content/drive/MyDrive/carvana_dataset/val_images"
VAL_MASK_DIR = "/content/drive/MyDrive/carvana_dataset/val_masks"

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data,targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        #forward
        with torch.cuda.amp.autocast():
            predictions = model(data) #getting predictions
            loss = loss_fn(predictions,targets) #getting loss function

        #backwards
        optimizer.zero_grad() #clearing previous gradients
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        #update tqdm loop
        loop.set_postfix(loss=loss.item())

In [None]:
train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

val_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

model = UNET(inChannels=3, outChannels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loader, val_loader = get_loaders(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    BATCH_SIZE,
    train_transform,
    val_transform,
    NUM_WORKERS,
    PIN_MEMORY,
)

if LOAD_MODEL:
    try:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
    except EOFError:
        pass


check_accuracy(val_loader, model, device=DEVICE)

scaler = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/", device=DEVICE
        )

Got 423093/1843200 with acc 22.95
Dice score: 0.37316572666168213


100%|██████████| 315/315 [12:26<00:00,  2.37s/it, loss=0.138]


=> Saving checkpoint
Got 1823827/1843200 with acc 98.95
Dice score: 0.9773280024528503


100%|██████████| 315/315 [04:11<00:00,  1.25it/s, loss=0.0843]


=> Saving checkpoint
Got 1765114/1843200 with acc 95.76
Dice score: 0.9155818223953247


100%|██████████| 315/315 [04:12<00:00,  1.25it/s, loss=0.0536]


=> Saving checkpoint
Got 1831918/1843200 with acc 99.39
Dice score: 0.9867638349533081
