In [1]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
class SegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.images = sorted([f for f in os.listdir(images_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        self.masks = sorted([f for f in os.listdir(masks_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.images[idx])
        mask_path = os.path.join(self.masks_dir, self.masks[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = SegmentationDataset("C:\\ProgramData\\anaconda3\\Scripts\\Yana\\Dataset/Split/train/images", "C:\\ProgramData\\anaconda3\\Scripts\\Yana\\Dataset/Split/train/masks", transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)


In [2]:
import torch
import torch.nn as nn
from torchvision.models.segmentation import deeplabv3_resnet50

# Define DeepLabV3+ for Binary Segmentation
class DeepLabV3Plus(nn.Module):
    def __init__(self, num_classes=1):  # num_classes=1 for binary segmentation
        super(DeepLabV3Plus, self).__init__()
        # Load the pre-trained DeepLabV3 model with ResNet-50 backbone
        self.model = deeplabv3_resnet50(pretrained=True)
        
        # Modify the final classification head to have 'num_classes' output channels
        self.model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)
    
    def forward(self, x):
        return self.model(x)['out']  # Extract only the 'out' key from DeepLabV3's output dictionary

# Initialize the model
model = DeepLabV3Plus(num_classes=1)  # 1 output channel for binary segmentation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)




In [3]:
class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.bce = nn.BCEWithLogitsLoss()

    def dice_loss(self, preds, targets):
        smooth = 1e-6
        preds = torch.sigmoid(preds)  # Apply sigmoid to raw logits
        intersection = (preds * targets).sum(dim=(2, 3))
        dice = (2 * intersection + smooth) / (preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) + smooth)
        return 1 - dice.mean()

    def forward(self, preds, targets):
        bce = self.bce(preds, targets)
        dice = self.dice_loss(preds, targets)
        return self.alpha * bce + self.beta * dice

# Initialize the criterion
criterion = CombinedLoss(alpha=0.5, beta=0.5)


In [4]:
import torch.optim as optim

# Initialize the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [5]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=50):
    model.train()  # Set the model to training mode
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader):.4f}")

# Call the training function
train_model(model, train_loader, criterion, optimizer, num_epochs=200)


Epoch 1/200, Loss: 0.5628
Epoch 2/200, Loss: 0.5065
Epoch 3/200, Loss: 0.5019
Epoch 4/200, Loss: 0.4972
Epoch 5/200, Loss: 0.4885
Epoch 6/200, Loss: 0.4654
Epoch 7/200, Loss: 0.4442
Epoch 8/200, Loss: 0.4319
Epoch 9/200, Loss: 0.4141
Epoch 10/200, Loss: 0.4030
Epoch 11/200, Loss: 0.3917
Epoch 12/200, Loss: 0.3835
Epoch 13/200, Loss: 0.3758
Epoch 14/200, Loss: 0.3699
Epoch 15/200, Loss: 0.3659
Epoch 16/200, Loss: 0.3618
Epoch 17/200, Loss: 0.3612
Epoch 18/200, Loss: 0.3556
Epoch 19/200, Loss: 0.3543
Epoch 20/200, Loss: 0.3524
Epoch 21/200, Loss: 0.3542
Epoch 22/200, Loss: 0.3518
Epoch 23/200, Loss: 0.3488
Epoch 24/200, Loss: 0.3422
Epoch 25/200, Loss: 0.3412
Epoch 26/200, Loss: 0.3393
Epoch 27/200, Loss: 0.3392
Epoch 28/200, Loss: 0.3427
Epoch 29/200, Loss: 0.3394
Epoch 30/200, Loss: 0.3347
Epoch 31/200, Loss: 0.3294
Epoch 32/200, Loss: 0.3280
Epoch 33/200, Loss: 0.3265
Epoch 34/200, Loss: 0.3224
Epoch 35/200, Loss: 0.3208
Epoch 36/200, Loss: 0.3205
Epoch 37/200, Loss: 0.3186
Epoch 38/2

In [7]:
import torch.nn.functional as F

def iou_and_dice(pred, target):
    """Calculate IoU and Dice coefficient."""
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum().float()
    union = (pred + target).clamp(0, 1).sum().float()
    dice = (2 * intersection) / (pred.sum() + target.sum() + 1e-6)
    iou = intersection / (union + 1e-6)
    return iou.item(), dice.item()

def evaluate_model(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    total_iou = 0.0
    total_dice = 0.0
    num_batches = 0

    with torch.no_grad():
        for images, masks in test_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)

            # Apply sigmoid and threshold
            predictions = (torch.sigmoid(outputs) > 0.5).float()

            # Resize predictions to match masks, if needed
            if predictions.shape != masks.shape:
                predictions = F.interpolate(predictions, size=masks.shape[2:], mode="bilinear", align_corners=False)

            # Calculate IoU and Dice for the batch
            batch_iou, batch_dice = iou_and_dice(predictions, masks)
            total_iou += batch_iou
            total_dice += batch_dice
            num_batches += 1

    avg_iou = total_iou / num_batches
    avg_dice = total_dice / num_batches

    print(f"Average IoU: {avg_iou:.4f}")
    print(f"Average Dice Coefficient: {avg_dice:.4f}")
    return avg_iou, avg_dice


In [8]:
test_dataset = SegmentationDataset("C:\\ProgramData\\anaconda3\\Scripts\\Yana\\Dataset/Split/test/images", "C:\\ProgramData\\anaconda3\\Scripts\\Yana\\Dataset/Split/test/masks", transform)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)


In [9]:
avg_iou, avg_dice = evaluate_model(model, test_loader)
print(f"Test Results: IoU = {avg_iou:.4f}, Dice = {avg_dice:.4f}")


Average IoU: 0.2989
Average Dice Coefficient: 0.4456
Test Results: IoU = 0.2989, Dice = 0.4456


In [None]:
import matplotlib.pyplot as plt
import torch.nn.functional as F

def visualize_predictions(model, dataloader):
    model.eval()
    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            predictions = (torch.sigmoid(outputs) > 0.5).float()

            # Resize predictions to match masks (if needed)
            if predictions.shape != masks.shape:
                predictions = F.interpolate(predictions, size=masks.shape[2:], mode="bilinear", align_corners=False)

            # Convert to CPU for visualization
            images = images.cpu()
            masks = masks.cpu()
            predictions = predictions.cpu()

            # Display the first image in the batch
            for i in range(min(len(images), 5)):  # Display up to 5 images
                plt.figure(figsize=(12, 4))
                plt.subplot(1, 3, 1)
                plt.imshow(images[i].permute(1, 2, 0))  # Convert from CHW to HWC
                plt.title("Input Image")
                plt.axis("off")

                plt.subplot(1, 3, 2)
                plt.imshow(masks[i][0], cmap="gray")
                plt.title("Ground Truth Mask")
                plt.axis("off")

                plt.subplot(1, 3, 3)
                plt.imshow(predictions[i][0], cmap="gray")
                plt.title("Model Prediction")
                plt.axis("off")

                plt.show()

            break  # Only visualize one batch

# Visualize predictions
visualize_predictions(model, test_loader)
