1. padded convolutions
2. learn about U net and convolution
3. Pytorch course
4. see and understand the model architecture

In [1]:
import torch

In [2]:
import torch.nn as nn
import torchvision.transforms.functional as TF
import torchvision

In [3]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim

In [4]:
class DoubleConv(nn.Module):
  def __init__(self,in_channels, out_channels):
    super(DoubleConv,self).__init__()
    self.conv=nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_channels,out_channels,3,1,1,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

    def forward(self,x):
      return self.conv(x)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as TF

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)


        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature * 2, feature))
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        # upsampling
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]

            if x.shape != skip_connection.shape:
                x = TF.interpolate(x, size=skip_connection.shape[2:], mode="bilinear", align_corners=True)

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)

        return self.final_conv(x)

x = torch.randn(1, 3, 256, 256)
model = UNET()
output = model(x)
print(output.shape)


torch.Size([1, 1, 256, 256])


In [6]:
def test():
  x=torch.randn((3,1,161,161))
  model=UNET(in_channels=1,out_channels=1)
  preds=model(x)
  print(preds.shape)
  print(x.shape)
  assert preds.shape==x.shape
if __name__=="__main__":
  test()

torch.Size([3, 1, 161, 161])
torch.Size([3, 1, 161, 161])


In [7]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

In [8]:
class carData(Dataset):
    def __init__(self,image_dir,mask_dir,transform=None):
        self.image_dir=image_dir
        self.mask_dir=mask_dir
        self.transform=transform
        self.images=os.listdir(image_dir)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_path=os.path.join(self.image_dir,self.images[index])
        mask_path=os.path.join(self.mask_dir,self.images[index].replace(".jpg","_mask.gif"))
        image=np.array(Image.open(img_path).convert("RGB"))
        mask=np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask==255.0]=1.0

        if self.transform is not None:
            augmentations=self.transform(image=image,mask=mask)
            image=augmentations["image"]
            mask=augmentations["mask"]
        
        
        return image,mask


In [11]:
from torch.utils.data import DataLoader
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=0,
    pin_memory=True,

):
    train_ds = carData(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = carData(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
    num_correct=0
    num_pixels=0
    dice_score=0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

Training part

In [9]:
# hyperparameters
LEARNING_RATE=1e-4
DEVICE="cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE=4
NUM_EPOCHS=7
NUM_WORKERS=0
IMAGE_HT=300
IMAGE_W=300
PIN_MEMORY=True
LOAD_MODEL=False
TRAIN="train"
MASK="train_mask"
VAL_T="val_train"
VAL_MASK="val_mask"

In [15]:
print(f"Using Device: {DEVICE}")  # Check if it's 'cuda'

Using Device: cuda


In [16]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim


def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.amp.autocast('cuda'):
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HT, width=IMAGE_W),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HT, width=IMAGE_W),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN,
        MASK,
        VAL_T,
        VAL_MASK,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


    check_accuracy(val_loader, model, device=DEVICE)
    scaler = torch.amp.GradScaler('cuda')

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        check_accuracy(val_loader, model, device=DEVICE)

        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/", device=DEVICE
        )

if __name__ == "__main__":
    main()


Got 288907/1350000 with acc 21.40
Dice score: 0.3116714656352997


100%|██████████| 22/22 [00:13<00:00,  1.68it/s, loss=0.378]


=> Saving checkpoint
Got 1105006/1350000 with acc 81.85
Dice score: 0.0


100%|██████████| 22/22 [00:16<00:00,  1.35it/s, loss=0.231]


=> Saving checkpoint
Got 1312185/1350000 with acc 97.20
Dice score: 0.9168542623519897


100%|██████████| 22/22 [00:21<00:00,  1.01it/s, loss=0.219]


=> Saving checkpoint
Got 1312632/1350000 with acc 97.23
Dice score: 0.9187583923339844


100%|██████████| 22/22 [00:18<00:00,  1.16it/s, loss=0.218]


=> Saving checkpoint
Got 1331381/1350000 with acc 98.62
Dice score: 0.9629092216491699


100%|██████████| 22/22 [00:15<00:00,  1.46it/s, loss=0.191]


=> Saving checkpoint
Got 1270051/1350000 with acc 94.08
Dice score: 0.810137152671814


100%|██████████| 22/22 [00:22<00:00,  1.04s/it, loss=0.207]


=> Saving checkpoint
Got 1336026/1350000 with acc 98.96
Dice score: 0.9713790416717529


100%|██████████| 22/22 [00:23<00:00,  1.05s/it, loss=0.176]


=> Saving checkpoint
Got 1333836/1350000 with acc 98.80
Dice score: 0.9670965075492859
