In [1]:
import os
import torch
import cv2
import numpy as np


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, masks_dir, bbox_dir, transform=None):
       
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.bbox_dir = bbox_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(images_dir))
        self.mask_files = sorted(os.listdir(masks_dir))
        self.bbox_files = sorted(os.listdir(bbox_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        
        # Get image metadata
        image_name = self.image_files[idx]
        image_path = os.path.join(self.images_dir, image_name)
        
        # Use cv2 to read image (BGR format)
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB
        
        # Get the corresponding mask file
        mask_name = image_name.replace('.jpg', '_mask.png')  # Adjust if needed
        mask_path = os.path.join(self.masks_dir, mask_name)
        
        # Use cv2 to read the mask (grayscale)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # '0' flag is for grayscale

        bbox_name = image_name.replace('.jpg', '.txt')
        bbox_path = os.path.join(self.bbox_dir, bbox_name)
        
        scale_x = 256 / 1024
        scale_y = 256 / 1024
        
        with open(bbox_path, 'r') as f:
            lines = f.readlines()
        
        for line in lines:
            coords = list(map(int, line.strip().split()))
            x1, y1, x2, y2 = coords[1:]

        bboxes = [x1, y1, x2, y2]
        
        x1 = int(x1 * scale_x)
        y1 = int(y1 * scale_y)
        x2 = int(x2 * scale_x)
        y2 = int(y2 * scale_y)
        
        #bboxes = [x1, y1, x2, y2]
        
        bboxes = torch.tensor(bboxes)
        
        
        # Apply transformations to both image and mask (if provided)
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
            

        return image, mask, bboxes


In [2]:
from torchvision.transforms import Compose, ToTensor, Normalize
import cv2

# Define simplified transformations
def transform(image, mask):
    # Resize image and mask to the same size (for consistency)
    image = cv2.resize(image, (256, 256))  # Resize image to 256x256
    mask = cv2.resize(mask, (256, 256))    # Resize mask to 256x256
    
    # Convert image and mask to tensor
    image = ToTensor()(image)  # Convert image to tensor
    mask = ToTensor()(mask)    # Convert mask to tensor
    
    # Normalize the image (not the mask)
    image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
    

    return {"image": image, "mask": mask}


In [None]:

images_dir = 'knee_segmentation_robo/valid/new/resized_images'
masks_dir = 'knee_segmentation_robo/valid/new/masks'
bbox_dir = 'knee_segmentation_robo/valid/new/bbox_coords'


# Initialize dataset with transform
test_dataset = CustomDataset(images_dir, masks_dir, bbox_dir, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

# Iterate through the DataLoader (example)
for images, masks, bboxes in test_loader:
    print(images.shape)  # Batch of images
    print(masks.shape)   # Corresponding batch of masks
    print(bboxes, bboxes.shape) 
    bbboxes = bboxes.unsqueeze(1)
    print(bbboxes, bbboxes.shape)



torch.Size([1, 3, 256, 256])
torch.Size([1, 1, 256, 256])
tensor([[ 98,   0, 794, 483]]) torch.Size([1, 4])
tensor([[[ 98,   0, 794, 483]]]) torch.Size([1, 1, 4])
torch.Size([1, 3, 256, 256])
torch.Size([1, 1, 256, 256])
tensor([[ 225,  508,  846, 1025]]) torch.Size([1, 4])
tensor([[[ 225,  508,  846, 1025]]]) torch.Size([1, 1, 4])
torch.Size([1, 3, 256, 256])
torch.Size([1, 1, 256, 256])
tensor([[ 310,  548,  898, 1025]]) torch.Size([1, 4])
tensor([[[ 310,  548,  898, 1025]]]) torch.Size([1, 1, 4])
torch.Size([1, 3, 256, 256])
torch.Size([1, 1, 256, 256])
tensor([[ 328,  450,  982, 1025]]) torch.Size([1, 4])
tensor([[[ 328,  450,  982, 1025]]]) torch.Size([1, 1, 4])
torch.Size([1, 3, 256, 256])
torch.Size([1, 1, 256, 256])
tensor([[197, 355, 860, 966]]) torch.Size([1, 4])
tensor([[[197, 355, 860, 966]]]) torch.Size([1, 1, 4])
torch.Size([1, 3, 256, 256])
torch.Size([1, 1, 256, 256])
tensor([[324, 504, 921, 995]]) torch.Size([1, 4])
tensor([[[324, 504, 921, 995]]]) torch.Size([1, 1, 4]

In [None]:
model_type = 'vit_b'
checkpoint = 'models/sam-vit-base_custom.pth'
device = 'cuda:0'

In [5]:
from segment_anything import SamPredictor, sam_model_registry
sam_model = sam_model_registry[model_type](checkpoint=checkpoint)
sam_model.to(device)
sam_model.eval()

  state_dict = torch.load(f)


Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()
    )


In [6]:
import monai

# Note: Hyperparameter tuning could improve performance here
optimizer = torch.optim.AdamW(sam_model.mask_decoder.parameters(), lr=1e-5, weight_decay=0.1)

loss_fn = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [7]:
def dice_coefficient(pred, target, smooth=1e-5):
    pred = pred.float()
    target = target.float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice = (2 * intersection + smooth) / (union + smooth)
    return dice

def iou_score(pred, target, smooth=1e-5):
    pred = pred.float()
    target = target.float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou

def precision_recall(pred, target, smooth=1e-5):
    pred = pred.float()
    target = target.float()
    true_positives = (pred * target).sum()
    false_positives = (pred * (1 - target)).sum()
    false_negatives = ((1 - pred) * target).sum()
    
    precision = (true_positives + smooth) / (true_positives + false_positives + smooth)
    recall = (true_positives + smooth) / (true_positives + false_negatives + smooth)
    
    return precision, recall

In [8]:
# Initialize variables to track metrics
test_loss = 0
test_dice = 0
test_iou = 0
test_precision = 0
test_recall = 0

with torch.no_grad():
    for i, data in enumerate(test_loader):
        imgs, msks, bbox = data
        imgs = imgs.to(device)
        msks = msks.to(device)
        input_bbox = bbox.to(device)

        # Forward pass
        input_image = sam_model.preprocess(imgs)
        image_embedding = sam_model.image_encoder(input_image)
        sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
            points=None,
            boxes=input_bbox,
            masks=None,
        )
        low_res_masks, _ = sam_model.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )

        # Compute loss
        loss = loss_fn(low_res_masks, msks.float())
        test_loss += loss.item()

        # Threshold predictions to binary masks
        pred_masks = (low_res_masks > 0.5).float()

        # Compute metrics
        test_dice += dice_coefficient(pred_masks, msks)
        test_iou += iou_score(pred_masks, msks)
        precision, recall = precision_recall(pred_masks, msks)
        test_precision += precision
        test_recall += recall

    # Average metrics over the test dataset
    test_loss /= (i + 1)
    test_dice /= (i + 1)
    test_iou /= (i + 1)
    test_precision /= (i + 1)
    test_recall /= (i + 1)

    print(f'Test Loss: {test_loss:.4f} | Dice: {test_dice:.4f} | IoU: {test_iou:.4f} | Precision: {test_precision:.4f} | Recall: {test_recall:.4f}')

Test Loss: 0.2905 | Dice: 0.8171 | IoU: 0.7080 | Precision: 0.8894 | Recall: 0.7868
