In [1]:
!git clone https://github.com/siapai/tuft-dental-segmentation.git temp
!mv temp/segmentation .
!rm -drf temp

Cloning into 'temp'...
remote: Enumerating objects: 371, done.[K
remote: Counting objects: 100% (27/27), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 371 (delta 4), reused 25 (delta 2), pack-reused 344[K
Receiving objects: 100% (371/371), 65.62 MiB | 10.16 MiB/s, done.
Resolving deltas: 100% (11/11), done.


In [2]:
!pip install -q pytorch-lightning
!pip install -q pretrainedmodels
!pip install -q torchmetrics
!pip install -q albumentations


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m802.3/802.3 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m868.8/868.8 kB[0m [31m48.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m64.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pretrainedmodels (setup.py) ... [?25l[?25hdone


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import os

# Ensure the destination directory exists
os.makedirs('data', exist_ok=True)

# Unzip the file
!unzip /content/drive/MyDrive/Datasets/TuftDental/Radiographs.zip -d data/
!unzip /content/drive/MyDrive/Datasets/TuftDental/Segmentation.zip -d data/

!cp /content/drive/MyDrive/Datasets/TuftDental/data_split.json .



Archive:  /content/drive/MyDrive/Datasets/TuftDental/Radiographs.zip
   creating: data/Radiographs/
  inflating: data/Radiographs/1.JPG  
  inflating: data/Radiographs/100.JPG  
  inflating: data/Radiographs/1000.JPG  
  inflating: data/Radiographs/1001.JPG  
  inflating: data/Radiographs/1002.JPG  
  inflating: data/Radiographs/1004.JPG  
  inflating: data/Radiographs/1007.JPG  
  inflating: data/Radiographs/1008.JPG  
  inflating: data/Radiographs/1009.JPG  
  inflating: data/Radiographs/101.JPG  
  inflating: data/Radiographs/1010.JPG  
  inflating: data/Radiographs/1011.JPG  
  inflating: data/Radiographs/1012.JPG  
  inflating: data/Radiographs/1013.JPG  
  inflating: data/Radiographs/1014.JPG  
  inflating: data/Radiographs/1015.JPG  
  inflating: data/Radiographs/1016.JPG  
  inflating: data/Radiographs/1017.JPG  
  inflating: data/Radiographs/1018.JPG  
  inflating: data/Radiographs/102.JPG  
  inflating: data/Radiographs/1020.JPG  
  inflating: data/Radiographs/1021.JPG  
  in

In [5]:
import os

import torch
import torchvision
import albumentations as A

print(torch.__version__)
print(torchvision.__version__)
print(A.__version__)

2.3.0+cu121
0.18.0+cu121
1.3.1


In [6]:
import cv2
from albumentations.pytorch import ToTensorV2

simple_transform = A.Compose([
    A.Resize(height=256, width=512),  # Resize to 256x512
    A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)),  # Normalize image to 0-1 range
    ToTensorV2()
])

In [7]:
import json
from torch.utils.data import Dataset
import cv2
import numpy as np

class DentalDataset(Dataset):
    def __init__(self, data_path: str, split: str, transform=None):
        with open(data_path, 'r') as f:
            self.data = json.load(f)

        self.images = self.data[split]
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]['image']
        mask_path = self.images[idx]['mask']

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = self.preprocess_mask(mask)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask'].permute(2, 0, 1)

        return image, mask

    def preprocess_mask(self, mask):
        # Threshold the mask to ensure it only contains 0 and 1 values
        _, binary_mask = cv2.threshold(mask, 127, 1, cv2.THRESH_BINARY)
        binary_mask = np.expand_dims(binary_mask, axis=-1)  # Add channel dimension
        return binary_mask

In [9]:
test_dataset = DentalDataset(
    data_path='data_split.json',
    split='test',
    transform=simple_transform
)

print(f'[INFO] Length of test dataset: {len(test_dataset)}')

[INFO] Length of test dataset: 150


In [10]:
import os
from torch.utils.data import DataLoader

NUM_WORKERS = os.cpu_count()

print(f'Number of Workers: {NUM_WORKERS}')
BATCH_SIZE = 32

test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

Number of Workers: 12


In [11]:
import random
import matplotlib.pyplot as plt

def visualize_batch(dataloader, num_images=2):
    images, masks = next(iter(dataloader))

    random_indices = random.sample(range(len(images)), num_images)

    # Plot the images and annotations side by side
    fig, axes = plt.subplots(num_images, 2, figsize=(16, 4 * num_images))

    for i, idx in enumerate(random_indices):

        image = images[i].squeeze(0).permute(1, 2, 0)
        mask = masks[i].squeeze(0)

        axes[i, 0].imshow(image, alpha=1)
        # axes[i, 0].imshow(mask, alpha=0.5)
        axes[i, 0].axis(False)

        axes[i, 1].imshow(image, alpha=0.5)
        axes[i, 1].imshow(mask, alpha=0.5)
        axes[i, 1].axis(False)

    plt.show()

In [19]:
import segmentation as sm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from torchmetrics import JaccardIndex, Accuracy
import torch
import time

checkpoint_cb = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    mode='min',
    filename='best_checkpoint',
    verbose=True
)

early_stopping_cb = EarlyStopping(
    monitor='val_loss',
    patience=20,
    mode='min',
    verbose=True
)

lr_monitor = LearningRateMonitor(logging_interval='epoch')


class DentalModel(pl.LightningModule):
    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = sm.create_model(
            arch=arch,
            encoder_name=encoder_name,
            in_channels=in_channels,
            classes=out_classes,
            **kwargs
        )

        self.loss_fn = sm.losses.DiceLoss(sm.losses.BINARY_MODE, from_logits=True)

        self.iou = JaccardIndex(task='binary')  # IoU
        self.accuracy = Accuracy(task='binary')  # Pixel Accuracy

    def forward(self, x):
        return self.model(x)

    def _common_step(self, batch):
        image, mask = batch
        assert image.ndim == 4

        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        assert mask.ndim == 4
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        loss = self.loss_fn(logits_mask, mask)

        return loss, logits_mask

    def _calc_metrics(self, loss, logits_mask, target_mask, stage):
        dice = 1.0 - loss

        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        iou = self.iou(pred_mask, target_mask)
        acc = self.accuracy(pred_mask, target_mask)

        if stage == 'val':
            self.logger.experiment.add_scalars('dice', {stage: dice}, self.current_epoch)
            self.logger.experiment.add_scalars('iou', {stage: iou}, self.current_epoch)
            self.logger.experiment.add_scalars('accuracy', {stage: acc}, self.current_epoch)

        self.log(f'{stage}_loss', loss, on_step=False, on_epoch=True, prog_bar=True)

        self.log_dict({
            f'{stage}_dice': dice,
            f'{stage}_iou': iou,
            f'{stage}_acc': acc
        }, on_step=False, on_epoch=True)


    def test_step(self, batch, batch_idx):
        _, target_mask = batch
        start_time = time.time()
        loss, logits_mask = self._common_step(batch)
        end_time = time.time()

        inference_time = end_time - start_time
        self.log('inference_time', inference_time, prog_bar=True)

        self._calc_metrics(loss, logits_mask, target_mask, "test")
        return {loss, inference_time}

    def configure_optimizers(self):
        pass


In [20]:
parent_dir = "/content/drive/MyDrive/Experiment/Tuft"
ckpt_path = "version_0/checkpoints/best_checkpoint.ckpt"

checkpoints = [
    {
        "arch": "unet",
        "encoder_name": "resnet34",
        "path": os.path.join(parent_dir, "unet-resnet34", ckpt_path)
    },
    {
        "arch": "unet",
        "encoder_name": "mobilenet_v2",
        "path": os.path.join(parent_dir, "unet-mobilenet_v2", ckpt_path)
    },
    {
        "arch": "unetplusplus",
        "encoder_name": "resnet34",
       "path": os.path.join(parent_dir, "unetplusplus-resnet34", ckpt_path)
    },
    {
        "arch": "unetplusplus",
        "encoder_name": "mobilenet_v2",
        "path": os.path.join(parent_dir, "unetplusplus-mobilenet_v2", ckpt_path)
    },
    {
        "arch": "deeplabv3",
        "encoder_name": "resnet34",
       "path": os.path.join(parent_dir, "deeplabv3-resnet34", ckpt_path)
    },
    {
        "arch": "deeplabv3",
        "encoder_name": "mobilenet_v2",
        "path": os.path.join(parent_dir, "deeplabv3-mobilenet_v2", ckpt_path)
    },
    {
        "arch": "deeplabv3plus",
        "encoder_name": "resnet34",
       "path": os.path.join(parent_dir, "deeplabv3plus-resnet34", ckpt_path)
    },
    {
        "arch": "deeplabv3plus",
        "encoder_name": "mobilenet_v2",
        "path": os.path.join(parent_dir, "deeplabv3plus-mobilenet_v2", ckpt_path)
    },
    {
        "arch": "fpn",
        "encoder_name": "resnet34",
       "path": os.path.join(parent_dir, "fpn-resnet34", ckpt_path)
    },
    {
        "arch": "fpn",
        "encoder_name": "mobilenet_v2",
        "path": os.path.join(parent_dir, "fpn-mobilenet_v2", ckpt_path)
    },
    {
        "arch": "pan",
        "encoder_name": "resnet34",
       "path": os.path.join(parent_dir, "pan-resnet34", ckpt_path)
    },
    {
        "arch": "pan",
        "encoder_name": "mobilenet_v2",
        "path": os.path.join(parent_dir, "pan-mobilenet_v2", ckpt_path)
    },
    {
        "arch": "pspnet",
        "encoder_name": "resnet34",
       "path": os.path.join(parent_dir, "pspnet-resnet34", ckpt_path)
    },
    {
        "arch": "pspnet",
        "encoder_name": "mobilenet_v2",
        "path": os.path.join(parent_dir, "pspnet-mobilenet_v2", ckpt_path)
    },
]


In [21]:
trainer = pl.Trainer()

for checkpoint in checkpoints:
    arch = checkpoint["arch"]
    encoder_name = checkpoint["encoder_name"]
    path = checkpoint["path"]
    model = DentalModel.load_from_checkpoint(path, arch=arch, encoder_name=encoder_name, in_channels=3, out_classes=1)
    print(f"arch: {arch}, encoder_name: {encoder_name}")
    trainer.test(model, dataloaders=test_dataloader, verbose=True)
    print("\n\n")



INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: unet, encoder_name: resnet34


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: unet, encoder_name: mobilenet_v2


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: unetplusplus, encoder_name: resnet34


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: unetplusplus, encoder_name: mobilenet_v2


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: deeplabv3, encoder_name: resnet34


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: deeplabv3, encoder_name: mobilenet_v2


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: deeplabv3plus, encoder_name: resnet34


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: deeplabv3plus, encoder_name: mobilenet_v2


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: fpn, encoder_name: resnet34


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: fpn, encoder_name: mobilenet_v2


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: pan, encoder_name: resnet34


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: pan, encoder_name: mobilenet_v2


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: pspnet, encoder_name: resnet34


Testing: |          | 0/? [00:00<?, ?it/s]






INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


arch: pspnet, encoder_name: mobilenet_v2


Testing: |          | 0/? [00:00<?, ?it/s]




