In [1]:
import torch
from transformers import SamModel, SamProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = SamProcessor.from_pretrained("wanglab/medsam-vit-base")
model = SamModel.from_pretrained("wanglab/medsam-vit-base").to(device)

In [81]:
from datasets import load_dataset

dataset = load_dataset("nielsr/breast-cancer", split="train")

In [85]:
print(dataset)

Dataset({
    features: ['image', 'label'],
    num_rows: 130
})


In [55]:
print(dataset.features)

{'image': Image(mode=None, decode=True, id=None), 'label': Image(mode=None, decode=True, id=None)}


# Import datasets

In [149]:
from torch.utils.data import Dataset
from PIL import Image
import os

class CustomSegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        
        # Assuming image files are PNG and mask files are BMP
        self.img_names = [f for f in os.listdir(img_dir) if f.endswith('.png')]
    
    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        
        # Construct paths to the image and mask
        img_path = os.path.join(self.img_dir, img_name)
        mask_name = img_name.replace('.png', '.bmp')  # Replace .png with .bmp to get corresponding mask name
        mask_path = os.path.join(self.mask_dir, mask_name)
        
        # Load image and mask
        image = Image.open(img_path).convert("RGB")  # Convert image to RGB
        mask = Image.open(mask_path).convert("L")  # Load mask (no need to convert, assume single channel)
         
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask




In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms

img_dir = "images folder"
mask_dir = "masks folder"

transform = transforms.Compose([
    transforms.Resize((512, 512)),  
    transforms.ToTensor(),          
])

dataset = CustomSegmentationDataset(img_dir=img_dir, mask_dir=mask_dir, transform=transform)

dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)

for images, masks in dataloader:
    print(masks.shape)
    pass

In [122]:
import matplotlib.pyplot as plt
import random

In [None]:
random_idx = random.randint(0, len(dataset) - 1)
image, mask = dataset[random_idx]

image_np = image.permute(1, 2, 0).numpy()  

mask_np = mask.squeeze().numpy()  

if mask_np.dtype != np.uint8:
    mask_np = (mask_np * 255).astype(np.uint8)  

mask_np = mask_np / mask_np.max()

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(image_np)
plt.title('Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(mask_np, cmap='gray')
plt.title('Mask')
plt.axis('off')
print(mask_np.shape)

plt.show()

# Other way of loading the images by the code

In [71]:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class SegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir):
        self.images_dir = images_dir
        self.masks_dir = masks_dir

        self.image_names = os.listdir(images_dir)
        self.images = []
        self.masks = []

        for image_name in self.image_names:
            img_path = os.path.join(self.images_dir, image_name)
            mask_name = image_name.replace(".png", ".bmp")
            mask_path = os.path.join(self.masks_dir, mask_name)

            try:
                image = Image.open(img_path).convert("RGB")
                mask = Image.open(mask_path).convert("L")

                image_np = np.array(image)
                mask_np = np.array(mask)

                self.images.append(image_np)
                self.masks.append(mask_np)
            except Exception as e:
                print(f"Error loading image or mask: {e}")

        self.images = np.array(self.images)
        self.masks = np.array(self.masks)

        print(f"Images shape before filtering: {self.images.shape}")
        print(f"Masks shape before filtering: {self.masks.shape}")

        self._filter_empty_masks()

    def _filter_empty_masks(self):
        if self.masks.size == 0:
            print("No masks loaded. Check data loading process.")
            self.images = np.array([])
            self.masks = np.array([])
            return
        valid_indices = [i for i, mask in enumerate(self.masks) if mask.max() != 0]

        if not valid_indices:
            print("No non-empty masks found.")
            self.images = np.array([])
            self.masks = np.array([])
            return

        self.images = self.images[valid_indices]
        self.masks = self.masks[valid_indices]

        print(f"Images shape after filtering: {self.images.shape}")
        print(f"Masks shape after filtering: {self.masks.shape}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if len(self.images) == 0 or len(self.masks) == 0:
            raise IndexError("Dataset is empty. Ensure filtering did not remove all data.")
        image = self.images[idx]
        mask = self.masks[idx]
        return image, mask

images_dir = "images path"
masks_dir = "labels path"
dataset = SegmentationDataset(images_dir=images_dir, masks_dir=masks_dir)


Images shape before filtering: (763, 512, 512, 3)
Masks shape before filtering: (763, 512, 512)
Images shape after filtering: (654, 512, 512, 3)
Masks shape after filtering: (654, 512, 512)


## Filtering images

In [73]:
from datasets import Dataset
from PIL import Image

dataset_dict = {
    "image": [],
    "label": [],
}
target_size = (256, 256)

for i in range(len(dataset)):
    image_np, mask_np = dataset[i]

    image_pil = Image.fromarray(image_np)
    mask_pil = Image.fromarray(mask_np)
    image_resized = image_pil.resize(target_size, Image.Resampling.LANCZOS)
    mask_resizeed = mask_pil.resize(target_size, Image.Resampling.NEAREST)

    dataset_dict["image"].append(image_resized)
    dataset_dict["label"].append(mask_resized)

dataset = Dataset.from_dict(dataset_dict)

In [None]:
dataset

In [None]:
import random
import matplotlib.pyplot as plt
img_num = random.randint(0, image_np.shape[0]-1)
example_image = dataset[img_num]["image"]
example_mask = dataset[img_num]["label"]

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(np.array(example_image), cmap='gray')  # Assuming the first image is grayscale
axes[0].set_title("Image")

axes[1].imshow(example_mask, cmap='gray')  # Assuming the second image is grayscale
axes[1].set_title("Mask")

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

plt.show()

In [80]:
def get_bounding_box(ground_truth_map):

  y_indices, x_indices = np.where(ground_truth_map > 0)
  x_min, x_max = np.min(x_indices), np.max(x_indices)
  y_min, y_max = np.min(y_indices), np.max(y_indices)
  H, W = ground_truth_map.shape
  x_min = max(0, x_min - np.random.randint(0, 20))
  x_max = min(W, x_max + np.random.randint(0, 20))
  y_min = max(0, y_min - np.random.randint(0, 20))
  y_max = min(H, y_max + np.random.randint(0, 20))
  bbox = [x_min, y_min, x_max, y_max]

  return bbox

In [82]:
from torch.utils.data import Dataset

class SAMDataset(Dataset):
  """
  This class is used to create a dataset that serves input images and masks.
  It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
  """
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["label"])

    prompt = get_bounding_box(ground_truth_mask)

    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

In [84]:
from transformers import SamProcessor
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [85]:
train_dataset = SAMDataset(dataset=dataset, processor=processor)

In [None]:
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)

In [96]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=False)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

In [None]:
batch["ground_truth_mask"].shape

In [None]:
from transformers import SamModel
model = SamModel.from_pretrained("facebook/sam-vit-base")

for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)
for name, param in model.named_parameters():
    print(f"Layer: {name}, Requires Grad: {param.requires_grad}")

In [104]:
from torch.optim import Adam
import monai
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Device name: {torch.cuda.get_device_name(0)}")

In [None]:
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

num_epochs = 1

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=False)

        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

In [112]:
 torch.save(model.state_dict(), "weight/checkpoint.pth")

In [114]:
from transformers import SamModel, SamConfig, SamProcessor
import numpy as np
import random
import torch
import matplotlib.pyplot as plt

In [116]:
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
my_mito_model = SamModel(config=model_config)
my_mito_model.load_state_dict(torch.load("weight/mito_model_checkpoint.pth"))
     

<All keys matched successfully>

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
my_mito_model.to(device)

In [None]:

idx = random.randint(0, image_np.shape[0]-1)
print(idx)

test_image = dataset[idx]["image"]

ground_truth_mask = np.array(dataset[idx]["label"])

prompt = get_bounding_box(ground_truth_mask)

inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt")

inputs = {k: v.to(device) for k, v in inputs.items()}

my_mito_model.eval()

with torch.no_grad():
    outputs = my_mito_model(**inputs, multimask_output=False)

medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)


fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(np.array(test_image), cmap='gray') 
axes[0].set_title("Image")

axes[1].imshow(medsam_seg, cmap='gray') 
axes[1].set_title("Mask")

axes[2].imshow(medsam_seg_prob)  
axes[2].set_title("Probability Map")

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
output_path = "Sam_result.png"
plt.savefig(output_path, bbox_inches='tight')
plt.show()

## Randomly picking prompts

In [None]:
import numpy as np
import random
import torch
import matplotlib.pyplot as plt

idx = random.randint(0, len(dataset) - 1)
print(f"Selected index: {idx}")

test_image = dataset[idx]["image"]
ground_truth_mask = np.array(dataset[idx]["label"])

def get_normalized_points(mask, num_points=50):
    pos_points = np.argwhere(mask == 1)  
    neg_points = np.argwhere(mask == 0)  
    
    if len(pos_points) > 0:
        pos_indices = np.random.choice(len(pos_points), min(num_points, len(pos_points)), replace=False)
        pos_points = pos_points[pos_indices]
    else:
        pos_points = np.array([[0, 0]]) 

    if len(neg_points) > 0:
        neg_indices = np.random.choice(len(neg_points), min(num_points, len(neg_points)), replace=False)
        neg_points = neg_points[neg_indices]
    else:
        neg_points = np.array([[0, 0]])  

    height, width = mask.shape
    pos_points = [[y / width, x / height] for x, y in pos_points]
    neg_points = [[y / width, x / height] for x, y in neg_points]

    return pos_points, neg_points

pos_points, neg_points = get_normalized_points(ground_truth_mask)

points = pos_points + neg_points
labels = [1] * len(pos_points) + [0] * len(neg_points)

inputs = processor(test_image, input_points=[points], input_labels=[labels], return_tensors="pt")

inputs = {k: v.to(device) for k, v in inputs.items()}

my_mito_model.eval()

with torch.no_grad():
    outputs = my_mito_model(**inputs, multimask_output=False)

medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))

medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

fig, axes = plt.subplots(1, 4, figsize=(20, 5))

axes[0].imshow(np.array(test_image))
axes[0].set_title("Image")

axes[1].imshow(ground_truth_mask, cmap='gray')
axes[1].set_title("Ground Truth Mask")

axes[2].imshow(medsam_seg, cmap='gray')
axes[2].set_title("Predicted Mask")

axes[3].imshow(medsam_seg_prob, cmap='gray')
axes[3].set_title("Probability Map")

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

height, width = ground_truth_mask.shape
for point in pos_points:
    axes[0].scatter(point[0] * width, point[1] * height, c='red', s=50)  
for point in neg_points:
    axes[0].scatter(point[0] * width, point[1] * height, c='blue', s=50)  
output_path = "Sam_result.png"
plt.savefig(output_path, bbox_inches='tight')
plt.show()