In [1]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import functional as F
from torchvision import transforms as T
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches


In [2]:
#  Configuration
DATA_ROOT = 'C:\\Users\\91741\\Downloads\\dental_xray' # Path to dataset
NUM_CLASSES = 2  # 1 (cavity) + 1 (background)
BATCH_SIZE = 2
LEARNING_RATE = 0.005
NUM_EPOCHS = 8
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(DEVICE)


cuda


In [3]:
def read_mask_image(mask_path):
    """
    Reads a mask image and converts it into a binary mask tensor.
    Assumes mask images are grayscale where non-zero pixels represent the cavity.
    """
    mask = Image.open(mask_path).convert("L") # Convert to grayscale
    mask = np.array(mask)

    # Convert to binary mask: 1 for cavity, 0 for background
    # Thresholding might be needed if masks are not perfectly binary (e.g., 0-255 values)
    mask[mask > 0] = 1 # Assuming any non-zero value is a cavity
    mask = torch.as_tensor(mask, dtype=torch.uint8)
    return mask


In [4]:
def get_bounding_box_from_mask(mask):
    """
    Generates bounding box coordinates from a binary mask.
    Handles cases where a mask might contain multiple disconnected components,
    treating each as a separate instance.
    """
    # Find contours in the mask
    # cv2.RETR_EXTERNAL retrieves only the extreme outer contours.
    # cv2.CHAIN_APPROX_SIMPLE compresses horizontal, vertical, and diagonal segments.
    contours, _ = cv2.findContours(mask.numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    boxes = []
    for contour in contours:
        # Get bounding box for each contour
        x, y, w, h = cv2.boundingRect(contour)
        boxes.append([x, y, x + w, y + h])

    if not boxes:
        # If no contours found (e.g., empty mask), return a dummy box
        # This is important to avoid errors in Mask R-CNN training if some masks are truly empty
        return torch.zeros((0, 4), dtype=torch.float32)

    boxes = torch.as_tensor(boxes, dtype=torch.float32)
    return boxes

In [5]:
class DentalXrayDataset(Dataset):
    """
    Custom Dataset for loading dental X-ray images and their corresponding
    cavity masks and bounding box annotations.
    """
    def __init__(self, root_dir, transforms=None):
        self.root_dir = root_dir
        self.transforms = transforms
        self.image_dir = os.path.join(root_dir, 'def_images')
        self.mask_dir = os.path.join(root_dir, 'mask_images')

        # List all image files. Assuming image names are consistent with mask names.
        self.image_filenames = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.image_dir, img_name)

        # --- MORE ROBUST MASK NAME GENERATION AND CHECK ---
        base_name, img_ext = os.path.splitext(img_name)
        possible_mask_paths = []

        # 1. Try exact same name and extension (e.g., images_7666.png -> images_7666.png)
        possible_mask_paths.append(os.path.join(self.mask_dir, img_name))

        # 2. Try with '_mask.png' suffix (e.g., images_7666.png -> images_7666_mask.png)
        possible_mask_paths.append(os.path.join(self.mask_dir, f"{base_name}_mask.png"))

        # 4. Try exact same name but different common image extensions for masks
        if img_ext.lower() == '.jpg':
            possible_mask_paths.append(os.path.join(self.mask_dir, f"{base_name}.png"))
        elif img_ext.lower() == '.png':
            possible_mask_paths.append(os.path.join(self.mask_dir, f"{base_name}.jpg"))


        mask_path = None
        for p in possible_mask_paths:
            if os.path.exists(p):
                mask_path = p
                break

        if mask_path is None:
            # If no mask file is found after trying all common patterns, raise an error
            raise FileNotFoundError
        # --- END ROBUST MASK NAME GENERATION AND CHECK ---

        # Load image
        img = Image.open(img_path).convert("RGB") # Ensure image is in RGB format

        # Load mask and convert to binary mask tensor
        mask = read_mask_image(mask_path) # This is a single mask image for the whole image

        # Get bounding boxes from the mask.
        boxes = get_bounding_box_from_mask(mask)

        # Create labels: 1 for cavity (foreground)
        labels = torch.ones((boxes.shape[0],), dtype=torch.int64)

        # Find connected components in the mask to get individual instance masks
        num_labels, labeled_mask = cv2.connectedComponents(mask.numpy())
        instance_masks = []
        for i in range(1, num_labels): # Iterate through each detected component (skip background 0)
            instance_mask = (labeled_mask == i).astype(np.uint8)
            instance_masks.append(torch.as_tensor(instance_mask, dtype=torch.uint8))

        if not instance_masks:
            instance_masks = torch.zeros((0, mask.shape[0], mask.shape[1]), dtype=torch.uint8)
        else:
            instance_masks = torch.stack(instance_masks)


        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = instance_masks
        target["image_id"] = torch.tensor([idx]) # Unique ID for the image
        target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # Area of the bounding box
        target["iscrowd"] = torch.zeros((boxes.shape[0],), dtype=torch.int64) # All instances are not crowd

        if self.transforms is not None:
            # Apply transforms. The Compose class is now updated to handle this correctly.
            img, target = self.transforms(img, target)


        return img, target


In [6]:
class Compose(object):
    """
    Composes several transforms together.
    This custom Compose is simplified and only applies F.to_tensor to the image.
    For augmentations that also affect bounding boxes and masks,
    use `torchvision.transforms.v2.Compose`.
    """
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        # Apply transforms. F.to_tensor only applies to the image.
        # The target dictionary is passed through unchanged.
        for t in self.transforms:
            # F.to_tensor is a functional transform that only takes the image.
            # Other transforms (like geometric augmentations in v2) would handle target too.
            image = t(image) # Apply transform to image

        return image, target

In [7]:
# --- Transformations ---
def get_transform(train):
    """
    Defines image transformations for training and validation.
    """
    transforms = []
    transforms.append(F.to_tensor) # Convert PIL Image to PyTorch Tensor

    if train:
        # Add data augmentation for training
        # Example: Random horizontal flip
        transforms.append(T.RandomHorizontalFlip(0.5))

    return Compose(transforms)

In [8]:
# --- Model Definition ---

def get_model_instance_segmentation(num_classes):
    """
    Loads a pre-trained Mask R-CNN model and modifies its head
    for the specific number of classes.
    """
    # Load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights=torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)

    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Get the number of input features for the mask predictor
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256 # A common choice for the hidden layer size
    # Replace the pre-trained mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)
    return model

In [9]:
def train_one_epoch(model, optimizer, data_loader, device, epoch):
    """
    Conceptual training loop for one epoch.
    In a real scenario, you would calculate losses, perform backpropagation,
    and update model weights here.
    """
    model.train() # Set model to training mode
    total_loss = 0

    # Iterate over the data loader
    for i, (images, targets) in enumerate(data_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass: model returns a dictionary of losses
        loss_dict = model(images, targets)

        # Sum up all losses
        losses = sum(loss for loss in loss_dict.values())
        total_loss += losses.item()

        # Backward pass and optimization
        optimizer.zero_grad() # Clear gradients
        losses.backward()     # Compute gradients
        optimizer.step()      # Update weights

        if i % 10 == 0: # Print loss every 10 batches
            print(f"Epoch: {epoch}, Batch: {i}, Loss: {losses.item():.4f}")

    avg_loss = total_loss / len(data_loader)
    print(f"Epoch {epoch} finished. Average Loss: {avg_loss:.4f}")
    return avg_loss


In [10]:
def visualize_prediction(image_path, model, device, threshold=0.7):
    """
    Performs inference on a single image and visualizes the predicted
    bounding boxes and segmentation masks.
    """
    model.eval() # Set model to evaluation mode
    img = Image.open(image_path).convert("RGB")
    img_tensor = F.to_tensor(img).to(device)

    with torch.no_grad(): # Disable gradient calculation for inference
        prediction = model([img_tensor])

    # Convert image back to numpy for visualization
    img_np = np.array(img)

    fig, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(img_np)
    ax.set_title("Predicted Cavities")
    ax.axis('off')

    # Process predictions
    boxes = prediction[0]['boxes'].cpu().numpy()
    labels = prediction[0]['labels'].cpu().numpy()
    scores = prediction[0]['scores'].cpu().numpy()
    masks = prediction[0]['masks'].cpu().numpy() # Masks are (N, 1, H, W)

    # Filter predictions based on score threshold
    for i in range(len(scores)):
        if scores[i] > threshold:
            box = boxes[i]
            label = labels[i]
            score = scores[i]
            mask = masks[i, 0] # Take the first channel of the mask (binary mask)

            # Draw bounding box
            rect = patches.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1],
                                     linewidth=2, edgecolor='r', facecolor='none',
                                     label=f'Cavity: {score:.2f}')
            ax.add_patch(rect)

            # Apply mask overlay
            # Convert mask to binary (0 or 1) and resize to original image dimensions if necessary
            mask = (mask > 0.5).astype(np.uint8) # Threshold mask to binary
            # If mask is smaller than original image, resize it (Mask R-CNN outputs masks at 28x28 by default, then upsamples)
            # Ensure mask is the same size as the image for overlay
            if mask.shape[0] != img_np.shape[0] or mask.shape[1] != img_np.shape[1]:
                mask = cv2.resize(mask, (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST)

            # Create a colored overlay for the mask
            color_mask = np.zeros_like(img_np)
            color_mask[mask == 1] = [255, 0, 0] # Red color for cavity mask

            # Blend the mask with the original image
            # alpha controls the transparency of the mask (e.g., 0.5 for 50% transparency)
            ax.imshow(color_mask, alpha=0.4)

    plt.legend()
    plt.show()

In [11]:
if __name__ == "__main__":
    # 1. Preparing dataset and configuring test train split
    dataset = DentalXrayDataset(DATA_ROOT, get_transform(train=True))
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    # Define a collate_fn for the DataLoader, as targets are dictionaries
    def collate_fn(batch):
        return tuple(zip(*batch))

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

    print(f"Dataset loaded. Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

    # Get the Model
    print("Initializing Mask R-CNN model")
    model = get_model_instance_segmentation(NUM_CLASSES)
    model.to(DEVICE)

    #Define Optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=0.9, weight_decay=0.0005)
    # And a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=3,gamma=0.1)

    print(f"This code will train for {NUM_EPOCHS} epochs.")

    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        train_one_epoch(model, optimizer, train_loader, DEVICE, epoch)
        lr_scheduler.step() # Update learning rate

    # Saving Model 
    torch.save(model.state_dict(), 'mask_rcnn_cavity_detector.pth')
    print("Model weights saved  ")

    


Dataset loaded. Training samples: 240, Validation samples: 60
Initializing Mask R-CNN model
This code will train for 8 epochs.

Epoch 1/8
Epoch: 0, Batch: 0, Loss: 8.1458
Epoch: 0, Batch: 10, Loss: 1.0486
Epoch: 0, Batch: 20, Loss: 1.5216
Epoch: 0, Batch: 30, Loss: 1.3281
Epoch: 0, Batch: 40, Loss: 1.2826
Epoch: 0, Batch: 50, Loss: 0.9276
Epoch: 0, Batch: 60, Loss: 0.7774
Epoch: 0, Batch: 70, Loss: 1.0983
Epoch: 0, Batch: 80, Loss: 0.7748
Epoch: 0, Batch: 90, Loss: 0.9026
Epoch: 0, Batch: 100, Loss: 0.7977
Epoch: 0, Batch: 110, Loss: 0.8276
Epoch 0 finished. Average Loss: 1.2639

Epoch 2/8
Epoch: 1, Batch: 0, Loss: 0.9894
Epoch: 1, Batch: 10, Loss: 0.6987
Epoch: 1, Batch: 20, Loss: 0.8194
Epoch: 1, Batch: 30, Loss: 0.5716
Epoch: 1, Batch: 40, Loss: 3.3626
Epoch: 1, Batch: 50, Loss: 0.7859
Epoch: 1, Batch: 60, Loss: 0.7635
Epoch: 1, Batch: 70, Loss: 0.7549
Epoch: 1, Batch: 80, Loss: 0.9237
Epoch: 1, Batch: 90, Loss: 0.7735
Epoch: 1, Batch: 100, Loss: 0.6959
Epoch: 1, Batch: 110, Loss: 0