In [None]:
!pip install -q torchmetrics==0.10.1

In [None]:
import os
import math
import tqdm
import numpy as np
from PIL import Image
from skimage import io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
import matplotlib.pyplot as plt
from torch.nn.modules.utils import _single, _pair, _triple
from copy import deepcopy
from sklearn.model_selection import train_test_split
import glob
import shutil
import torchmetrics

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 8
NUM_EPOCHS = 75
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
IMAGE_MEAN = [0.2826, 0.2826, 0.2826]
IMAGE_STD = [0.2695, 0.2695, 0.2695]
LEARNING_RATE = 1e-4
TRAIN_IMG_DIR = "/kaggle/input/ra-attention-util-notebookf6efc8202c/frames/"
print("Number of samples:", len(os.listdir(TRAIN_IMG_DIR)))
train_img_files, test_img_files = train_test_split(os.listdir(TRAIN_IMG_DIR), test_size=0.2, shuffle=True, random_state=42)

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_files, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = image_files

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])
        image = io.imread(img_path)
        mask = np.array(Image.open(mask_path).convert("L"))
        mask = (np.reshape(mask, (*mask.shape, 1))/255).astype(np.float32)
        mask[mask>1.0] = 1.0
        mask[mask<0.0] = 0.0

        if self.transform:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
            mask = mask.permute((2,0,1))

        return image, mask

In [None]:
def get_transforms(train=True):
    train_transforms = []
    if train:
        train_transforms = [
            A.Rotate(limit=15, p=0.5),
        ]
        
    transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        *train_transforms,
        A.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD),
        ToTensorV2(),
    ])
    return transform

train_dataset = SegmentationDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, train_img_files, get_transforms(True))
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
test_dataset = SegmentationDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, test_img_files, get_transforms(False))
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)

In [None]:
class FuzzyConv(nn.Module):
    """
    Fuzzy Conv Layer
    """

    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, mu=0.1, *args, **kwargs):
        super(FuzzyConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dilation = _pair(dilation)  # all pairs are considered to be of form (n, n)
        self.mu = mu
        if self.dilation[0] != self.dilation[1]:
            raise NotImplementedError(f"Unequal dilation not supported. Got dilation rates: {self.dilation}")
        if self.dilation[0] > 1:
            fuzzy_points = np.linspace(self.mu, 1, int(self.dilation[0] / 2) + 1)
            fuzzy_points /= np.sum(fuzzy_points)
            self.fuzzy_points = nn.Parameter(torch.Tensor(fuzzy_points), requires_grad=True)
            self.create_fuzzy_mask()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, dilation=self.dilation, *args, **kwargs)

    def create_fuzzy_mask(self):
        fh, fw = self.dilation if self.dilation[0] % 2 else (self.dilation[0] + 1, self.dilation[1] + 1)
        fuzzy_mask = torch.zeros((fh, fw), device=self.fuzzy_points.device)
        for e, v in enumerate(self.fuzzy_points):
            for _ in range(int(self.dilation[0] / 2)):
                fuzzy_mask[e : fh - e, e : fw - e] = v

        self.fuzzy_mask = nn.Parameter(torch.broadcast_to(fuzzy_mask, (self.in_channels, 1, fh, fw)), requires_grad=False)

    def forward(self, x):
        if self.dilation[0] > 1:
            if self.training:
                self.create_fuzzy_mask()
            x = F.conv2d(x, self.fuzzy_mask, padding="same", groups=self.in_channels)
        x = self.conv(x)
        return x


class SpatialAttention(nn.Module):
    def __init__(self, in_channels, out_size, kernel_size, downsample_ratio, dilation, mu=0.1, upsample_mode="bilinear"):
        super(SpatialAttention, self).__init__()
        self.downsample_ratio = downsample_ratio
        if downsample_ratio == 1:
            self.fuzzy_conv = FuzzyConv(in_channels, in_channels, kernel_size, dilation, bias=True, padding=dilation, stride=downsample_ratio, mu=mu)
        else:
            self.fuzzy_conv = FuzzyConv(in_channels, in_channels, kernel_size, dilation, bias=True, padding=dilation, stride=downsample_ratio, mu=mu) #  * (int(dilation / 2) + 1)
            self.upsample = nn.Upsample(size=out_size, mode=upsample_mode)
        
        self.conv2d = nn.Conv2d(in_channels, in_channels, kernel_size, padding="same") #  * (int(dilation / 2) + 1)

    def forward(self, x):
        if self.downsample_ratio == 1:
            y = self.fuzzy_conv(x)
        else:
            y = self.fuzzy_conv(x)
            y = self.upsample(y)
        
        y = self.conv2d(y)
        return torch.add(x, y)


class GlobalLearnablePool(nn.Module):
    """
    Learnable Pooling Layer.
    Takes (n, ci, h, w) input and returns (n, co, 1, 1) output
    """

    def __init__(self, in_channels, in_height, in_width):
        super(GlobalLearnablePool, self).__init__()
        self.in_channels = in_channels
        self.pool_mask = nn.Parameter(torch.ones((in_channels, 1, in_height, in_width)) / (in_height * in_width), requires_grad=True)

    def forward(self, x):
        x = F.conv2d(x, self.pool_mask, groups=self.in_channels)
        return x


class ChannelAttention(nn.Module):
    def __init__(self, in_channels, in_height, in_width, se_hidden_channels, dropout_rate=0.1):
        super(ChannelAttention, self).__init__()
        self.pool = GlobalLearnablePool(in_channels, in_height, in_width)
        self.linear1 = nn.Linear(in_channels, se_hidden_channels)
        self.dropout = nn.Dropout(dropout_rate)
        self.linear2 = nn.Linear(se_hidden_channels, in_channels)

    def forward(self, x):
        out = self.pool(x)
        out = out.reshape(x.shape[0], -1)
        out = self.linear1(out)
        out = F.relu(out)
        out = self.dropout(out)
        out = self.linear2(out)
        out = out[:, :, None, None]  # reshape - add height and width axes
        out = x * torch.sigmoid(out)
        return out

class ChannelSpatialAttention(nn.Module):
    def __init__(self, in_channels, out_size, kernel_size, downsample_ratio_list, dilation_rates_list, mu=0.1, upsample_mode="bilinear"):
        super().__init__()
        if isinstance(downsample_ratio_list, int):
            downsample_ratio_list = [downsample_ratio_list]
        if isinstance(dilation_rates_list, int):
            dilation_rates_list = [dilation_rates_list]
        if isinstance(kernel_size, int):
            kernel_size_list = [kernel_size] * len(dilation_rates_list)
        else:
            kernel_size_list = kernel_size
        assert len(kernel_size_list) == len(downsample_ratio_list) == len(dilation_rates_list), "Downsample list and Dilation list have to be equal length"
        fuzzy_convs = []
        for kernel_size, stride, dilation in zip(kernel_size_list, downsample_ratio_list, dilation_rates_list):
            fuzzy_convs.append(
                nn.Sequential(
                    ChannelAttention(in_channels, 32, 32, 512, 0.2),
                    SpatialAttention(in_channels, out_size, kernel_size, stride, dilation, mu),
                )
            )
        self.fuzzy_convs = nn.ModuleList(fuzzy_convs)
        self.channel_conv = nn.Conv2d(in_channels * len(downsample_ratio_list), in_channels, (1, 1))

    def forward(self, x):
        out = []
        for fuzzy_conv in self.fuzzy_convs:
            out.append(fuzzy_conv(x))

        out = torch.concat(out, dim=1)
        out = self.channel_conv(out)
        return out


class ChannelShuffle(nn.Module):
    def __init__(self, groups) -> None:
        super().__init__()
        self.groups = groups

    def forward(self, x):
        """Channel Shuffle operation.
        This function enables cross-group information flow for multiple groups
        convolution layers.
        Args:
            x (Tensor): The input tensor.
            groups (int): The number of groups to divide the input tensor
                in the channel dimension.
        Returns:
            Tensor: The output tensor after channel shuffle operation.
        """

        batch_size, num_channels, height, width = x.size()
        assert num_channels % self.groups == 0, "num_channels should be divisible by groups"
        channels_per_group = num_channels // self.groups

        x = x.view(batch_size, self.groups, channels_per_group, height, width)
        x = torch.transpose(x, 1, 2).contiguous()
        x = x.view(batch_size, self.groups * channels_per_group, height, width)

        return x


class ResidualDenseBlock(nn.Module):
    """Achieves densely connected convolutional layers.
    `Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
    Args:
        channels (int): The number of channels in the input image.
        growth_channels (int): The number of channels that increase in each layer of convolution.
    """

    def __init__(self, channels: int, growth_channels: int) -> None:
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels + growth_channels * 0, growth_channels, (3, 3), (1, 1), (1, 1))
        self.conv2 = nn.Conv2d(channels + growth_channels * 1, growth_channels, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(channels + growth_channels * 2, growth_channels, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(channels + growth_channels * 3, growth_channels, (3, 3), (1, 1), (1, 1))
        self.conv5 = nn.Conv2d(channels + growth_channels * 4, channels, (3, 3), (1, 1), (1, 1))
        self.leaky_relu = nn.LeakyReLU(0.2, True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out1 = self.leaky_relu(self.conv1(x))
        out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
        out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
        out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
        out5 = self.conv5(torch.cat([x, out1, out2, out3, out4], 1))
        out = torch.mul(out5, 0.2)
        out = torch.add(out, x)
        return out


class ResidualResidualDenseBlock(nn.Module):
    """Multi-layer residual dense convolution block.
    Args:
        channels (int): The number of channels in the input image.
        growth_channels (int): The number of channels that increase in each layer of convolution.
    """

    def __init__(self, channels: int, growth_channels: int) -> None:
        super(ResidualResidualDenseBlock, self).__init__()
        self.rdb1 = ResidualDenseBlock(channels, growth_channels)
        self.rdb2 = ResidualDenseBlock(channels, growth_channels)
        self.rdb3 = ResidualDenseBlock(channels, growth_channels)
        self.batchnorm = nn.BatchNorm2d(channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        # out = torch.mul(out, 0.2)
        out = self.batchnorm(out)
        out = torch.add(out, x)

        return out

In [None]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.8)

train_metrics = [torchmetrics.Accuracy().to(DEVICE), torchmetrics.Dice().to(DEVICE)] #, torchmetrics.classification.BinaryJaccardIndex().to(DEVICE)]
test_metrics = [torchmetrics.Accuracy().to(DEVICE), torchmetrics.Dice().to(DEVICE)] #, torchmetrics.classification.BinaryJaccardIndex().to(DEVICE)]

In [None]:
max_dice = 0
max_epoch = 0
NUM_EPOCHS=5
for epoch in range(1, NUM_EPOCHS+1):
    print(epoch,'/',NUM_EPOCHS)
    total_loss = 0
    model.train()
    for imgs, masks in tqdm.notebook.tqdm(train_loader):
        optimizer.zero_grad()
        imgs = imgs.to(DEVICE)
        masks = masks.to(DEVICE)
        outputs = torch.sigmoid(model(imgs))
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        for metric in train_metrics:
            metric.update(outputs, masks.to(torch.int8))

    print("Train Loss:", total_loss)
    for metric in train_metrics:
        print(metric, metric.compute().item())
        metric.reset()
    
    print()
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for imgs, masks in tqdm.notebook.tqdm(test_loader):
            imgs = imgs.to(DEVICE)
            masks = masks.to(DEVICE)
            outputs = torch.sigmoid(model(imgs))
            loss = criterion(outputs, masks)
            total_loss += loss.item()
            for metric in test_metrics:
                metric.update(outputs, masks.to(torch.int8))
    
    print("Test Loss:", total_loss)
    for metric in test_metrics:
        print(metric, metric.compute().item())
        if metric.__str__() == "Dice()" and metric.compute().item() > max_dice:
            max_dice = metric.compute().item()
            max_epoch = epoch
            torch.save(model, f'model-{epoch}.pt')
        metric.reset()
    print()
    
    scheduler.step()

print("Max Dice:", max_dice, "Epoch:", max_epoch)