# Advanced Computer Vision - Week_03 - Image segmentation

In [34]:
import os
import sys
import time
import torch
import wandb
import numpy as np
from torch import nn
from torch.utils.data import Dataset
import pandas as pd
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.model_selection import KFold
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt


## Jaccard Loss and Dice Loss Explanation
### Binary Cross-Entropy

Binary Cross-Entropy (BCE) loss is a commonly used loss function for binary classification tasks. It measures the dissimilarity between two probability distributions, typically the predicted probabilities and the true labels.

The BCE loss is calculated as follows:

For each prediction, it applies the sigmoid activation function to ensure the predicted values are in the range [0, 1], representing probabilities.
Then, it computes the Binary Cross-Entropy loss between the predicted probabilities and the true binary labels.
BCE loss penalizes the model based on the difference between predicted probabilities and true labels. It encourages the model to output probabilities close to 1 for positive examples and close to 0 for negative examples.

### Jaccard Loss:
The Jaccard index, also known as the Intersection over Union (IoU), is a measure of the similarity between two sets. In the context of segmentation tasks (such as image segmentation), the Jaccard index quantifies the overlap between the predicted segmentation and the ground truth segmentation.

Jaccard loss is a loss function derived from the Jaccard index. It measures the dissimilarity between two sets by calculating the ratio of the intersection of the sets to their union. The Jaccard loss is computed as: Jaccard Loss = 1 - (Intersection / Union)


Where:
- **Intersection**: Number of common elements between the predicted segmentation and the ground truth segmentation.
- **Union**: Total number of elements in both the predicted and ground truth segmentations.

The Jaccard loss penalizes the model when the predicted segmentation deviates from the ground truth segmentation, encouraging the model to produce segmentations with higher overlap with the ground truth.

### Dice Loss:

The Dice coefficient, also known as the Sørensen-Dice coefficient, is another measure of the similarity between two sets. Like the Jaccard index, it is commonly used in segmentation tasks to evaluate the overlap between the predicted segmentation and the ground truth segmentation.

Dice loss is a loss function derived from the Dice coefficient. It is calculated as:
Dice Loss = 1 - (2 * Intersection + ε) / (Total Predicted + Total Targets + ε)

Where:
- **Intersection**: Number of common elements between the predicted segmentation and the ground truth segmentation.
- **Total Predicted**: Total number of elements in the predicted segmentation.
- **Total Targets**: Total number of elements in the ground truth segmentation.
- **ε**: A small constant added to the denominator to avoid division by zero.

Similar to the Jaccard loss, the Dice loss encourages the model to produce segmentations with higher overlap with the ground truth. It penalizes deviations between the predicted and ground truth segmentations by measuring the dissimilarity between them.

Both Jaccard loss and Dice loss are commonly used as loss functions in tasks involving segmentation, such as medical image analysis, object detection, and scene understanding. They are effective in guiding the training process towards producing accurate segmentations.



In [36]:
import torch
import torch.nn as nn

class JaccardLoss(nn.Module):
    def __init__(self):
        """
        Initializes the JaccardLoss module.
        """
        super(JaccardLoss, self).__init__()

    def forward(self, predictions, targets):
        """
        Calculates the Jaccard loss between predictions and targets.

        Args:
            predictions (torch.Tensor): Predicted values.
            targets (torch.Tensor): Target values.

        Returns:
            torch.Tensor: Computed Jaccard loss.
        """
        # Calculate the intersection of predictions and targets
        intersection = torch.sum(predictions * targets)
        # Calculate the union of predictions and targets
        union = torch.sum(predictions + targets) - intersection
        # Calculate the Jaccard index with a small epsilon to avoid division by zero
        jaccard = (intersection + 1e-5) / (union + 1e-5)
        # Calculate the Jaccard loss
        jaccard_loss = 1 - jaccard
        return jaccard_loss


class DiceLoss(nn.Module):
    def __init__(self):
        """
        Initializes the DiceLoss module.
        """
        super(DiceLoss, self).__init__()

    def forward(self, predictions, targets):
        """
        Calculates the Dice loss between predictions and targets.

        Args:
            predictions (torch.Tensor): Predicted values.
            targets (torch.Tensor): Target values.

        Returns:
            torch.Tensor: Computed Dice loss.
        """
        # Calculate the intersection of predictions and targets
        intersection = torch.sum(predictions * targets)
        # Calculate the Dice coefficient with a small epsilon to avoid division by zero
        dice_coefficient = (2.0 * intersection + 1e-5) / (torch.sum(predictions) + torch.sum(targets) + 1e-5)
        # Calculate the Dice loss
        dice_loss = 1 - dice_coefficient
        return dice_loss

class BinaryCrossEntropyLoss(nn.Module):
    def __init__(self):
        """
        Initializes the BinaryCrossEntropyLoss module.
        """
        super(BinaryCrossEntropyLoss, self).__init__()

    def forward(self, predictions, targets):
        """
        Calculates the Binary Cross-Entropy loss between predictions and targets.

        Args:
            predictions (torch.Tensor): Predicted values.
            targets (torch.Tensor): Target values.

        Returns:
            torch.Tensor: Computed Binary Cross-Entropy loss.
        """
        # Apply sigmoid activation to predictions to ensure they are in the range [0, 1]
        predictions = torch.sigmoid(predictions)
        # Compute the Binary Cross-Entropy loss
        bce_loss = nn.BCELoss()(predictions, targets)
        return bce_loss


In [37]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, fold, transform=None, train=True, tp=0):
        self.data_frame = pd.read_csv(csv_file)[["video_id_x", "frame_cropped_path", "mask_cropped_path", "polygon_label"]]
        if tp == 0:
            self.data_frame = self.data_frame[self.data_frame["polygon_label"].isin(["lungslidingpresent", "lungslidingabsent"])]
        elif tp == 1:
            self.data_frame = self.data_frame[self.data_frame["polygon_label"] == "aline"]
        elif tp == 2:
            self.data_frame = self.data_frame[self.data_frame["polygon_label"] == "bline"]
        self.transform = transform

        # Create folds based on patient_id
        kf = KFold(n_splits=5, shuffle=True, random_state=42)
        patient_ids = self.data_frame['video_id_x'].values
        for fold_num, (train_index, val_index) in enumerate(kf.split(X=range(len(patient_ids)))):
            if fold_num == fold:
                if train:
                    self.data_frame = self.data_frame.iloc[train_index]
                else:
                    self.data_frame = self.data_frame.iloc[val_index]
                break

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_path = self.data_frame.iloc[idx, 1]
        mask_path = self.data_frame.iloc[idx, 2]
        image = Image.open(f"{PATH}/"+img_path)
        mask = Image.open(f"{PATH}/"+mask_path)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return [image, mask]

In [38]:
PATH = "/Users/hlibokymaros/Documents/_datasets/APVV_Lung/revision_8"
def getDataset(batch_size=256, workers=16, fold=0, tp=0):
    transform = transforms.Compose([  # Transform for test set
        transforms.Resize((512,512)),
        transforms.ToTensor()
    ])

    trainset = CustomDataset(f"{PATH}/frames_label_full_final_all.csv", fold, transform=transform, tp=tp)
    testset = CustomDataset(f"{PATH}/frames_label_full_final_all.csv", fold, transform=transform, train=False, tp=tp)

    #trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, num_workers=8, sampler=sampler_init)
    # loader for the training set
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, num_workers=workers, shuffle=True)
    # loader for the training set
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=workers)
    # loader for the testing set
    return {"train": trainloader, "val": testloader}  # set of loaders


In [39]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder
        self.encoder1 = self.conv_block(1, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)

        # Middle (bottleneck)
        self.middle = self.deconv_block(512, 1024)
        self.maxpool = nn.MaxPool2d(kernel_size=2)

        # Decoder
        self.decoder4 = self.deconv_block(1024, 512)
        self.decoder3 = self.deconv_block(512, 256)
        self.decoder2 = self.deconv_block(256, 128)
        self.decoder1 = self.last_deconv_block(128, 64)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same"),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same"),
            nn.ReLU(inplace=True)
        )


    def deconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same"),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same"),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(out_channels, out_channels//2, kernel_size=2, stride=2)
        )


    def last_deconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same"),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same"),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, 1, kernel_size=1),
            nn.Sigmoid()
        )


    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.maxpool(enc1))
        enc3 = self.encoder3(self.maxpool(enc2))
        enc4 = self.encoder4(self.maxpool(enc3))

        # Middle (bottleneck)
        middle = self.middle(self.maxpool(enc4))

        # Decoder with skip connections
        dec4 = torch.cat([enc4, middle], dim=1)  # Skip connection
        dec4 = self.decoder4(dec4)

        dec3 = torch.cat([enc3, dec4], dim=1)  # Skip connection
        dec3 = self.decoder3(dec3)

        dec2 = torch.cat([enc2, dec3], dim=1)  # Skip connection
        dec2 = self.decoder2(dec2)

        dec1 = torch.cat([enc1, dec2], dim=1)  # Skip connection
        output = self.decoder1(dec1)

        return output


In [44]:
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('mps')
device = torch.device('cpu')

In [45]:
# def train(device, lr=0.0001, fold=0, tp=0, loss_type=0):
lr=0.0001
fold=0
tp=0
loss_type=0
    # wandb.login()
batch_size = 2**4
loaders = getDataset(batch_size, fold, tp)  # loader for dataset



model = UNet().to(device)  # capsule model



#model = smp.Unet().to(device)  # capsule model
optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # optimizer
lr_decay = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.98)  # weight decay, which is used in capsules

bce = nn.BCELoss().to(device)
iou = JaccardLoss().to(device)
dice = DiceLoss().to(device)

In [46]:
if loss_type == 0:
    criterion = bce  # loss function for classification
elif loss_type == 1:
    criterion = iou
elif loss_type == 2:
    criterion = dice
val_loss_min = np.inf  # variable to keep the performance of the best model
counter = 10

types = ["Pleura", "Aline", "Bline"]
losses = ["BCE", "IOU", "DICE"]

# wandb.init(
#     project="Lung",
#
#     config={
#         "learning_rate": lr,
#         "Model": "Unet",
#         "fold": fold,
#         "type": types[tp],
#         "loss_type": losses[loss_type]
#     }
# )

for epoch in range(100):
    start = time.time()
    model.train()  # change model to train mode
    ls = []  # empty list to save classification loss
    for inputs, labels in tqdm(loaders["train"]):
        inputs, labels = inputs.to(device), labels.to(device)
        caps = model(inputs)  # get output from capsule model.  Output is [batch, classes, caps_dim]
        loss = criterion(caps, labels)  # classification error
        loss.backward()  # calculating gradients
        ls.append(loss.detach().cpu().item())
        optimizer.step()  # changing weights
        optimizer.zero_grad()
    lr_decay.step()  # learning rate decayed by gamma: lr = lr * gamma
    ls = sum(ls)/len(ls)
    bce_val = []
    iou_val = []
    dice_val = []
    model.eval()
    with torch.no_grad():
        for inputs, labels in loaders["val"]:
            inputs, labels = inputs.to(device), labels.to(device)
            caps = model(inputs)
            bce_val.append(bce(caps, labels).detach().cpu().item())
            iou_val.append(iou(caps, labels).detach().cpu().item())
            dice_val.append(dice(caps, labels).detach().cpu().item())
    bce_val = sum(bce_val)/len(bce_val)
    iou_val = sum(iou_val)/len(iou_val)
    dice_val = sum(dice_val)/len(dice_val)
    if loss_type == 0:
        val_loss = bce_val
    elif loss_type == 1:
        val_loss = iou_val
    elif loss_type == 2:
        val_loss = dice_val
    if val_loss_min > val_loss:
        counter = 10
        val_loss_min = val_loss
        torch.save(model.state_dict(), str(fold)+"_" + str(tp) +"_" +str(loss_type) +".mo")
    print(
        "Epoch %d, train_loss %4.4f, bce_val %4.4f, iou_val %4.4f, dice_val %4.4f, time %4.2f" % (
            epoch, ls, bce_val, iou_val, dice_val, time.time() - start))
    # wandb.log({"train_loss": ls, "val_loss": val_loss, "bce_val": bce_val, "iou_val": iou_val, "dice_val": dice_val, "time": time.time() - start})
    if counter == 0:
        break
    counter -= 1
# wandb.config["final_loss"] = val_loss_min
# wandb.finish()

  2%|▏         | 7/396 [09:16<8:35:02, 79.44s/it]


KeyboardInterrupt: 

In [None]:
if __name__ == '__main__':
    cd = int(sys.argv[1])
    device = "cuda:" + sys.argv[1]
    for i in range(5):
        for j in range(3):
            for k in range(3):
                #if i*9+j*3+k % 2 == cd:
                seed = 42
                torch.manual_seed(seed)
                torch.cuda.manual_seed(seed)
                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False
                np.random.seed(seed)
                #pretrained = int(sys.argv[1]) > 0
                train(fold=i, tp=j, loss_type=k)