In [1]:
from torch.utils.data import Dataset
from torchvision import transforms


# Dataset class
class MySegmentationDataset(Dataset):
    def __init__(self, data_images, data_masks, transform=None):
        self.images = data_images  # Now these are PIL Images or tensors
        self.masks = data_masks
        self.transform = transform or transforms.Compose([
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # If they're PIL Images, you can directly transform them
        item_image = self.transform(self.images[idx])
        item_mask = (self.transform(self.masks[idx]) > 0).float()  # Binarize mask

        return item_image, item_mask

In [2]:
import torch
from torch import nn


class ChannelAttentionBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        reduced_channels = in_channels // reduction
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.shared_mlp = nn.Sequential(
            nn.Conv2d(in_channels, reduced_channels, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(reduced_channels, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.shared_mlp(self.avg_pool(x))
        max_out = self.shared_mlp(self.max_pool(x))
        scale = self.sigmoid(avg_out + max_out)
        return x * scale


class SpatialAttentionBlock(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        concat = torch.cat([max_out, avg_out], dim=1)
        scale = self.sigmoid(self.conv(concat))
        return x * scale


class MultiScaleDepthwiseConv(nn.Module):
    def __init__(self, in_channels, kernel_sizes=None):
        super().__init__()
        if kernel_sizes is None:
            kernel_sizes = [1, 3, 5]
        self.in_channels = in_channels
        self.branches = nn.ModuleList()

        for k in kernel_sizes:
            self.branches.append(nn.Sequential(
                nn.Conv2d(in_channels, in_channels, kernel_size=k, padding=k // 2, groups=in_channels, bias=False),
                nn.BatchNorm2d(in_channels),
                nn.ReLU6(inplace=True)
            ))

    def forward(self, x):
        for branch in self.branches:
            assert branch[1].num_features == x.shape[1], \
                f"Expected {branch[1].num_features} channels, got {x.shape[1]}"
        out = sum(branch(x) for branch in self.branches)
        return out


class MSCB(nn.Module):
    def __init__(self, in_channels, expansion_factor=2, kernel_sizes=None):
        super().__init__()
        if kernel_sizes is None:
            kernel_sizes = [1, 3, 5]
        self.in_channels = in_channels
        expanded_channels = in_channels * expansion_factor

        self.expand = nn.Sequential(
            nn.Conv2d(in_channels, expanded_channels, 1, bias=False),
            nn.BatchNorm2d(expanded_channels),
            nn.ReLU6(inplace=True)
        )

        self.msd_conv = MultiScaleDepthwiseConv(expanded_channels, kernel_sizes)

        self.project = nn.Sequential(
            nn.Conv2d(expanded_channels, in_channels, 1, bias=False),
            nn.BatchNorm2d(in_channels)
        )

    def forward(self, x):
        return self.project(self.msd_conv(self.expand(x)))


class MSCAM(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.cab = ChannelAttentionBlock(in_channels)
        self.sab = SpatialAttentionBlock()
        self.mscb = MSCB(in_channels)

    def forward(self, x):
        return self.mscb(self.sab(self.cab(x)))


class LGAG(nn.Module):
    def __init__(self, in_channels, groups=4):
        super().__init__()
        self.gc_g = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=groups, bias=False)
        self.gc_x = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=groups, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, 1, 1)
        self.bn1 = nn.BatchNorm2d(1)  # <--- NEW batch-norm for 1 channel
        self.sigmoid = nn.Sigmoid()

    def forward(self, g, x):
        g_bn = self.bn(self.gc_g(g))
        x_ = self.bn(self.gc_x(x))
        attn = self.relu(g_bn + x_)
        attn = self.bn1(self.conv1(attn))  # <--- use bn1
        attn = self.sigmoid(attn)
        return x * attn


class EUCB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(in_channels, out_channels, 3, padding=1, groups=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 1)
        )

    def forward(self, x):
        return self.block(x)


class SegHead(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class EMCADDecoder(nn.Module):
    def __init__(self, channels, out_channels=1):
        super().__init__()
        self.mscams = nn.ModuleList([MSCAM(c) for c in channels])
        self.seg_heads = nn.ModuleList([SegHead(c, out_channels) for c in channels])
        self.eucbs = nn.ModuleList([
            EUCB(channels[i], channels[i - 1]) for i in range(len(channels) - 1, 0, -1)
        ])
        self.lgags = nn.ModuleList([
            LGAG(channels[i - 1]) for i in range(len(channels) - 1, 0, -1)
        ])

        self.final_seghead = SegHead(channels[0], out_channels)

    def forward(self, feats):
        feats = [msc(f) for msc, f in zip(self.mscams, feats)]
        predictions = [seg(f) for seg, f in zip(self.seg_heads, feats)]
        x = feats[-1]

        for i in reversed(range(3)):
            x_up = self.eucbs[2 - i](x)
            x = self.lgags[2 - i](feats[i], x_up) + x_up

        # Dynamically match the target size to the first feature map
        x = self.final_seghead(x)

        return predictions, x

In [3]:
import timm


class PVTEMCAD(nn.Module):
    def __init__(self, encoder_name='pvt_v2_b2', out_channels=1):
        super().__init__()
        self.encoder = timm.create_model(encoder_name, pretrained=True, features_only=True)

        # Get the actual channel sizes from the encoder
        encoder_channels = self.encoder.feature_info.channels()

        self.decoder = EMCADDecoder(
            channels=encoder_channels,
            out_channels=out_channels
        )

    def forward(self, x):
        _, fused_out = self.decoder(self.encoder(x))
        return fused_out

## Train and Test

In [4]:
import os
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

# Get all image and mask paths
image_dir = 'datasets/CVC-ClinicDB/Original'
mask_dir = 'datasets/CVC-ClinicDB/Ground Truth'

image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')])
mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir)])

all_images = [Image.open(p).convert('RGB') for p in image_paths]
all_masks = [Image.open(p).convert('L') for p in mask_paths]

# Split into train and validation sets
train_images, val_images, train_masks, val_masks = train_test_split(
    all_images, all_masks, test_size=0.2, random_state=42
)

# Create DataLoaders
train_loader = DataLoader(MySegmentationDataset(train_images, train_masks), batch_size=8, shuffle=True)
val_loader = DataLoader(MySegmentationDataset(val_images, val_masks), batch_size=8)

In [None]:
from torch.nn import functional

# # Multi GPU
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device_ids = list(range(torch.cuda.device_count()))
# model = nn.DataParallel(PVTEMCAD().to(device), device_ids=device_ids)

# Single GPU
device = torch.device('cuda:3')  # Using GPU 3
model = PVTEMCAD().to(device)

bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0], device=device))


def dice_loss(predictions, target, smooth=1e-6):
    if predictions.shape != target.shape:
        target = target.unsqueeze(1)
    predictions = predictions.view(-1)
    target = target.view(-1)
    intersection = (predictions * target).sum()
    return 1 - (2. * intersection + smooth) / (predictions.sum() + target.sum() + smooth)


def criterion(prediction, target):
    return 0.2 * dice_loss(prediction, target) + 0.8 * bce(prediction, target)


optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
epochs = 100
for epoch in range(epochs):
    epoch_loss = 0
    num_batches = 0
    model.train()
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        output = model(images)
        output = functional.interpolate(output, size=masks.shape[2:], mode='bilinear', align_corners=False)

        loss = criterion(output, masks)
        epoch_loss += loss.item()
        num_batches += 1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}: Avg Loss = {epoch_loss / num_batches:.4f}")

In [None]:
import matplotlib.pyplot as plt

# Get the i-th image's prediction from the validation dataset
index = 18
image, mask = val_loader.dataset[index]
image = image.unsqueeze(0).to(device)

model.eval()
with torch.no_grad():
    pred = torch.sigmoid(model(image)).cpu().squeeze().numpy()

# Visualize properly
plt.subplot(1, 3, 1)
plt.imshow(image.squeeze().permute(1, 2, 0).cpu())
plt.title("Input")

plt.subplot(1, 3, 2)
plt.imshow(mask.squeeze(), cmap='gray')
plt.title("Mask")

plt.subplot(1, 3, 3)
plt.imshow(pred > 0.5, cmap='gray')
plt.title("Prediction")

plt.show()