In [None]:
%matplotlib inline


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torchvision
# from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def load_maskrcnn_model(save_path, device):
    """
    Load a trained PyTorch Mask R-CNN model

    Args:
        save_path: Path to the saved model
        device: Device to load the model to (cuda/cpu)
    Returns:
        model: The loaded model
    """
    # Initialize the model
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(num_classes=33)

    # Load the state dict with proper file handling
    with open(save_path, 'rb') as f:
        state_dict = torch.load(f, map_location=device, weights_only=True)
        model.load_state_dict(state_dict)

    # Move model to device
    model = model.to(device)

    # Set to evaluation mode
    model.eval()

    return model


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

maskrcnn_path="drive/MyDrive/training_data/quadrant_enumeration/dental_maskrcnn.pth"
markrcnn_model = load_maskrcnn_model(maskrcnn_path, device)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 131MB/s]


# Modified UNet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ModifiedUNet(nn.Module):
    def __init__(self, in_channels=2, out_channels=33):  # in_channels=2 for image + bbox info
        super(ModifiedUNet, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
            )

        def up_block(in_channels, out_channels):
            return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

        # Note: in_channels is now 2 (grayscale image + bbox mask)
        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = conv_block(512, 1024)

        self.upconv4 = up_block(1024, 512)
        self.decoder4 = conv_block(1024, 512)
        self.upconv3 = up_block(512, 256)
        self.decoder3 = conv_block(512, 256)
        self.upconv2 = up_block(256, 128)
        self.decoder2 = conv_block(256, 128)
        self.upconv1 = up_block(128, 64)
        self.decoder1 = conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):  # x is now the concatenated input
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        enc4 = self.encoder4(self.pool(enc3))

        bottleneck = self.bottleneck(self.pool(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = self.decoder4(torch.cat((dec4, enc4), dim=1))
        dec3 = self.upconv3(dec4)
        dec3 = self.decoder3(torch.cat((dec3, enc3), dim=1))
        dec2 = self.upconv2(dec3)
        dec2 = self.decoder2(torch.cat((dec2, enc2), dim=1))
        dec1 = self.upconv1(dec2)
        dec1 = self.decoder1(torch.cat((dec1, enc1), dim=1))

        return self.final_conv(dec1)





class ImprovedHybridSegmentation:
    def predict(self, image, confidence_threshold=0.5):
        """
        Improved hybrid segmentation using bbox information directly in U-Net
        """
        # 1. Get Mask R-CNN predictions
        maskrcnn_pred = self.maskrcnn([image])[0]

        # 2. Create bbox channel
        bbox_mask = torch.zeros((1, 1, image.shape[2], image.shape[3]),
                              device=self.device)

        # Fill bbox_mask with tooth index values in box regions
        for box, label in zip(maskrcnn_pred['boxes'], maskrcnn_pred['labels']):
            x1, y1, x2, y2 = map(int, box.cpu().numpy())
            bbox_mask[:, :, y1:y2, x1:x2] = label.item()

        # 3. Feed both image and bbox information to modified U-Net
        unet_pred = self.modified_unet(image, bbox_mask)

        return {
            'masks': unet_pred,
            'boxes': maskrcnn_pred['boxes'],
            'labels': maskrcnn_pred['labels'],
            'scores': maskrcnn_pred['scores']
        }

In [None]:

class DiceLoss(nn.Module):
    def __init__(self, weights=None):
        """
        DiceLoss with optional class weights.

        Args:
            weights (torch.Tensor): Weights for each class. Shape: (num_classes,)
        """
        super(DiceLoss, self).__init__()
        self.weights = weights

    def forward(self, pred, target):
        pred = torch.softmax(pred, dim=1)  # Class probabilities
        target = F.one_hot(target, num_classes=33).permute(0, 3, 1, 2).float()  # One-hot encode target

        intersection = (pred * target).sum(dim=(2, 3))  # Per class intersection
        union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))  # Per class union

        dice_score = 2.0 * intersection / (union + 1e-6)  # Per class Dice score

        # Apply weights
        if self.weights is not None:
            dice_score = dice_score * self.weights.view(1, -1)

        return 1.0 - dice_score.mean()  # Mean weighted Dice loss


def dice_metric(pred, target, num_classes=33):
    """
    Compute per-class Dice scores.
    """
    pred = torch.argmax(pred, dim=1)  # Shape: (batch_size, H, W)
    dice_scores = []

    for c in range(num_classes):
        pred_c = (pred == c).float()
        target_c = (target == c).float()

        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum()

        if union == 0:  # Avoid NaN for empty classes
            dice_scores.append(torch.tensor(1.0))  # Perfect score for empty classes
        else:
            dice_scores.append((2.0 * intersection) / (union + 1e-6))

    return dice_scores







In [None]:
import os
import json
from PIL import Image, ImageDraw
import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms import functional as TF

class ModifiedToothSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, coco_json, maskrcnn_model, device, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.maskrcnn_model = maskrcnn_model
        self.device = device

        with open(coco_json, "r") as f:
            coco_data = json.load(f)

        self.image_info = {img["id"]: img for img in coco_data["images"]}
        self.image_annotations = {img_id: [] for img_id in self.image_info.keys()}
        for annotation in coco_data["annotations"]:
            self.image_annotations[annotation["image_id"]].append(annotation)

        self.image_ids = list(self.image_info.keys())
        print(f"Dataset initialized with {len(self.image_ids)} images.")

    def __getitem__(self, index):
        image_id = self.image_ids[index]
        image_name = self.image_info[image_id]["file_name"]
        image_path = os.path.join(self.image_dir, image_name)

        # Load image
        image = Image.open(image_path).convert("L")  # Grayscale
        original_size = image.size

        # Create segmentation mask
        mask = Image.new("L", original_size, 0)
        draw = ImageDraw.Draw(mask)
        for annotation in self.image_annotations[image_id]:
            points = np.array(annotation["segmentation"]).reshape(-1, 2)
            draw.polygon([tuple(p) for p in points], fill=annotation["category_id"]+1)

        # Get Mask R-CNN predictions on original size image
        with torch.no_grad():
            # Convert to tensor without resizing and without adding batch dimension
            image_tensor = TF.to_tensor(Image.open(image_path).convert("L")).to(self.device)  # Shape: [1, H, W]
            maskrcnn_pred = self.maskrcnn_model([image_tensor])[0]  # Pass as list of tensors

        # Create bbox mask at original size
        bbox_mask = torch.zeros((1, *original_size[::-1]), device=self.device)
        for box, label in zip(maskrcnn_pred['boxes'], maskrcnn_pred['labels']):
            x1, y1, x2, y2 = map(int, box.cpu().numpy())
            bbox_mask[:, y1:y2, x1:x2] = label.item()

        # Apply transforms to image and mask
        if self.transform:
            image, mask = self.transform(image, mask)

        # Resize bbox_mask to match transformed image size
        bbox_mask = F.interpolate(bbox_mask.unsqueeze(0), size=(256, 256),
                                mode='nearest').squeeze(0)

        return image, bbox_mask.cpu(), mask

    def __len__(self):
        return len(self.image_ids)


from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

def train_modified_unet(model, train_loader, val_loader, epochs, device, weights):
    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=1e-3)

    if weights is not None:
        weights = weights.to(device)
        criterion = CrossEntropyLoss(weight=weights)
    else:
        criterion = CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for images, bbox_masks, masks in train_loader:
            images = images.to(device)
            bbox_masks = bbox_masks.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            # Concatenate image and bbox information
            inputs = torch.cat([images, bbox_masks], dim=1)
            outputs = model(inputs)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}")

        # Validation step
        model.eval()
        with torch.no_grad():
            all_dice_scores = {c: [] for c in range(33)}
            for images, bbox_masks, masks in val_loader:
                images = images.to(device)
                bbox_masks = bbox_masks.to(device)
                masks = masks.to(device)

                inputs = torch.cat([images, bbox_masks], dim=1)
                outputs = model(inputs)
                per_class_dice = dice_metric(outputs, masks)
                for c, score in enumerate(per_class_dice):
                    all_dice_scores[c].append(score)

        mean_dice_scores = {c: sum(scores) / len(scores) for c, scores in all_dice_scores.items()}
        print(f"Epoch {epoch + 1}/{epochs}, Per-Class Dice Scores: {mean_dice_scores}")
        overall_dice = sum(mean_dice_scores.values()) / len(mean_dice_scores)
        print(f"Epoch {epoch + 1}/{epochs}, Overall Val Dice Score: {overall_dice:.4f}")



In [None]:
def transform(image, mask):
    image = TF.resize(image, (256, 256))
    mask = TF.resize(mask, (256, 256), interpolation=Image.NEAREST)
    image = TF.to_tensor(image)
    mask = torch.from_numpy(np.array(mask, dtype=np.int64))  # Convert to tensor
    return image, mask

In [None]:

# Paths
base_dir = "drive/MyDrive/training_data/quadrant_enumeration"
image_dir = os.path.join(base_dir, "xrays_2048_1024")
mask_dir = os.path.join(base_dir, "masks_teeth_2048_1024")
coco_json = os.path.join(base_dir, "coco_quadrant_enumeration_2048_1024.json")


num_classes = 33
background_proportion = 0.9
tooth_proportion = 0.1 / 32  # Each of the 32 classes share 10%

# Compute weights
weights = [1 / tooth_proportion] * 32  # Equal weight for all 32 classes
weights.append(1 / background_proportion)  # Weight for background class

# Normalize weights
weights = torch.tensor(weights, dtype=torch.float32)
weights /= weights.sum()  # Normalize so weights sum to ~1

# Print weights for reference
print("Class Weights:", weights)



# Usage:
batch_size = 4

# Initialize Mask R-CNN
maskrcnn_model = load_maskrcnn_model(maskrcnn_path, device)
maskrcnn_model.eval()  # Set to evaluation mode

# 3. Create dataset
dataset = ModifiedToothSegmentationDataset(
    image_dir=image_dir,
    mask_dir=mask_dir,
    coco_json=coco_json,
    maskrcnn_model=maskrcnn_model,
    device=device,
    transform=transform
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize modified U-Net (2 input channels for image + bbox mask)
model = ModifiedUNet(in_channels=2, out_channels=33)



Class Weights: tensor([0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312,
        0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312,
        0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0312,
        0.0312, 0.0312, 0.0312, 0.0312, 0.0312, 0.0001])
Dataset initialized with 634 images.


In [None]:
# Train
train_modified_unet(model, train_loader, val_loader, epochs=100, device=device, weights=weights)

Epoch 1/100, Train Loss: 1.0250
Epoch 1/100, Per-Class Dice Scores: {0: tensor(0.9205, device='cuda:0'), 1: tensor(0., device='cuda:0'), 2: tensor(0., device='cuda:0'), 3: tensor(0., device='cuda:0'), 4: tensor(0., device='cuda:0'), 5: tensor(0., device='cuda:0'), 6: tensor(0., device='cuda:0'), 7: tensor(0., device='cuda:0'), 8: tensor(0., device='cuda:0'), 9: tensor(0., device='cuda:0'), 10: tensor(0., device='cuda:0'), 11: tensor(0., device='cuda:0'), 12: tensor(0., device='cuda:0'), 13: tensor(0., device='cuda:0'), 14: tensor(0., device='cuda:0'), 15: tensor(0., device='cuda:0'), 16: tensor(0., device='cuda:0'), 17: tensor(0., device='cuda:0'), 18: tensor(0., device='cuda:0'), 19: tensor(0., device='cuda:0'), 20: tensor(0., device='cuda:0'), 21: tensor(0., device='cuda:0'), 22: tensor(0., device='cuda:0'), 23: tensor(0., device='cuda:0'), 24: tensor(0., device='cuda:0'), 25: tensor(0., device='cuda:0'), 26: tensor(0., device='cuda:0'), 27: tensor(0., device='cuda:0'), 28: tensor(0.

In [None]:
model.save('bbunet_maskrcnn.pth')

NameError: name 'model' is not defined