<a href="https://colab.research.google.com/github/soulsharp/Attentive-Segnet/blob/main/Attentive_SegNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
import os
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

In [None]:
class CrackDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        """
        Custom dataset for segmentation tasks.

        Args:
            image_dir (str): Path to the directory containing RGB images.
            mask_dir (str): Path to the directory containing mask images.
            transform (callable, optional): A function/transform to apply to both images and masks.
        """
        self.image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.jpg', '.png', '.jpeg'))])
        self.mask_paths = sorted([os.path.join(mask_dir, fname) for fname in os.listdir(mask_dir) if fname.endswith(('.jpg', '.png', '.jpeg'))])

        if len(self.image_paths) != len(self.mask_paths):
            raise ValueError("Number of images and masks must be the same!")

        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Loads image and mask
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        image = cv.imread(image_path)
        mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE)

        if image is None:
            raise ValueError(f"Failed to load image: {image_path}")
        if mask is None:
            raise ValueError(f"Failed to load mask: {mask_path}")

        # Applies transformations
        if self.transform:
            image, mask = self.transform(image, mask)

        # Convert to tensors
        image = torch.as_tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0
        mask = torch.as_tensor(mask, dtype=torch.long)

        return image, mask

# Simple preprocessing function
def transform_image_and_mask(image, mask, size=512):
    # Resize image and mask
    image = cv.resize(image, (size, size), interpolation=cv.INTER_NEAREST)
    mask = cv.resize(mask, (size, size), interpolation=cv.INTER_NEAREST)
    return image, mask

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def make_train_val_loaders(image_directory, mask_directory):
  # Instantiates dataset
  dataset = CrackDataset(image_dir=image_directory, mask_dir=mask_directory, transform=transform_image_and_mask)

  # Train-validation split
  train_ratio = 0.8
  train_size = int(len(dataset) * train_ratio)
  val_size = len(dataset) - train_size
  train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

  # Creates DataLoaders
  train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
  val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

  return train_loader, val_loader

In [None]:
# Encodes patch representation of an image batch
class Patch_Embedding(nn.Module):
    def __init__(self, embed_dim=64, patch_size=32, in_channels=3):
        super(Patch_Embedding, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.convolution = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # Expects x to be of shape (batch_size, channels, height, width)
        h, w = x.shape[2], x.shape[3]
        if h % self.patch_size != 0 or w % self.patch_size != 0:
          raise ValueError("Input dimensions must be divisible by the patch size.")
        # Applies convolution to create patches
        patches = self.convolution(x)

        return patches

embedding_dim = 64
# patch information contains the patch embeddings of all training samples
patch_information = []
patch_embedding = Patch_Embedding()

# Tensor for validating the entire process
test_tensor = torch.randn((2, 3, 512, 512))
output_patches = patch_embedding(test_tensor)
output_patch_encoding = output_patches.view(output_patches.size(0), -1, embedding_dim)

output_patch_encoding.shape

torch.Size([2, 256, 64])

In [None]:
# returns a tensor after applying maxpool2d
# also returns the indices used up in the max pool
class MaxPoolWithIndices(nn.Module):
    def __init__(self, kernel_size, stride):
        super(MaxPoolWithIndices, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.maxpool = nn.MaxPool2d(kernel_size=self.kernel_size, stride=self.stride, return_indices=True)

    def forward(self, x):
        output, indices = self.maxpool(x)
        return output, indices

In [None]:
# basic convblock with conv2d-->batchnorm-->relu-->maxpool(optional)
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, maxpool_flag=True):
        super(ConvBlock, self).__init__()
        self.maxpool_flag = maxpool_flag
        self.operation_seq = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        if self.maxpool_flag:
            self.maxpool = MaxPoolWithIndices(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.operation_seq(x)

        if self.maxpool_flag:
            pooled, indices = self.maxpool(x)
            return pooled, indices
        else:
            return x, None

In [None]:
def Downsample(input_tensor):
    max_indices = {}
    bottleneck_input = None

    assert input_tensor.shape[-2] % 4 == 0 and input_tensor.shape[-1] % 4 == 0, \
        "Input dimensions must be divisible by 4 for this downsampling pipeline."

    conv_blocks = [
        ("block1", ConvBlock(3, 8, 3, 1, 1)),
        ("block2", ConvBlock(8, 16, 3, 1, 1)),
        ("block3", ConvBlock(16, 32, 3, 1, 1)),
        ("block4", ConvBlock(32, 64, 3, 4, 1, False)),
    ]

    for name, conv_block in conv_blocks:
        input_tensor, indices = conv_block(input_tensor)

        if name == "block4":
            bottleneck_input = input_tensor

        else:
          max_indices[name] = indices

    return max_indices, bottleneck_input

In [None]:
# Attention weights to scale the patch embed representation
class AttentionWeights(nn.Module):
    def __init__(self, output_dim):
        super(AttentionWeights, self).__init__()
        self.alphas = nn.Parameter(torch.randn(output_dim))

    def forward(self, x):

        # adding a dimension ie making alphas (256, 1) for proper broadcasting
        scaled_input = self.alphas.unsqueeze(1) * x
        return scaled_input


In [None]:
# Takes in the tensor containing patch information of all images in the batch
# Applies attention scaling on each image and returns BATCH_SZ X 256 X 64 tensor
def attention_to_all_images(patch_emb_tensor, attention_dim):
    att_weights = AttentionWeights(attention_dim)
    attended_inputs = att_weights(patch_emb_tensor)

    return attended_inputs

In [None]:
# Takes in the bottleneck tensor and adds patch embeddings to it
def add_bottleneck_and_patch_embeddings(bottleneck, patch_embedding):

  num_batches, num_channels, h, w = bottleneck.shape
  patch_embedding = patch_embedding.permute(0, 2, 1)
  patch_embedding = patch_embedding.view(num_batches, num_channels, h, w)

  return bottleneck + patch_embedding

In [None]:
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(UpBlock, self).__init__()
        self.operation_sequence = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        output = self.operation_sequence(x)
        return output

In [None]:
class Unpool2d(nn.Module):
  def __init__(self, k_sz, stride):
    super(Unpool2d, self).__init__()
    self.unpool2d = nn.MaxUnpool2d(k_sz, stride)

  def forward(self, pooled, indices):
    return self.unpool2d(pooled, indices)

In [None]:
def Upsample(input_tensor, max_indices):
    up_blocks = [
        UpBlock(64, 32, kernel_size=4, stride=4, padding=-0),
        Unpool2d(2, 2),
        UpBlock(32, 16, kernel_size=3, stride=1, padding=1),
        Unpool2d(2, 2),
        UpBlock(16, 8, kernel_size=3, stride=1, padding=1),
        Unpool2d(2, 2),
        UpBlock(8, 1, kernel_size=3, stride=1, padding=1),
    ]

    count = 0

    for idx, up_block in enumerate(up_blocks):

      if idx % 2 == 0:
        input_tensor = up_block(input_tensor)
      else:
        max_index_key = "block" + str(3-count)
        input_tensor = up_block(input_tensor, max_indices[max_index_key])
        count += 1

    return input_tensor


In [None]:
class weightedBCEDiceLoss(nn.Module):
  def __init__(self, weight_BCE, weight_Dice, epsilon):
    super(weightedBCEDiceLoss, self).__init__()
    self.alpha = weight_BCE
    self.beta = weight_Dice
    self.eps = epsilon
    assert self.alpha + self.beta == 1.0 , "weights of BCE loss and Dice Loss must sum to 1"

  def forward(self, preds, targets):
    preds = torch.sigmoid(preds).view(-1)
    targets = targets.view(-1)

    bce_criterion = nn.BCELoss()
    bce_loss = bce_criterion(preds, targets)

    intersection = torch.sum(preds * targets)
    union = torch.sum(preds) + torch.sum(targets)
    dice_coeff = (2 * intersection + self.eps) / (union + self.eps)
    dice_loss = 1 - dice_coeff

    total_loss = self.alpha * bce_loss + self.beta * dice_loss

    return total_loss


In [None]:
pred = torch.randn(2, 1, 256, 256)
target = torch.randint(0, 2, (2, 1, 256, 256)).float()

loss_fn  = weightedBCEDiceLoss(0.6, 0.4, 0.001)
loss_fn(pred, target)

tensor(0.6630)

In [None]:
config = {

    "embed_dim":64,
    "patch_size":32,
    "in_channels":3,
    "weight_bce": 0.5,
    "weight_dice": 0.5,
    "epsilon_loss_fn": 1e6,
    "train_batch_size": 64,
    "train_epochs": 100,
    "lr": 0.001,
    "ckpt_name": 'att_segnet.pth',
    "model_save_path":"/content/saved_models"

}

In [None]:
class AttentiveSegNet(nn.Module):
  def __init__(self, config):
    super(AttentiveSegNet, self).__init__()
    self.patch_generator = Patch_Embedding(config["embed_dim"], config["patch_size"], config["in_channels"])
    self.embed_dim = config['embed_dim']

  def forward(self, x):
    # Computes patch encodings for the input tensor x and scales the representations
    patched_encodings = self.patch_generator(x)
    patched_encodings = patched_encodings.view(patched_encodings.size(0), -1, self.embed_dim)
    _, attention_dimension, _ = patched_encodings.shape
    scaled_embeddings = attention_to_all_images(patched_encodings, attention_dimension)

    # Downsamples the input_tensor x
    max_indices, bottleneck_input = Downsample(x)

    # Fuses patch information with downsampled information
    bottleneck_plus_emb = add_bottleneck_and_patch_embeddings(bottleneck_input, scaled_embeddings)

    # Upscales the bottleneck input and gets the segmentation map
    output_mask = Upsample(bottleneck_plus_emb, max_indices)

    return output_mask


In [None]:
def train_for_one_epoch(epoch_idx, model, train_loader, optimizer, config):

    r"""
    Method to run the training for one epoch.
    :param epoch_idx: iteration number of current epoch
    :param model: Attentive SegNet model
    :param train_loader: Dataloader for the training set
    :param optimizer: optimizer to be used
    :param config: config dictionary that contains model information
    :return: loss value for the epoch
    """

    losses = []
    criterion = weightedBCEDiceLoss(config["weight_bce"], config["weight_dice"], config["epsilon_loss_fn"])

    # Iterates through the dataloader in form of batches
    for batch in tqdm(train_loader):
        im, mask = batch
        im = im.to(device)
        mask = mask.to(device)
        optimizer.zero_grad()
        model_output = model(im)
        loss = criterion(model_output, mask)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

    # Prints epoch and loss info
    print('Finished epoch: {} | Number Loss : {:.4f}'.
          format(epoch_idx + 1,
                 np.mean(losses)))
    return np.mean(losses)

def train(train_loader, config):
    model = AttentiveSegNet(config).to(device)
    num_epochs = config['train_epochs']
    optimizer = Adam(model.parameters(), lr=config['lr'])
    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)

    # Creates output directories
    if not os.path.exists(config['model_save_path']):
        os.mkdir(config['model_save_path'])

    # Loads checkpoint if found
    current_dir = os.get_cwd()
    if os.path.exists(os.path.join(config['model_save_path'], config['ckpt_name'])):
        print('Loading checkpoint')
        model.load_state_dict(torch.load(os.path.join(config['model_save_path'],
                                    config['ckpt_name']), map_location=device))
    best_loss = np.inf

    for epoch_idx in range(num_epochs):
        mean_loss = train_for_one_epoch(epoch_idx, model, train_loader, optimizer, config)
        scheduler.step(mean_loss)

        # Updates checkpoint if better model params found
        if mean_loss < best_loss:
            print('Improved Loss to {:.4f} .... Saving Model'.format(mean_loss))
            torch.save(model.state_dict(), os.path.join(config['model_save_path'],
                                            config['ckpt_name']))
            best_loss = mean_loss
        else:
            print('No Loss Improvement')

In [None]:
def get_segmentation_accuracy(preds, targets):
  preds = preds.view(-1)
  targets = targets.view(-1)

  matches = (preds == targets).sum().item()
  accuracy = matches / preds.shape(0)
  return matches

def inference(val_loader, config):
   model = AttentiveSegNet(config).to(device)
   model.eval()

   # Loads checkpoint if found
   if os.path.exists(os.path.join(config['model_save_path'], config['ckpt_name'])):
        print('Loading checkpoint')
        model.load_state_dict(torch.load(os.path.join(config['model_save_path'],
                                     config['ckpt_name']), map_location=device))
   else:
        print('No checkpoint found at {}'.format(os.path.join(config['model_save_path'],
                                        config['ckpt_name'])))

   for idx, batch in enumerate(tqdm(val_loader)):
        im, mask = batch
        im = im.to(device)
        mask = mask.to(device)
        preds = model(im)
        preds = preds.to(device)

        accuracy = get_segmentation_accuracy(preds, mask)
        print(f"Accuracy in batch {idx} during validation is: {accuracy}")

In [None]:
image_directory = ""
mask_directory = ""
train_loader, val_loader = make_train_val_loaders(image_directory, mask_directory)

In [None]:
train(train_loader, config)
inference(val_loader, config)