In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os
import torchvision.transforms as T

# Custom Dataset


In [None]:
class image_shadow_dataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform = None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
    def __len__(self):
        return len(self.images)
    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, self.images[index])
        base_name = os.path.splitext(img_name)[0]
        mask_name = base_name + '_shadow.png'
        mask_path = os.path.join(self.mask_dir, mask_name)

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype = np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image = image, mask = mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
transform_n = A.Compose([ 
    A.Normalize(mean=[0.36184, 0.34747, 0.33529],         #calculation of normalize values down below
                        std=[0.22139, 0.22265, 0.24461]), #https://forums.fast.ai/t/image-normalization-in-pytorch/7534/7?u=laochanlam
    ToTensorV2()
])

In [None]:
BATCH_SIZE = 8
NUM_WORKERS = os.cpu_count()

train_ds = image_shadow_dataset(image_dir = train_dir,
                                    mask_dir= train_masks,
                                    transform = transform_n)
train_data_loader = DataLoader(dataset = train_ds,
                               batch_size= BATCH_SIZE,
                               num_workers= 0,
                               shuffle = True)
val_ds = image_shadow_dataset(image_dir = val_dir,
                                   mask_dir = val_masks,
                                   transform = transform_n)
val_data_loader = DataLoader(dataset = val_ds,
                            batch_size = BATCH_SIZE,
                            num_workers = 0,
                            shuffle= False)
test_ds = image_shadow_dataset(image_dir = test_dir,
                                    mask_dir= test_masks,
                                    transform = transform_n)
test_data_loader = DataLoader(dataset = test_ds,
                              batch_size = BATCH_SIZE,
                              num_workers = 0,
                              shuffle = False)

# Custom Model Classes


In [None]:
class ResNetEncoder(nn.Module):
    def __init__(self, backbone = 'resnet50', pretrained = True, channels = [64, 256, 512, 1024, 2048]):
        super(ResNetEncoder, self).__init__()
        resnet = models.resnet50(pretrained = True)
        self.channels = channels
        self.stem = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
        )
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        self.encoder_layers = nn.ModuleList([self.layer1,
                                             self.layer2,
                                             self.layer3,
                                             self.layer4])

    def forward(self, x):
        skip_connections = []
        x = self.stem(x)
        skip_connections.append(x)
        x = self.maxpool(x)

        for layers in self.encoder_layers:
            x= layers(x)
            skip_connections.append(x)

        return skip_connections


In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self, out_channels = 1, encoder_channels = [64, 256, 512, 1024, 2048],
                                         decoder_channels = [1024, 512, 256, 128, 64]):
        super(Decoder, self).__init__()
        self.decoder = nn.ModuleList()
        self.final_conv = nn.Sequential(nn.ConvTranspose2d(in_channels =64,
                                         out_channels=64,
                                         kernel_size = 2,
                                         stride = 2),
                                        nn.Conv2d(decoder_channels[-1], out_channels, kernel_size=1))
        self.use_skip = [True, True,True,False, True]
        for idx in range(len(decoder_channels)):
            self.decoder.append(
                nn.ConvTranspose2d(in_channels= decoder_channels[idx]*2,
                                   out_channels= decoder_channels[idx],
                                   kernel_size= 2,
                                   stride = 2)
            )
            if self.use_skip[idx] == False:
                self.decoder.append(nn.Sequential(nn.Conv2d(in_channels = 128,
                                                                     out_channels = 256,
                                                                     kernel_size = 1,
                                                                     stride = 1),
                                                  DoubleConv(in_channels=   encoder_channels[-(idx+1)],
                                                             out_channels= decoder_channels[idx]),

                                                 )
                                   )
            else:
                self.decoder.append(DoubleConv(in_channels=  decoder_channels[idx]*2,
                                               out_channels= decoder_channels[idx]))

    def forward(self, skip_connections):
        x = skip_connections[-1]
        skip_connections_reversed = skip_connections[:-1][::-1]

        for i in range(0, len(self.decoder), 2):
            if self.use_skip[i//2]:
                x = self.decoder[i](x)
                skip = skip_connections_reversed[i//2]

                if x.shape != skip.shape:
                    x = TF.resize(x, size=skip.shape[2:])
                concat_skip = torch.cat((skip, x), dim=1)
                x = self.decoder[i+1](concat_skip)
            else:
                x = self.decoder[i](x) #up sampling
                x = self.decoder[i+1](x)
                skip_connections_reversed.insert(i//2, skip_connections_reversed[i//2])

        return self.final_conv(x)

In [None]:
class ResNet50UNet(nn.Module):
    def __init__(self, encoder: nn.Module,
                       decoder: nn.Module):
        super(ResNet50UNet, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, x):
        encoder_out = self.encoder(x)
        decoder_out = self.decoder(encoder_out)

        return decoder_out


In [None]:
import torchvision.models as models
x = torch.randn((3, 3, 512, 512))
model = ResNet50UNet(encoder = ResNetEncoder(),
                         decoder = Decoder())
preds = model(x)
assert preds.shape[2:] == x.shape[2:]

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

# Model Training Functions

In [None]:
from surface_distance import metrics
def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               scaler: torch.cuda.amp.GradScaler,
               device: torch.device):

    model.train()

    train_loss = 0
    train_acc = 0
    train_dice = 0
    train_iou = 0
    train_surface_dice = 0
    total_pixels = 0

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        X = X.float()
        y = y.float()
        with torch.cuda.amp.autocast():
          y_pred = model(X)
          loss = loss_fn(y_pred, y.unsqueeze(1))
          train_loss += loss.item()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        #get predicted class (0 or 1)
        #y_pred_class = torch.argmax(y_pred, dim=1)
        y_pred_class = (torch.sigmoid(y_pred) > 0.5).float().squeeze(1)

        #dice score
        intersect = (y_pred_class * y).sum().item()
        dice = (2.0 * intersect) / (y_pred_class.sum().item() + y.sum().item() + 1e-8)
        train_dice += dice
        #IoU score
        union = y_pred_class.sum().item() + y.sum().item() - intersect
        iou = intersect / (union + 1e-8)
        train_iou += iou

        #surface_dice
        y_np= y.cpu().numpy()
        y_pred_np = y_pred_class.cpu().numpy()
        for i in range(y_np.shape[0]):
          surface_distances = metrics.compute_surface_distances(
            y_np[i] > 0,
            y_pred_np[i] > 0,
            spacing_mm=(1.0, 1.0)
        )
          surface_dice = metrics.compute_surface_dice_at_tolerance(
                surface_distances,
                tolerance_mm=7.0,
            )
          train_surface_dice += surface_dice

        #accuracy
        correct_pixels = (y_pred_class == y).sum().item()
        batch_pixels = y.numel() #numel returns len of y -> total pixels of the batch
        total_pixels += batch_pixels
        train_acc += correct_pixels

    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / total_pixels
    train_dice = train_dice / len(dataloader)
    train_iou = train_iou / len(dataloader)
    train_surface_dice /= (len(dataloader) * dataloader.batch_size)
    return train_loss, train_acc, train_dice, train_surface_dice, train_iou

In [None]:
def validation_step(model: torch.nn.Module,
                    dataloader: torch.utils.data.DataLoader,
                    loss_fn = torch.nn.Module,
                    device = torch.device):
    model.eval()
    val_loss, val_acc = 0, 0
    total_pixels = 0
    val_dice = 0
    val_iou = 0
    val_surface_dice = 0

    with torch.inference_mode():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            X = X.float()
            y = y.float()
            val_pred_logits = model(X)
            loss = loss_fn(val_pred_logits, y.unsqueeze(1))
            val_loss += loss.item()

            #val_pred_labels = torch.argmax(val_pred_logits, dim = 1)
            val_pred_labels = (torch.sigmoid(val_pred_logits) > 0.5).float().squeeze(1)
            #dice
            intersect = (val_pred_labels * y).sum().item()
            dice = (2.0 * intersect) / (val_pred_labels.sum().item() + y.sum().item() + 1e-8)
            val_dice += dice
            #iou
            union = val_pred_labels.sum().item() + y.sum().item() - intersect
            iou = intersect / (union + 1e-8)
            val_iou += iou

            #surface dice
            y_np= y.cpu().numpy()
            val_pred_labels_np = val_pred_labels.cpu().numpy()
            for i in range(y_np.shape[0]):
              surface_distances = metrics.compute_surface_distances(
                  y_np[i] > 0,
                  val_pred_labels_np[i] > 0,
                  spacing_mm=(1.0, 1.0)
        )
              surface_dice = metrics.compute_surface_dice_at_tolerance(
                surface_distances,
                tolerance_mm=7.0,
            )
              val_surface_dice += surface_dice

            #accuracy
            correct_pixels = (val_pred_labels == y).sum().item()
            batch_pixels = y.numel()
            total_pixels += batch_pixels
            val_acc += correct_pixels

        val_dice = val_dice / len(dataloader)
        val_loss = val_loss / len(dataloader)
        val_acc = val_acc / total_pixels if total_pixels > 0 else 0
        val_surface_dice /= (len(dataloader) * dataloader.batch_size)
        val_iou = val_iou / len(dataloader)
    return val_loss, val_acc, val_dice,val_iou, val_surface_dice

In [None]:
def test_step(model: torch.nn.Module,
                    dataloader: torch.utils.data.DataLoader,
                    loss_fn = torch.nn.Module,
                    device = torch.device):
    model.eval()
    test_loss, test_acc = 0, 0
    total_pixels = 0
    test_dice = 0
    test_iou = 0
    test_surface_dice = 0
    with torch.inference_mode():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            X = X.float()
            y = y.float()
            test_pred_logits = model(X)
            loss = loss_fn(test_pred_logits, y.unsqueeze(1))
            test_loss += loss.item()

            #test_pred_labels = torch.argmax(test_pred_logits, dim = 1)
            test_pred_labels = (torch.sigmoid(test_pred_logits) > 0.5).float().squeeze(1)
            #dice
            intersect = (test_pred_labels * y).sum().item()
            dice = (2.0 * intersect) / (test_pred_labels.sum().item() + y.sum().item() + 1e-8)
            test_dice += dice
            #iou score
            union = test_pred_labels.sum().item() + y.sum().item() - intersect
            iou = intersect / (union + 1e-8)
            test_iou += iou

            #surface_dice
            y_np= y.cpu().numpy()
            test_pred_labels_np = test_pred_labels.cpu().numpy()
            for i in range(y_np.shape[0]):
              surface_distances = metrics.compute_surface_distances(
                y_np[i] > 0,
                test_pred_labels_np[i] > 0,
                spacing_mm=(1.0, 1.0)
        )
              surface_dice = metrics.compute_surface_dice_at_tolerance(
                surface_distances,
                tolerance_mm=7.0,
            )

              test_surface_dice += surface_dice


            correct_pixels = (test_pred_labels == y).sum().item()
            batch_pixels = y.numel()
            total_pixels += batch_pixels
            test_acc += correct_pixels

        test_dice = test_dice / len(dataloader)
        test_loss = test_loss / len(dataloader)
        test_acc = test_acc / total_pixels if total_pixels > 0 else 0
        test_surface_dice /= (len(dataloader) * dataloader.batch_size)
        test_iou = test_iou / len(dataloader)
    return test_loss, test_acc, test_dice, test_iou, test_surface_dice

In [None]:
from tqdm.auto import tqdm
def train_loop(model: torch.nn.Module,
               train_dataloader: torch.utils.data.DataLoader,
               test_dataloader: torch.utils.data.DataLoader,
               val_dataloader: torch.utils.data.DataLoader,
               optimizer: torch.optim.Optimizer,
               loss_fn: torch.nn.Module,
               epochs: int,
               device: torch.device,
               scaler: torch.cuda.amp.GradScaler = None,
               scheduler: torch.optim.lr_scheduler._LRScheduler = None):

    results = {"train_loss": [],
               "train_acc": [],
               "train_dice": [],
               "train_iou": [],
               "train_surface_dice": [],
               "val_loss": [],
               "val_acc": [],
               "val_dice": [],
               "val_iou": [],
               "val_surface_dice": [],
               "test_loss": [],
               "test_acc": [],
               "test_dice": [],
               "test_iou": [],
               "test_surface_dice": []
               }
    best_model_dice = 0.0
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc, train_dice,train_iou, train_surface_dice = train_step(model= model,
                                           dataloader = train_dataloader,
                                           loss_fn= loss_fn,
                                           optimizer = optimizer,
                                           scaler= scaler,
                                           device = device)
        val_loss, val_acc, val_dice,val_iou, val_surface_dice = validation_step(model = model,
                                     dataloader = val_dataloader,
                                     loss_fn = loss_fn,
                                     device = device)
        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"train_dice: {train_dice:.4f} | "
            f"train_iou: {train_iou:.4f} | "
            f"train_surface_dice: {train_surface_dice:.4f}\n"
            f"Epoch: {epoch+1} | "
            f"val_loss: {val_loss:.4f} | "
            f"val_acc: {val_acc:.4f} | "
            f"val_dice: {val_dice:.4f} | "
            f"val_iou: {val_iou:.4f} | "
            f"val_surface_dice: {val_surface_dice:.4f}\n\n"
        )
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["train_dice"].append(train_dice)
        results["train_iou"].append(train_iou)
        results["train_surface_dice"].append(train_surface_dice)
        results["val_loss"].append(val_loss)
        results["val_acc"].append(val_acc)
        results["val_dice"].append(val_dice)
        results["val_iou"].append(val_iou)
        results["val_surface_dice"].append(val_surface_dice)

        filename = f"/content/scheduler_state_dicts_2/{epoch}_best_model_{val_dice:.4f}dice_score.pth"
        if val_dice > best_model_dice:
            best_model_dice = val_dice
            torch.save(model.state_dict(), filename)
            print(f"New best model with validation dice: {val_dice:.2f}")

        if scheduler is not None:
          scheduler.step()

    test_loss, test_acc, test_dice,test_iou, test_surface_dice = test_step(model = model,
                                    dataloader= test_dataloader,
                                    loss_fn = loss_fn,
                                    device = device)
    print(f"test_loss = {test_loss:.4f} | "
          f"test_acc = {test_acc:.4f} |"
          f"test_dice = {test_dice:.4f} |"
          f"test_iou = {test_iou:4f} |"
          f"test_surface_dice = {test_surface_dice:.4f}"
          )
    results["test_loss"] = test_loss
    results["test_acc"] = test_acc
    results["test_dice"] = test_dice
    results["test_iou"] = test_iou
    results["test_surface_dice"] = test_surface_dice

    return results

# Model Training

In [None]:
torch.cuda.empty_cache()

In [None]:
from timeit import default_timer as timer
from torch.optim import optimizer
from torch.optim.lr_scheduler import StepLR
import torchvision.models as models
NUM_EPOCHS = 20
model = ResNet50UNet(encoder = ResNetEncoder(),
                     decoder= Decoder()).to(device)
#model.load_state_dict(torch.load("/content/scheduler_state_dicts/15epoch_best_model_0.8932dice_score.pth"))
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr = 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.)
scaler = torch.cuda.amp.GradScaler()
start_time = timer()
model_0_results = train_loop(model= model,
                       train_dataloader= train_data_loader,
                       val_dataloader= val_data_loader,
                       test_dataloader= test_data_loader,
                       optimizer = optimizer,
                       scaler = scaler,
                       loss_fn = loss_fn,
                       epochs = NUM_EPOCHS,
                       device = device,
                      scheduler = scheduler)
end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")

In [None]:
import pickle

with open("training_results.pkl", "wb") as f:
    pickle.dump(model_0_results, f)

In [None]:
results["val_loss"]