<a href="https://colab.research.google.com/github/sclfunonr/image-segmentation-unet/blob/main/segment_images_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
from torch import optim
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset
from tqdm import tqdm

In [2]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DoubleConv, self).__init__()
    self.convs = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3),
        nn.ReLU(inplace=True)
    )
  def forward(self, x):
    return self.convs(x)

In [3]:
class UNet(nn.Module):
  def __init__(self, in_channels, n_classes):
    super(UNet, self).__init__()
    self.in_channels = in_channels
    self.n_classes = n_classes

    self.encoder = nn.ModuleList([
        DoubleConv(in_channels, 64),
        DoubleConv(64, 128),
        DoubleConv(128, 256),
        DoubleConv(256, 512),
    ])

    self.pool = nn.MaxPool2d(2)

    self.bottleneck = DoubleConv(512, 1024)

    self.decoder = nn.ModuleList([
        nn.ConvTranspose2d(1024, 512, 2, 2),
        DoubleConv(1024, 512),
        nn.ConvTranspose2d(512, 256, 2, 2),
        DoubleConv(512, 256),
        nn.ConvTranspose2d(256, 128, 2, 2),
        DoubleConv(256, 128),
        nn.ConvTranspose2d(128, 64, 2, 2),
        DoubleConv(128, 64)
    ])

    self.classifier = nn.Conv2d(64, n_classes, 1)
    # Add a final upsampling layer to match the input size
    # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

  def forward(self, x):
    encoder_features = []
    for encoder in self.encoder:
      x = encoder(x)
      encoder_features.append(x)
      x = self.pool(x)

    x = self.bottleneck(x)

    for idx in range(0, len(self.decoder), 2):
      x = self.decoder[idx](x)
      encoder_feature = encoder_features.pop()

      # Replace padding with interpolation to match sizes
      # x = F.interpolate(x, size=encoder_feature.shape[2:], mode='bilinear', align_corners=False)
      # print(f"decoder feature shape: {x.shape} | encoder feature shape: {encoder_feature.shape}")

      # Resize the decoder output to match the encoder feature size
      x = F.interpolate(x, size=encoder_feature.shape[2:], mode='bilinear', align_corners=False)

      # diff_x = encoder_feature.shape[2] - x.shape[2]
      # diff_y = encoder_feature.shape[3] - x.shape[3]
      # x = F.pad(x, (diff_x // 2, diff_x - diff_x //2, diff_y // 2, diff_y - diff_y // 2))
      # print(f"decoder feature shape after padding: {x.shape}")

      # print(f"x.shape: {x.shape} ")
      # print(f"encoder_feature shape: {encoder_feature.shape}")
      x = torch.cat((x, encoder_feature), dim=1)
      x = self.decoder[idx+1](x)

    x = self.classifier(x)
    x = F.interpolate(x, size=mask.shape[2:], mode="bilinear", align_corners=False)
    # print(f"output feature shape: {x.shape}")
    # Upsample the output to match the input size
    # x = self.upsample(x)
    return x

In [4]:
# ! pip install kaggle

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
! mkdir ~/.kaggle
! cp /content/drive/MyDrive/kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! kaggle competitions download -c carvana-image-masking-challenge
! unzip carvana-image-masking-challenge.zip

Downloading carvana-image-masking-challenge.zip to /content
100% 24.4G/24.4G [11:23<00:00, 45.5MB/s]
100% 24.4G/24.4G [11:23<00:00, 38.4MB/s]
Archive:  carvana-image-masking-challenge.zip
  inflating: 29bb3ece3180_11.jpg     
  inflating: metadata.csv.zip        
  inflating: sample_submission.csv.zip  
  inflating: test.zip                
  inflating: test_hq.zip             
  inflating: train.zip               
  inflating: train_hq.zip            
  inflating: train_masks.csv.zip     
  inflating: train_masks.zip         


In [7]:
import zipfile
import os
if not os.path.isdir("/content/train") or (len(os.listdir("/content/train")) <= 1):
  with zipfile.ZipFile("/content/train.zip", "r") as zip_ref:
    zip_ref.extractall("/content/")
if not os.path.isdir("/content/train_masks") or (len(os.listdir("/content/train_masks")) <= 1):
  with zipfile.ZipFile("/content/train_masks.zip", "r") as zip_ref:
    zip_ref.extractall("/content/")

In [None]:
# Function to visualize the images and masks
def visualize_images(images, masks, num_images=1):
  for i in range(num_images):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(images[i].permute(1, 2, 0))
    plt.title("Image")
    plt.subplot(1, 2, 2)
    plt.imshow(masks[i].permute(1, 2, 0))
    plt.title("Mask")
    plt.show()

images = sorted(["content"+"/train/" + i for i in os.listdir("content"+"/train/")])
masks = sorted(["content"+"/train_masks/" + i for i in os.listdir("content"+"/train_masks/")])
visualize_images(images, masks, num_images=1)

In [8]:
class CarvanaDataset(Dataset):
  def __init__(self, root_path, limit=None):
    # Assign values for self.root_path, self.limit, self.images, and self.masks
    self.root_path = root_path
    self.limit = limit

    self.images = sorted([root_path+"/train/" + i for i in os.listdir(root_path+"/train/")])[:limit]
    self.masks = sorted([root_path+"/train_masks/" + i for i in os.listdir(root_path+"/train_masks/")])[:limit]
    if self.limit is None:
      self.limit = len(self.images)

    self.transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])

  def __getitem__(self, index):
    img = Image.open(self.images[index]).convert("RGB")
    mask = Image.open(self.masks[index]).convert("L")
    return self.transform(img), self.transform(mask)

  def __len__(self):
    return min(len(self.images), self.limit)




In [9]:
generator = torch.Generator().manual_seed(25)
train_dataset = CarvanaDataset("/content/")

train_dataset, test_dataset = random_split(train_dataset, [0.8, 0.2], generator=generator)
test_dataset, val_dataset = random_split(test_dataset, [0.5,0.5], generator=generator)
device = "cuda" if torch.cuda.is_available() else "cpu"
num_workers = torch.cuda.device_count() * 4 if device == "cuda" else 0

In [10]:
LEARNING_RATE = 3e-4 # 3x10^-4
BATCH_SIZE = 8 # instead of 8

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=False)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=False)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=False)

# pinnned memory is a region of RAM that the OS is prevented from swapping out to disk.
model = UNet(in_channels=3, n_classes=1).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.BCEWithLogitsLoss()

In [11]:
# DICE score = 2 * |A ∩ B| / (|A| + |B|) -- metric, evaluate segmentation performance
def dice_coefficient(prediction, target, epsilon=1e-07):
  # print(f"INSIDE dice_coefficient() prediction.shape: {prediction.shape}")
  # print(f"INSIDE dice_coefficient() target (mask)shape: {target.shape}")
  prediction_copy = prediction.clone()

  prediction_copy[prediction_copy < 0] = 0
  prediction_copy[prediction_copy > 0] = 1

  intersection = abs(torch.sum(prediction_copy * target)) # "* "is element-wise operation
  union = abs(torch.sum(prediction_copy) + torch.sum(target))
  dice = (2. * intersection + epsilon) / (union + epsilon)
  return dice

In [13]:
torch.cuda.empty_cache()
EPOCHS = 5

train_losses = []
train_dcs = []
val_losses = []
val_dcs = []

for epoch in tqdm(range(EPOCHS)):
  model.train()
  train_running_loss = 0
  train_running_dc = 0

  # print("\n\n------- BEGIN TRAINING ONE EPOCH---------\n\n")

  for idx, (image, mask) in enumerate(tqdm(train_dataloader, position=0, leave=True)):
    image = image.float().to(device)
    mask = mask.float().to(device)
    # print("\n")
    # print(f"idx: {idx}| image.shape:{image.shape}|mask.shape: {mask.shape}")

    y_pred = model(image)
    optimizer.zero_grad()

    # print(f"Before calling dice_coefficient() y_pred shape: mask shape| {y_pred.shape}:{mask.shape}")
    dc = dice_coefficient(y_pred, mask)
    loss = loss_fn(y_pred, mask)

    train_running_loss += loss.item() # loss.item() moved loss to CPU from GPU and get a python float
    train_running_dc += dc.item()

    loss.backward()
    optimizer.step()

  train_loss = train_running_loss / (idx + 1)
  train_dc = train_running_dc / (idx + 1)

  train_losses.append(train_loss)
  train_dcs.append(train_dc)
  # print("\n\n------- END TRAINING ONE EPOCH---------\n\n")
  # print("\n\n------- BEGIN VALIDATION ONE EPOCH---------\n\n")
  model.eval()
  val_running_loss = 0
  val_running_dc = 0

  with torch.no_grad():
    for idx, (image, mask) in enumerate(tqdm(val_dataloader, position=0, leave=True)):
      image = image.float().to(device)
      mask = mask.float().to(device)

      y_pred = model(image)
      loss = loss_fn(y_pred, mask)
      dc = dice_coefficient(y_pred, mask)

      val_running_loss += loss.item()
      val_running_dc += dc.item()

    val_loss = val_running_loss / (idx + 1)
    val_dc = val_running_dc / (idx + 1)

  val_losses.append(val_loss)
  val_dcs.append(val_dc)
  # print("\n\n------- END VALIDATION ONE EPOCH ---------\n\n")

  print("\n")
  print("-" * 30)
  print(f"Training Loss EPOCH {epoch + 1}: {train_loss:.4f}")
  print(f"Training DICE EPOCH {epoch + 1}: {train_dc:.4f}")
  print("\n")
  print(f"Validation Loss EPOCH {epoch + 1}: {val_loss:.4f}")
  print(f"Validation DICE EPOCH {epoch + 1}: {val_dc:.4f}")
  print("-" * 30)

# Saving the model
torch.save(model.state_dict(), 'my_checkpoint.pth')


100%|██████████| 509/509 [13:14<00:00,  1.56s/it]
100%|██████████| 64/64 [00:37<00:00,  1.71it/s]
 20%|██        | 1/5 [13:51<55:27, 831.91s/it]



------------------------------
Training Loss EPOCH 1: 0.0402
Training DICE EPOCH 1: 0.9630


Validation Loss EPOCH 1: 0.0443
Validation DICE EPOCH 1: 0.9589
------------------------------


 18%|█▊        | 91/509 [02:23<10:57,  1.57s/it]
 20%|██        | 1/5 [16:15<1:05:00, 975.01s/it]


KeyboardInterrupt: 

In [None]:
epochs_list = list(range(1, EPOCHS + 1))

plt.figure(figsize=(12, 5))
plt.subplot(1,2,1)
plt.plot(epochs_list, train_losses, label='Training Loss')
plt.plot(epochs_list, val_losses, label='Validation Loss')
plt.xticks(ticks=list(range(1, EPOCHS + 1, 1)))
plt.title('Loss over epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.grid()
plt.tight_layout()
plt.legend()

plt.subplot(1,2,2)
plt.plot(epochs_list, train_dcs, label='Training DICE')
plt.plot(epochs_list, val_dcs, label='Validation DICE')
plt.xticks(ticks=list(range(1, EPOCHS + 1, 1)))
plt.title('DICE Coefficient over epochs')
plt.xlabel('Epochs')
plt.ylabel('DICE')
plt.grid()
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
# Zoomed loss over epochs
epochs_list = list(range(1, EPOCHS + 1))

plt.figure(figsize=(12, 5))
plt.plot(epochs_list, train_losses, label='Training Loss')
plt.plot(epochs_list, val_losses, label='Validation Loss')
plt.xticks(ticks=list(range(1, EPOCHS + 1, 1)))
plt.ylim(0, 0.05)
plt.title('Loss over epochs (zoomed)')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.grid()
plt.tight_layout()
plt.legend()
plt.show()

In [None]:
# Test - how the model performs on unseen images
model_pth = "content/my_checkpoint.pth"
trained_model = UNet(in_channels=3, n_classes=1).to(device)
trained_model.load_state_dict(torch.load(model_pth, map_location=torch.device(device)))


In [None]:
test_running_loss = 0
test_running_dc = 0

with torch.no_grad():
  for idx, img_mask in enumerate(tqdm(test_dataloader, position=0, leave=True)):
    img = img_mask[0].float().to(device)
    mask = img_mask[1].float().to(device)

    y_pred = trained_model(img)
    loss = loss_fn(y_pred, mask)
    dc = dice_coefficient(y_pred, mask)

    test_running_loss += loss.item()
    test_running_dc += dc.item()

    test_loss = test_running_loss / (idx + 1)
    test_dc = test_running_dc / (idx + 1)


In [None]:
def random_images_inference(image_tensors, mask_tensors, image_paths, model_pth, device):
    model = UNet(in_channels=3, n_classes=1).to(device)
    model.load_state_dict(torch.load(model_pth, map_location=torch.device(device)))

    transform = transforms.Compose([
        transforms.Resize((512, 512))
    ])

    # Iterate for the images, masks and paths
    for image_pth, mask_pth, image_paths in zip(image_tensors, mask_tensors, image_paths):
        # Load the image
        img = transform(image_pth)

        # Predict the imagen with the model
        pred_mask = model(img.unsqueeze(0))
        pred_mask = pred_mask.squeeze(0).permute(1,2,0)

        # Load the mask to compare
        mask = transform(mask_pth).permute(1, 2, 0).to(device)

        print(f"Image: {os.path.basename(image_paths)}, DICE coefficient: {round(float(dice_coefficient(pred_mask, mask)),5)}")

        # Show the images
        img = img.cpu().detach().permute(1, 2, 0)
        pred_mask = pred_mask.cpu().detach()
        pred_mask[pred_mask < 0] = 0
        pred_mask[pred_mask > 0] = 1

        plt.figure(figsize=(15, 16))
        plt.subplot(131), plt.imshow(img), plt.title("original")
        plt.subplot(132), plt.imshow(pred_mask, cmap="gray"), plt.title("predicted")
        plt.subplot(133), plt.imshow(mask, cmap="gray"), plt.title("mask")
        plt.show()

In [None]:
import copy
import random
import shutil

n = 10

image_tensors = []
mask_tensors = []
image_paths = []

for _ in range(n):
    random_index = random.randint(0, len(test_dataloader.dataset) - 1)
    random_sample = test_dataloader.dataset[random_index]

    image_tensors.append(random_sample[0])
    mask_tensors.append(random_sample[1])
    image_paths.append(random_sample[2])

In [None]:
model_pth = "/content/my_checkpoint.pth"
random_images_inference(image_tensors, mask_tensors, image_paths, model_pth, device="cpu")