In [2]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from torchvision import transforms as TF
from PIL import Image
import matplotlib.pyplot as plt
from skimage.segmentation import quickshift
from scipy import stats
from collections import deque

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Copy model and dataset from Drive
!cp /content/drive/MyDrive/UNet_segmentation.pth /content/UNet_segmentation.pth
!unzip /content/drive/MyDrive/main_combined_dataset.zip -d /content/main_combined_dataset

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

# Dataset class
class SegmentationOrientedDefectDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=True):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_names = sorted(os.listdir(image_dir))
        self.mask_names = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_names[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_names[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # grayscale (for binary)

        if self.transform:
            image, mask = self.augment(image, mask)

        # Normalize image and convert to tensor
        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        # Convert mask to tensor and binarize
        mask = TF.to_tensor(mask)
        mask = (mask > 0.5).float()  # Ensure binary mask (0 or 1)

        return image, mask

    def augment(self, image, mask):
        # Random Horizontal Flip
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)

        # Random Vertical Flip
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)

        # Random rotation
        angle = random.choice([0, 90, 180, 270])
        image = TF.rotate(image, angle)
        mask = TF.rotate(mask, angle)

        return image, mask

# Set dataset paths (from extracted folder)
image_dir = "/content/main_combined_dataset/images"
mask_dir = "/content/main_combined_dataset/masks"

# Load dataset
full_dataset = SegmentationOrientedDefectDataset(image_dir, mask_dir, transform=True)

# Subset and split
random_indices = random.sample(range(len(full_dataset)), 5000)
dataset = Subset(full_dataset, random_indices)

train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Visualize dataset
def visualize_dataset(data_loader, num_samples=4):
    images, masks = next(iter(data_loader))

    images = images * 0.5 + 0.5  # Denormalize
    images = images.numpy()
    masks = masks.numpy()

    plt.figure(figsize=(12, 4 * num_samples))

    for i in range(num_samples):
        img = images[i].transpose(1, 2, 0)  # CHW -> HWC
        mask = masks[i][0]  # binary mask (1, H, W) -> (H, W)

        plt.subplot(num_samples, 2, 2 * i + 1)
        plt.imshow(img)
        plt.title("Image")
        plt.axis("off")

        plt.subplot(num_samples, 2, 2 * i + 2)
        plt.imshow(mask, cmap='gray')
        plt.title("Mask")
        plt.axis("off")

    plt.tight_layout()
    plt.show()

visualize_dataset(train_loader, num_samples=5)

# UNet model definition
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.encoder1 = conv_block(3, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.bottle_neck = conv_block(256, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.decoder3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.decoder2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.decoder1 = conv_block(128, 64)
        self.output = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool1(e1))
        e3 = self.encoder3(self.pool2(e2))
        b = self.bottle_neck(self.pool3(e3))
        d3 = self.up3(b)
        d3 = self.decoder3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.decoder2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.decoder1(torch.cat([d1, e1], dim=1))
        out = self.output(d1)
        return out

# Load pre-trained model
model = UNet().to(device)
model.load_state_dict(torch.load("/content/UNet_segmentation.pth"))
model.eval()

# Define evaluation metrics
def dice_coefficient(preds, targets, threshold=0.5):
    """
    Calculate the Dice coefficient between predicted and target binary masks.

    Args:
        preds (torch.Tensor): Predicted values (probabilities) from the model, typically after sigmoid.
        targets (torch.Tensor): Ground truth binary mask with values 0 or 1.
        threshold (float, optional): Threshold for binarizing predictions. Defaults to 0.5.

    Returns:
        float: Dice coefficient value.
    """
    # Binarize predictions: convert probabilities to 0 or 1 based on threshold
    preds = (preds > threshold).float()

    # Calculate intersection: sum of element-wise product of preds and targets
    intersection = (preds * targets).sum()

    # Calculate Dice coefficient: 2 * intersection / (sum of preds + sum of targets)
    # Add small epsilon (1e-8) to denominator to avoid division by zero
    dice = (2. * intersection) / (preds.sum() + targets.sum() + 1e-8)

    return dice

def iou_score(preds, targets, threshold=0.5):
    preds = (torch.sigmoid(preds) > threshold).float()
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum() - intersection
    return intersection / (union + 1e-8)

# Visualization function for single prediction
def visualize_prediction(image, mask, pred, alpha=0.6):
    image_np = (image * 0.5 + 0.5).cpu().permute(1, 2, 0).numpy()  # Denormalize
    mask_np = mask.cpu().squeeze().numpy()
    pred_np = pred.cpu().squeeze().numpy()

    overlay = image_np.copy()
    green_mask = np.zeros_like(image_np)
    green_mask[..., 1] = 1  # Green channel
    overlay_mask = np.where(pred_np[..., None] > 0, green_mask, 0)
    overlay = (1 - alpha) * image_np + alpha * overlay_mask
    overlay = np.clip(overlay, 0, 1)

    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    axs[0].imshow(image_np)
    axs[0].set_title('Original Image')
    axs[1].imshow(mask_np, cmap='gray')
    axs[1].set_title('Ground Truth Mask')
    axs[2].imshow(pred_np, cmap='gray')
    axs[2].set_title('Predicted Mask')
    axs[3].imshow(overlay)
    axs[3].set_title('Overlay on Image')

    for ax in axs:
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# Predict and evaluate without refinement
def predict_and_evaluate(model, test_loader, device, save_dir="predicted_masks", threshold=0.5, num_visualize=5):
    os.makedirs(save_dir, exist_ok=True)
    model.eval()
    dice_scores = []
    iou_scores = []

    with torch.no_grad():
        for idx, (images, masks) in enumerate(test_loader):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            preds = torch.sigmoid(outputs)
            preds_bin = (preds > threshold).float()

            for i in range(images.size(0)):
                pred_mask = preds_bin[i][0].cpu().numpy() * 255
                pred_mask_img = Image.fromarray(pred_mask.astype(np.uint8))
                pred_mask_img.save(os.path.join(save_dir, f"pred_{idx * test_loader.batch_size + i}.png"))

                dice = dice_coefficient(preds[i], masks[i])
                iou = iou_score(preds[i], masks[i])
                dice_scores.append(dice.item())
                iou_scores.append(iou.item())

                if idx * test_loader.batch_size + i < num_visualize:
                    visualize_prediction(images[i], masks[i], preds_bin[i])

    print(f"\nAverage Dice Score on Test Set: {np.mean(dice_scores):.4f}")
    print(f"Average IoU Score on Test Set: {np.mean(iou_scores):.4f}")

predict_and_evaluate(model, test_loader, device, save_dir="predicted_masks", threshold=0.5)

# Region Growth Refinement
def region_growing(image, seed_spacing=10, threshold=0.1):
    H, W = image.shape[:2]
    label_map = np.zeros((H, W), dtype=int)
    label_counter = 1
    neighbors = [(-1, 0), (1, 0), (0, -1), (0, 1)]

    for i in range(0, H, seed_spacing):
        for j in range(0, W, seed_spacing):
            if label_map[i, j] == 0:
                seed_color = image[i, j]
                queue = deque([(i, j)])
                label_map[i, j] = label_counter

                while queue:
                    x, y = queue.popleft()
                    for dx, dy in neighbors:
                        nx, ny = x + dx, y + dy
                        if 0 <= nx < H and 0 <= ny < W and label_map[nx, ny] == 0:
                            color_diff = np.linalg.norm(image[nx, ny] - seed_color)
                            if color_diff < threshold:
                                label_map[nx, ny] = label_counter
                                queue.append((nx, ny))
                label_counter += 1

    return label_map

def refine_with_region_growth(pred_mask, image, seed_spacing=10, threshold=0.1):
    label_map = region_growing(image, seed_spacing, threshold)
    refined_mask = np.zeros_like(pred_mask)

    for label in np.unique(label_map):
        if label == 0:
            continue
        region_mask = (label_map == label)
        majority_vote = stats.mode(pred_mask[region_mask].flatten(), keepdims=False).mode
        refined_mask[region_mask] = majority_vote

    return refined_mask

def evaluation_with_region_growth_refinement(model, test_loader, save_dir, threshold=0.5, num_visualize=5):
    os.makedirs(save_dir, exist_ok=True)
    model.eval()

    with torch.no_grad():
        for idx, (images, masks) in enumerate(test_loader):
            images = images.to(device)
            outputs = model(images)
            preds = torch.sigmoid(outputs)
            preds_bin = (preds > threshold).float()

            for i in range(images.size(0)):
                image = (images[i] * 0.5 + 0.5).cpu().permute(1, 2, 0).numpy()  # Denormalize
                pred_mask = preds_bin[i].squeeze().cpu().numpy()

                refined_mask = refine_with_region_growth(pred_mask, image, seed_spacing=10, threshold=0.1)

                refined_img = Image.fromarray((refined_mask * 255).astype(np.uint8))
                refined_img.save(os.path.join(save_dir, f"region_growth_refined_{idx * test_loader.batch_size + i}.png"))

                if idx * test_loader.batch_size + i < num_visualize:
                    visualize_prediction(images[i], masks[i], torch.tensor(refined_mask))

save_dir = "UNet_region_growth_refined"
evaluation_with_region_growth_refinement(model, test_loader, save_dir, threshold=0.5)

# Mean Shift Clustering Refinement
def mean_shift_clustering(image, kernel_size=3, max_dist=10):
    segments = quickshift(image, kernel_size=kernel_size, max_dist=max_dist, convert2lab=False)
    return segments

def refine_with_mean_shift(pred_mask, image, kernel_size=3, max_dist=10):
    segments = mean_shift_clustering(image, kernel_size, max_dist)
    refined_mask = np.zeros_like(pred_mask)

    for label in np.unique(segments):
        region_mask = (segments == label)
        majority_vote = stats.mode(pred_mask[region_mask].flatten(), keepdims=False).mode
        refined_mask[region_mask] = majority_vote

    return refined_mask

def evaluation_with_mean_shift_refinement(model, test_loader, save_dir, threshold=0.5, num_visualize=5):
    os.makedirs(save_dir, exist_ok=True)
    model.eval()

    with torch.no_grad():
        for idx, (images, masks) in enumerate(test_loader):
            images = images.to(device)
            outputs = model(images)
            preds = torch.sigmoid(outputs)
            preds_bin = (preds > threshold).float()

            for i in range(images.size(0)):
                image = (images[i] * 0.5 + 0.5).cpu().permute(1, 2, 0).numpy()  # Denormalize
                pred_mask = preds_bin[i].squeeze().cpu().numpy()

                refined_mask = refine_with_mean_shift(pred_mask, image, kernel_size=3, max_dist=10)

                refined_img = Image.fromarray((refined_mask * 255).astype(np.uint8))
                refined_img.save(os.path.join(save_dir, f"mean_shift_refined_{idx * test_loader.batch_size + i}.png"))

                if idx * test_loader.batch_size + i < num_visualize:
                    visualize_prediction(images[i], masks[i], torch.tensor(refined_mask))

save_dir = "UNet_mean_shift_refined"
evaluation_with_mean_shift_refinement(model, test_loader, save_dir, threshold=0.5)

# Full pipeline visualization
def visualize_full_pipeline(image, mask, pred_mask, region_growth_refined, mean_shift_refined):
    image_np = (image * 0.5 + 0.5).cpu().permute(1, 2, 0).numpy()  # Denormalize
    mask_np = mask.cpu().squeeze().numpy()
    pred_mask_np = pred_mask
    region_growth_np = region_growth_refined
    mean_shift_np = mean_shift_refined

    fig, axs = plt.subplots(1, 5, figsize=(20, 4))

    axs[0].imshow(image_np)
    axs[0].set_title("Original Image")
    axs[1].imshow(mask_np, cmap='gray')
    axs[1].set_title("Ground Truth")
    axs[2].imshow(pred_mask_np, cmap='gray')
    axs[2].set_title("Predicted Mask")
    axs[3].imshow(region_growth_np, cmap='gray')
    axs[3].set_title("Region Growth Refined")
    axs[4].imshow(mean_shift_np, cmap='gray')
    axs[4].set_title("Mean Shift Refined")

    for ax in axs:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# Visualize full pipeline for a few test images
num_visualize = 5
with torch.no_grad():
    for idx in range(num_visualize):
        image, mask = test_dataset[idx]
        image_input = image.unsqueeze(0).to(device)

        output = model(image_input)
        pred = torch.sigmoid(output).squeeze().cpu().numpy()
        pred_bin = (pred > 0.5).astype(np.uint8)

        image_np = (image * 0.5 + 0.5).permute(1, 2, 0).numpy()  # Denormalize

        # Apply refinements
        region_growth_refined = refine_with_region_growth(pred_bin, image_np)
        mean_shift_refined = refine_with_mean_shift(pred_bin, image_np)

        visualize_full_pipeline(image, mask, pred_bin, region_growth_refined, mean_shift_refined)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Archive:  /content/drive/MyDrive/main_combined_dataset.zip
replace /content/main_combined_dataset/main_combined_dataset/images/img_00000.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: N
Using cuda device


FileNotFoundError: [Errno 2] No such file or directory: '/content/main_combined_dataset/images'