In [1]:
%matplotlib inline


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# verify data
%ls drive/MyDrive/training_data/quadrant_enumeration/

coco_quadrant_enumeration_2048_1024.json  [0m[01;34mmasks_quadrant_2048_2048[0m/        [01;34mxrays[0m/
coco_quadrant_enumeration_2048_2048.json  [01;34mmasks_teeth_2048_1024[0m/           [01;34mxrays_2048_1024[0m/
explore_unet.pth                          [01;34mmasks_teeth_2048_2048[0m/           [01;34mxrays_2048_2048[0m/
[01;34mmasks_quadrant_2048_1024[0m/                 train_quadrant_enumeration.json  [01;34myolo_2048_2048[0m/


In [4]:
!nvidia-smi

Fri Dec  6 19:32:23 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [12]:
import torch
from torch.utils.data import Dataset
from torchvision.transforms import functional as TF
from PIL import Image, ImageDraw
import numpy as np
import os
import json


class MaskRCNNDataset(Dataset):
    def __init__(self, image_dir, coco_json, transform=None):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            coco_json (str): Path to the COCO-format JSON file.
            transform (callable, optional): Transformation to apply to images and masks.
        """
        self.image_dir = image_dir
        self.transform = transform

        with open(coco_json, "r") as f:
            coco_data = json.load(f)

        # Group annotations by image_id
        self.image_info = {img["id"]: img for img in coco_data["images"]}
        self.image_annotations = {img_id: [] for img_id in self.image_info.keys()}
        for annotation in coco_data["annotations"]:
            self.image_annotations[annotation["image_id"]].append(annotation)

        # Use only image IDs for indexing
        self.image_ids = list(self.image_info.keys())
        print(f"Dataset initialized with {len(self.image_ids)} images.")

    def __getitem__(self, index):
        """
        Args:
            index (int): Index of the image to retrieve.

        Returns:
            tuple: (image, target) where:
                - image: Tensor of the image.
                - target: Dictionary containing bounding boxes, labels, masks, and image_id.
        """
        image_id = self.image_ids[index]
        image_name = self.image_info[image_id]["file_name"]
        image_path = os.path.join(self.image_dir, image_name)

        # Load image
        image = Image.open(image_path).convert("RGB")

        # Prepare targets
        annotations = self.image_annotations[image_id]
        boxes = []
        labels = []
        masks = []

        for annotation in annotations:
            # Bounding box
            bbox = annotation["bbox"]  # COCO format: [x_min, y_min, width, height]
            x_min, y_min, width, height = bbox
            x_max = x_min + width
            y_max = y_min + height
            boxes.append([x_min, y_min, x_max, y_max])

            # Class label
            labels.append(annotation["category_id"])

            # Segmentation mask
            segmentation = annotation["segmentation"]
            mask = Image.new("L", image.size, 32)  # Default background class is 32
            draw = ImageDraw.Draw(mask)
            points = np.array(segmentation).reshape(-1, 2)
            draw.polygon([tuple(p) for p in points], fill=annotation["category_id"])
            masks.append(np.array(mask, dtype=np.uint8))

        # Convert lists to tensors
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)
        if masks:
            masks = np.stack(masks)  # Efficient stacking of masks
            masks = torch.tensor(masks, dtype=torch.uint8)
        else:
            masks = torch.empty((0, *image.size[::-1]), dtype=torch.uint8)  # Handle no masks case

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([image_id], dtype=torch.int64),
        }

        # Apply transformations
        if self.transform:
            image, target = self.transform(image, target)
        else:
            image = TF.to_tensor(image)

        return image, target

    def __len__(self):
        return len(self.image_ids)


In [13]:
import torchvision

# Initialize Mask R-CNN model with ResNet-50 backbone
model = torchvision.models.detection.maskrcnn_resnet50_fpn(num_classes=33)  # 32 teeth + background
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [15]:
from torch.optim import Adam
from torchvision.models.detection import MaskRCNN

def train_maskrcnn(model, train_loader, val_loader, epochs, device):
    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for images, targets in train_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            optimizer.zero_grad()
            loss_dict = model(images, targets)  # Returns a dictionary of losses
            losses = sum(loss for loss in loss_dict.values())
            losses.backward()
            optimizer.step()

            train_loss += losses.item()

        train_loss /= len(train_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}")

        # # Validation step
        # model.eval()
        # val_loss = 0.0
        # with torch.no_grad():
        #     for images, targets in val_loader:
        #         images = [img.to(device) for img in images]
        #         targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        #         loss_dict = model(images, targets)
        #         losses = sum(loss for loss in loss_dict.values())
        #         val_loss += losses.item()

        # val_loss /= len(val_loader)
        # print(f"Epoch {epoch + 1}/{epochs}, Validation Loss: {val_loss:.4f}")


In [19]:
def validate_dataset(dataset):
    issues_found = False
    for idx in range(len(dataset)):
        try:
            # Retrieve an image and its target
            image, target = dataset[idx]
            print(f"Processing image {idx}: {image}")

            # Check image dimensions
            if image.size(1) == 0 or image.size(2) == 0:
                print(f"Image {idx} has invalid dimensions: {image.size()}")
                issues_found = True

            # Check bounding boxes
            for box in target["boxes"]:
                if any(coord < 0 for coord in box) or box[2] <= box[0] or box[3] <= box[1]:
                    print(f"Image {idx} has invalid bounding box: {box}")
                    issues_found = True

            # Check segmentation masks
            if target["masks"].shape[1:] != image.shape[1:]:
                print(f"Image {idx} has mismatched mask size: {target['masks'].shape} vs {image.shape[1:]}")
                issues_found = True

            # Check labels
            if any(label < 0 or label > 32 for label in target["labels"]):
                print(f"Image {idx} has invalid label: {target['labels']}")
                issues_found = True

        except Exception as e:
            print(f"Error processing image {idx}: {e}")
            issues_found = True

    if not issues_found:
        print("All dataset items look valid!")
    else:
        print("Issues found in the dataset.")


In [20]:
from torch.utils.data import DataLoader
# Paths
base_dir = "drive/MyDrive/training_data/quadrant_enumeration"
image_dir = os.path.join(base_dir, "xrays_2048_1024")
coco_json = os.path.join(base_dir, "coco_quadrant_enumeration_2048_1024.json")

# Dataset and DataLoader
dataset = MaskRCNNDataset(image_dir, coco_json)

validate_dataset(dataset, image_dir)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

# Training
epochs = 5
# train_maskrcnn(model, train_loader, val_loader, epochs=epochs, device=device)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 0.9961,  ..., 1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 0.9961,  ..., 1.0000, 1.0000, 1.0000]]])
Target: {'boxes': tensor([[ 929.2328,  335.8701, 1001.9001,  647.1163],
        [ 894.7158,  301.3966,  958.2997,  633.3269],
        [ 844.7571,  280.7126,  917.4244,  640.2216],
        [ 798.4317,  321.0958,  893.8075,  613.6278],
        [ 740.2979,  292.5320,  834.7653,  603.7783],
        [ 664.9056,  333.9002,  772.0898,  6

KeyboardInterrupt: 