This notebook is for running predictions on an entire image (rather than just non-empty tiles) and stitching them together into a single image of predictions using the SAM Model. You will need a saved, trained SAM model to load in. 

In [1]:
from PIL import Image
import os
import evaluate
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import SamProcessor, SamModel, SamConfig
import torch
from patchify import patchify
from torchvision import transforms
from torch.optim import AdamW
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Loading and Preparing Data

Load in the image you would like to perform predictions on:

In [2]:
#dataset
image_paths_val = ['/explore/nobackup/people/sking11/MakingDinoDataset/gliht_1969_rgbi.tif']
label_paths_val = ['/explore/nobackup/people/sking11/MakingDinoDataset/binarymasks/cleaned_1969_binarymask_500cluster.png']

Tile the image with patchify

In [3]:
def split_image_and_mask(image_path, mask_path, patch_size=256):
    # Load the image and mask
    image = np.array(Image.open(image_path).convert('RGB'))
    mask = np.array(Image.open(mask_path))

    # Check mask statistics before normalization
    print("Mask stats before normalization: min={}, max={}".format(mask.min(), mask.max()))

    # Normalize the mask to have values between 0 and 1
    mask = mask / 255.0

    # Check mask statistics after normalization
    print("Mask stats after normalization: min={}, max={}".format(mask.min(), mask.max()))

    # Ensure masks are 2D for simplicity
    if len(mask.shape) > 2 and mask.shape[-1] == 1:
        mask = mask[:, :, 0]

    # Split the image and mask into patches
    image_patches = patchify(image, (patch_size, patch_size, 3), step=patch_size)
    mask_patches = patchify(mask, (patch_size, patch_size), step=patch_size)
    
    print(f"Image patches shape: {image_patches.shape}")
    print(f"Mask patches shape: {mask_patches.shape}")

    # Check patch statistics
    print("First image patch stats: min={}, max={}".format(image_patches[0, 0, 0, :, :, :].min(), image_patches[0, 0, 0, :, :, :].max()))
    print("First mask patch stats: min={}, max={}".format(mask_patches[0, 0, :, :].min(), mask_patches[0, 0, :, :].max()))
    
    return image_patches, mask_patches

def filter_patches(image_patches, mask_patches):
    all_img_patches = []
    all_mask_patches = []

    num_patches_x, num_patches_y = image_patches.shape[0], image_patches.shape[1]
    for i in range(num_patches_x):
        for j in range(num_patches_y):
            # Remove the extra dimension
            single_patch_img = image_patches[i, j, 0, :, :, :]  # (patch_size, patch_size, 3)
            single_patch_mask = mask_patches[i, j, :, :]         # (patch_size, patch_size)

            # Append all patches, including those that are entirely 0 or 1
            all_img_patches.append(single_patch_img)
            all_mask_patches.append(single_patch_mask)
    
    return np.array(all_img_patches), np.array(all_mask_patches)

def process_and_save_images(image_paths, mask_paths, patch_size=256):
    all_images = []
    all_masks = []

    for img_path, mask_path in zip(image_paths, mask_paths):
        image_patches, mask_patches = split_image_and_mask(img_path, mask_path, patch_size)
        
        # Process patches without filtering
        images, masks = filter_patches(image_patches, mask_patches)
        all_images.append(images)
        all_masks.append(masks)

    # Concatenate all image and mask patches
    all_images = np.concatenate(all_images)
    all_masks = np.concatenate(all_masks)

    return all_images, all_masks

from datasets import Dataset

def create_dataset(images, masks):
    # Ensure images and masks are NumPy arrays
    if not isinstance(images, np.ndarray) or not isinstance(masks, np.ndarray):
        raise ValueError("Images and masks must be NumPy arrays.")

    # Convert the NumPy arrays to Pillow images and store them in a dictionary
    dataset_dict = {
        "image": [Image.fromarray(img) for img in images],
        "label": [Image.fromarray(mask) for mask in masks],
    }

    # Create the dataset using the datasets.Dataset class
    dataset = Dataset.from_dict(dataset_dict)

    return dataset

In [4]:
val_images, val_masks = process_and_save_images(image_paths_val, label_paths_val)
val_dataset = create_dataset(val_images, val_masks)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, drop_last=False)



Mask stats before normalization: min=0, max=255
Mask stats after normalization: min=0.0, max=1.0
Image patches shape: (37, 49, 1, 256, 256, 3)
Mask patches shape: (37, 49, 256, 256)
First image patch stats: min=0, max=255
First mask patch stats: min=0.0, max=0.0


In [5]:
print(len(val_dataset))

1813


In [6]:
def get_bounding_box(ground_truth_map):
    y_indices, x_indices = np.where(ground_truth_map > 0)
    
    if len(x_indices) == 0 or len(y_indices) == 0:
        # Handle empty mask case
        # You might return a default bounding box or skip processing
        return [0, 0, 1, 1]  # Example default bounding box, adjust if needed

    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    
    # Add some margin to the bounding box if needed
    x_min = max(0, x_min - 10)
    x_max = min(ground_truth_map.shape[1], x_max + 10)
    y_min = max(0, y_min - 10)
    y_max = min(ground_truth_map.shape[0], y_max + 10)
    
    return [x_min, y_min, x_max, y_max]

In [7]:
#set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model configuration
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# Create an instance of the model architecture with the loaded configuration
my_mito_model = SamModel(config=model_config)
#Update the model by loading the weights from saved file.
my_mito_model.load_state_dict(torch.load('/explore/nobackup/people/sking11/sam_model_checkpoint_6400.pth'))

  my_mito_model.load_state_dict(torch.load('/explore/nobackup/people/sking11/sam_model_checkpoint_6400.pth'))


<All keys matched successfully>

In [8]:
# Define the function to generate predicted masks
def generate_predicted_masks(val_dataset, model, processor, device):
    predicted_masks = []

    model.to(device)
    model.eval()
    for idx in range(len(val_dataset)):
        # Load image
        test_image = val_dataset[idx]["image"]

        # Get box prompt based on ground truth segmentation map
        ground_truth_mask = np.array(val_dataset[idx]["label"])
        prompt = get_bounding_box(ground_truth_mask)  # Ensure the function is consistent

        # Prepare image + box prompt for the model
        inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt")

        # Move the input tensor to the GPU if it's not already there
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Forward pass
        with torch.no_grad():
            outputs = model(**inputs, multimask_output=False)

        # Apply sigmoid to get the probability map
        medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))

        # Convert soft mask to hard mask and move to CPU
        medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
        medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

        # Save the predicted mask as PNG
        mask_image = Image.fromarray(medsam_seg * 255)  # Convert binary mask to [0, 255] range
        mask_image.save(os.path.join(output_dir, f"predicted_mask_{idx}.png"))

        # Append the predicted mask to the list
        predicted_masks.append(medsam_seg)

    return predicted_masks

# Example usage
output_dir = "/explore/nobackup/people/sking11/SAM_output_masks1"
os.makedirs(output_dir, exist_ok=True)
val_predicted_masks = generate_predicted_masks(val_dataset, my_mito_model, processor, device)

In [9]:
print(len(val_predicted_masks))

1813


In [10]:
import numpy as np
from PIL import Image

def unpatchify_binary(patches_list, patch_size, num_patches_x, num_patches_y):
    """
    Reconstruct the full binary mask from a list of patches.
    
    Args:
        patches_list (list of numpy.ndarray): List of patch arrays with binary values.
        patch_size (tuple): Size of each patch (height, width).
        num_patches_x (int): Number of patches along the height.
        num_patches_y (int): Number of patches along the width.
        
    Returns:
        numpy.ndarray: The reconstructed full-size binary mask.
    """
    patch_height, patch_width = patch_size
    full_height = num_patches_x * patch_height
    full_width = num_patches_y * patch_width
    
    # Initialize the full-size mask array
    full_mask = np.zeros((full_height, full_width), dtype=np.float32)
    
    # Initialize a count array to handle overlapping patches (if any)
    count_array = np.zeros((full_height, full_width), dtype=np.float32)
    
    # Place each patch in the correct position
    index = 0
    for i in range(num_patches_x):
        for j in range(num_patches_y):
            if index >= len(patches_list):
                print(f"Warning: Expected more patches but reached the end of the list.")
                break
            
            patch = patches_list[index]
            start_x = i * patch_height
            start_y = j * patch_width
            
            full_mask[start_x:start_x+patch_height, start_y:start_y+patch_width] += patch
            count_array[start_x:start_x+patch_height, start_y:start_y+patch_width] += 1
            
            index += 1
    
    # Normalize if any overlapping patches were averaged
    full_mask = np.divide(full_mask, count_array, out=np.zeros_like(full_mask, dtype=np.float32), where=count_array!=0).astype(np.uint8)
    
    # Threshold to ensure binary output (0 or 1)
    full_mask = np.clip(full_mask, 0, 1)

    return full_mask

# Example usage:
num_patches_x = 37  # Number of rows
num_patches_y = 49  # Number of columns
patch_size = (256, 256)  # Height, width of each patch

# Reconstruct the full binary mask from the list
full_mask = unpatchify_binary(val_predicted_masks, patch_size, num_patches_x, num_patches_y)

# Save or visualize the results
# Convert binary mask to 'L' mode for saving
Image.fromarray(full_mask * 255).convert('L').save('/explore/nobackup/people/sking11/fullmaskSAM11212.png')

In [12]:
import numpy as np
from PIL import Image

def unpatchify_rgb(patches_list, patch_size, num_patches_x, num_patches_y):
    """
    Reconstruct the full RGB image from a list of patches.
    
    Args:
        patches_list (list of numpy.ndarray): List of patch arrays with RGB values.
        patch_size (tuple): Size of each patch (height, width).
        num_patches_x (int): Number of patches along the height.
        num_patches_y (int): Number of patches along the width.
        
    Returns:
        numpy.ndarray: The reconstructed full-size RGB image.
    """
    patch_height, patch_width = patch_size
    full_height = num_patches_x * patch_height
    full_width = num_patches_y * patch_width
    
    # Initialize the full-size image array
    full_image = np.zeros((full_height, full_width, 3), dtype=np.float32)
    
    # Initialize a count array to handle overlapping patches (if any)
    count_array = np.zeros((full_height, full_width, 3), dtype=np.float32)
    
    # Place each patch in the correct position
    index = 0
    for i in range(num_patches_x):
        for j in range(num_patches_y):
            if index >= len(patches_list):
                print(f"Warning: Expected more patches but reached the end of the list.")
                break
            
            patch = patches_list[index]
            start_x = i * patch_height
            start_y = j * patch_width
            
            full_image[start_x:start_x+patch_height, start_y:start_y+patch_width, :] += patch
            count_array[start_x:start_x+patch_height, start_y:start_y+patch_width, :] += 1
            
            index += 1
    
    # Normalize if any overlapping patches were averaged
    full_image = np.divide(full_image, count_array, out=np.zeros_like(full_image, dtype=np.float32), where=count_array!=0).astype(np.uint8)
    
    return full_image

# Example usage:
num_patches_x = 37  # Number of rows
num_patches_y = 49  # Number of columns
patch_size = (256, 256)  # Height, width of each patch

# Reconstruct the full RGB image from the list
full_image = unpatchify_rgb(val_images, patch_size, num_patches_x, num_patches_y)

# Save or visualize the results
# Convert RGB image for saving
Image.fromarray(full_image).save('/explore/nobackup/people/sking11/full_imageSAM_val1212.png')