In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score
from google.colab import drive

In [None]:
# connect the google drive to you colab for data acess
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Update mask transform to resize masks to 104x104 to match model's output
class MedicalImageDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.image_files = os.listdir(image_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.image_files[idx])  # Assuming mask has the same filename

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask

# Define transformations for images and masks
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
    transforms.Resize((104, 104)),  # Resize mask to match model's current output size
    transforms.ToTensor()
])

# Load dataset with separate transforms for images and masks
train_dataset = MedicalImageDataset(image_dir, mask_dir, image_transform=image_transform, mask_transform=mask_transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)


In [None]:
class SqueezeNetSegmentation(nn.Module):
    def __init__(self, num_classes=1):  # Binary segmentation
        super(SqueezeNetSegmentation, self).__init__()
        squeezenet = models.squeezenet1_1(pretrained=True)

        # Keep the feature extraction part of SqueezeNet
        self.features = nn.Sequential(*list(squeezenet.features.children())[:-1])

        # Add upsampling layers for segmentation
        # Adjusted kernel_size and stride in ConvTranspose2d layers
        # to achieve the desired output size of 224x224
        self.upsample = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),  # Adjusted
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # Adjusted
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, num_classes, kernel_size=4, stride=2, padding=1),  # Adjusted
            nn.Sigmoid()  # For binary segmentation
        )

    def forward(self, x):
        x = self.features(x)
        x = self.upsample(x)
        return x

In [None]:
# Loss function and optimizer
model = SqueezeNetSegmentation()
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=0.001)




In [None]:
def refine_breast_area(mask, min_area_threshold=5000):
    print(f"Refining mask with area threshold {min_area_threshold}...")
    mask = (mask * 255).astype(np.uint8)

    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    refined_mask = np.zeros_like(mask)

    for contour in contours:
        area = cv2.contourArea(contour)
        if area >= min_area_threshold:
            print(f"Retaining contour with area: {area}")
            cv2.drawContours(refined_mask, [contour], -1, (255), thickness=cv2.FILLED)

    refined_mask = (refined_mask > 0).astype(np.uint8)
    print("Mask refinement complete.")
    return refined_mask


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train(model, loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss / len(loader)}")

# Start training
train(model, train_loader, criterion, optimizer, epochs=10)


Epoch [1/10], Loss: 0.3884918871853087
Epoch [2/10], Loss: 0.1335462292449342
Epoch [3/10], Loss: 0.08776108589437273
Epoch [4/10], Loss: 0.0651119156844086
Epoch [5/10], Loss: 0.05618860189699464
Epoch [6/10], Loss: 0.050617175073259406
Epoch [7/10], Loss: 0.050209627486765385
Epoch [8/10], Loss: 0.04277064268373781
Epoch [9/10], Loss: 0.03880380766673221
Epoch [10/10], Loss: 0.03658887878474262


In [None]:
def evaluate_with_refinement(model, loader, min_area_threshold=5000):
    model.eval()
    dice_scores = []
    jaccard_indices = []

    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass to get predictions
            outputs = model(images)
            preds = outputs > 0.5  # Threshold to get binary predictions

            # Convert predictions to numpy for post-processing
            preds_np = preds.squeeze(1).cpu().numpy()

            # Apply post-processing refinement to each mask in the batch
            refined_preds = np.array([refine_breast_area(pred, min_area_threshold) for pred in preds_np])

            # Convert refined predictions back to tensors for metric calculation
            refined_preds = torch.from_numpy(refined_preds).to(device)

            # Calculate metrics (Dice and Jaccard)
            for i in range(len(refined_preds)):
                pred_flat = refined_preds[i].flatten()
                mask_flat = masks[i].cpu().numpy().flatten()

                intersection = (pred_flat * mask_flat).sum()
                dice = (2 * intersection) / (pred_flat.sum() + mask_flat.sum() + 1e-7)
                dice_scores.append(dice)

                jaccard = jaccard_score(mask_flat, pred_flat, average='binary')
                jaccard_indices.append(jaccard)

    avg_dice = np.mean(dice_scores)
    avg_jaccard = np.mean(jaccard_indices)
    print(f"Avg Dice Coefficient: {avg_dice}, Avg Jaccard Index: {avg_jaccard}")
    return avg_dice, avg_jaccard


In [None]:
pip install opencv-python




In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np

def refine_breast_area(mask, min_area_threshold=5000):
    """
    Refine the segmentation mask to focus on the breast area by removing small unwanted regions.

    Parameters:
    mask (numpy.ndarray): Binary mask where the breast area is 1, and background is 0.
    min_area_threshold (int): Minimum area threshold for retaining a region.

    Returns:
    numpy.ndarray: Refined binary mask with unwanted areas removed.
    """
    # Convert mask to uint8 for OpenCV processing
    mask = (mask * 255).astype(np.uint8)

    # Find contours in the mask
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Create an empty mask to hold refined regions
    refined_mask = np.zeros_like(mask)

    # Retain only large contours based on min_area_threshold
    for contour in contours:
        area = cv2.contourArea(contour)
        if area >= min_area_threshold:
            cv2.drawContours(refined_mask, [contour], -1, (255), thickness=cv2.FILLED)

    # Convert back to binary mask (0 and 1)
    refined_mask = (refined_mask > 0).astype(np.uint8)
    return refined_mask

def overlay_mask_on_image(image, mask, alpha=0.5):
    """
    Overlay a binary mask on an image with transparency.

    Parameters:
    image (numpy.ndarray): The original image as a NumPy array.
    mask (numpy.ndarray): The binary mask to overlay, as a NumPy array.
    alpha (float): Transparency level for the overlay.

    Returns:
    numpy.ndarray: Image with the mask overlayed.
    """
    overlay = image.copy()
    overlay[mask == 1] = [255, 0, 0]  # Color the mask region red

    # Blend the original image and mask overlay
    return cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)

def test_and_visualize_with_roi(model, test_loader, min_area_threshold=5000):
    """
    Run the model on test data and visualize the refined ROI mask overlayed on the original image.

    Parameters:
    model (torch.nn.Module): The trained model.
    test_loader (DataLoader): DataLoader for the test dataset.
    min_area_threshold (int): Minimum area threshold for keeping contours in the mask.
    """
    model.eval()
    with torch.no_grad():
        for i, (image, mask) in enumerate(test_loader):
            image = image.to(device)
            mask = mask.to(device)
            output = model(image)
            pred = output > 0.5  # Binary thresholding

            # Convert the prediction to a numpy array for post-processing
            pred_np = pred.squeeze(1).cpu().numpy()

            # Refine the mask to focus on the main breast area
            refined_preds = np.array([refine_breast_area(pred, min_area_threshold) for pred in pred_np])

            # Display results with mask overlay
            for j in range(len(refined_preds)):
                img = image[j].cpu().numpy().transpose(1, 2, 0)  # Convert tensor to HWC format
                img = (img * 255).astype(np.uint8)  # Scale image to [0, 255] for display
                mask_gt = mask[j].cpu().squeeze().numpy()  # Ground truth mask
                refined_pred = refined_preds[j]  # Refined mask from the model

                # Overlay the refined mask on the original image
                overlay_img = overlay_mask_on_image(img, refined_pred)

                # Display the images
                plt.figure(figsize=(15, 5))
                plt.subplot(1, 4, 1)
                plt.imshow(img, cmap='gray')
                plt.title("Original Image")
                plt.subplot(1, 4, 2)
                plt.imshow(mask_gt, cmap='gray')
                plt.title("Ground Truth Mask")
                plt.subplot(1, 4, 3)
                plt.imshow(refined_pred, cmap='gray')
                plt.title("Refined Predicted Mask")
                plt.subplot(1, 4, 4)
                plt.imshow(overlay_img)
                plt.title("ROI Overlay on Original Image")
                plt.show()
            break  # Display one batch for example
