This notebook is for running predictions on an entire image (rather than just non-empty tiles) and stitching them together into a single image of predictions using the DINOv2 Model. You will need a saved, trained DINOv2 model to load in. 

In [20]:
#packages

from datasets import load_dataset
from PIL import Image
import os
import evaluate
import numpy as np
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2
import albumentations as A
from torch.utils.data import Dataset, DataLoader
import torch
from patchify import patchify
from transformers import Dinov2Model, Dinov2PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput
from tqdm import tqdm

# Loading and Preparing Data

Load in the image you would like to perform predictions on:

In [12]:
#dataset
image_paths_val = ['/explore/nobackup/people/sking11/MakingDinoDataset/gliht_1969_rgbi.tif']
label_paths_val = ['/explore/nobackup/people/sking11/MakingDinoDataset/binarymasks/cleaned_1969_binarymask_500cluster.png']

Tile the image with patchify

In [18]:
def split_image_and_mask(image_path, mask_path, patch_size=256):
    # Load the image and mask
    image = np.array(Image.open(image_path).convert('RGB'))
    mask = np.array(Image.open(mask_path))

    # Ensure masks are 2D for simplicity
    if len(mask.shape) > 2 and mask.shape[-1] == 1:
        mask = mask[:, :, 0]

    # Split the image and mask into patches
    image_patches = patchify(image, (patch_size, patch_size, 3), step=patch_size)
    mask_patches = patchify(mask, (patch_size, patch_size), step=patch_size)
    
    print(f"Image patches shape: {image_patches.shape}")
    print(f"Mask patches shape: {mask_patches.shape}")
    
    return image_patches, mask_patches

def filter_patches(image_patches, mask_patches):
    all_img_patches = []
    all_mask_patches = []

    num_patches_x, num_patches_y = image_patches.shape[0], image_patches.shape[1]
    for i in range(num_patches_x):
        for j in range(num_patches_y):
            # Remove the extra dimension
            single_patch_img = image_patches[i, j, 0, :, :, :]  # (patch_size, patch_size, 3)
            single_patch_mask = mask_patches[i, j, :, :]         # (patch_size, patch_size)

            # Directly append patches without filtering
            all_img_patches.append(single_patch_img)
            all_mask_patches.append(single_patch_mask)
    
    return np.array(all_img_patches), np.array(all_mask_patches)

def process_and_save_images(image_paths, mask_paths, patch_size=256):
    all_images = []
    all_masks = []

    for img_path, mask_path in zip(image_paths, mask_paths):
        image_patches, mask_patches = split_image_and_mask(img_path, mask_path, patch_size)
        
        # Process patches without filtering
        images, masks = filter_patches(image_patches, mask_patches)
        all_images.append(images)
        all_masks.append(masks)

    # Concatenate all image and mask patches
    all_images = np.concatenate(all_images)
    all_masks = np.concatenate(all_masks)

    return all_images, all_masks

class CustomSegmentationDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        original_image = self.images[idx]
        original_segmentation_map = self.masks[idx]

        # Ensure image and mask sizes match
        assert original_image.shape[:2] == original_segmentation_map.shape[:2], \
            f"Image and mask shape mismatch: {original_image.shape[:2]} vs {original_segmentation_map.shape[:2]}"

        # Convert masks to binary values: 0 and 1
        original_segmentation_map = (original_segmentation_map / 255).astype(np.uint8)

        if self.transform:
            transformed = self.transform(image=original_image, mask=original_segmentation_map)
            image = transformed['image']
            target = transformed['mask']
        else:
            image = original_image
            target = original_segmentation_map

        # Convert to tensor
        image = torch.tensor(image).float()
        target = torch.tensor(target).long()

        # Convert to C, H, W if necessary
        if image.ndimension() == 3 and image.shape[-1] in {1, 3}:
            image = image.permute(2, 0, 1)  # Convert to C, H, W

        return image, target, original_image, original_segmentation_map

ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255

val_transform = A.Compose([
    A.Resize(width=448, height=448),
    A.Normalize(mean=ADE_MEAN.tolist(), std=ADE_STD.tolist()),
])

def collate_fn(inputs):
    batch = dict()
    pixel_values = torch.stack([torch.tensor(i[0]).float() for i in inputs], dim=0)
    labels = torch.stack([torch.tensor(i[1]).long() for i in inputs], dim=0)
    original_images = [torch.tensor(i[2]).float() for i in inputs]
    original_segmentation_maps = [torch.tensor(i[3]).long() for i in inputs]

    # Uncomment this if your pixel_values are in (batch_size, height, width, channels) format
    # pixel_values = pixel_values.permute(0, 3, 1, 2)

    batch["pixel_values"] = pixel_values
    batch["labels"] = labels
    batch["original_images"] = original_images
    batch["original_segmentation_maps"] = original_segmentation_maps

    return batch

In [19]:
val_images, val_masks = process_and_save_images(image_paths_val, label_paths_val)
val_dataset = CustomSegmentationDataset(val_images, val_masks, transform=val_transform)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

Image patches shape: (37, 49, 1, 256, 256, 3)
Mask patches shape: (37, 49, 256, 256)


# Setting the Model 

In [5]:
class Dinov2ForSemanticSegmentation(Dinov2PreTrainedModel):
    def __init__(self, config, class_weights=None):
        """
        Initialize the DINOv2 model for semantic segmentation.

        Args:
        - config: Configuration object containing model parameters.
        - class_weights: Tensor containing class weights for the loss function.
        """
        super().__init__(config)
        self.dinov2 = Dinov2Model(config)  # Load the base DINOv2 model
        self.classifier = LinearClassifier(config.hidden_size, 32, 32, config.num_labels)  # Define the classification head
        self.class_weights = class_weights  # Store class weights

    def forward(self, pixel_values, output_hidden_states=False, output_attentions=False, labels=None):
        """
        Forward pass for the semantic segmentation model.

        Args:
        - pixel_values: Tensor of shape (batch_size, num_channels, height, width) containing input images.
        - output_hidden_states: If True, return hidden states.
        - output_attentions: If True, return attention weights.
        - labels: Tensor of shape (batch_size, height, width) containing segmentation maps (optional).

        Returns:
        - SemanticSegmenterOutput containing:
          - loss: Cross-entropy loss (if labels are provided).
          - logits: Tensor of shape (batch_size, num_labels, height, width) containing class logits.
          - hidden_states: List of hidden states from the DINOv2 model (if output_hidden_states is True).
          - attentions: List of attention weights from the DINOv2 model (if output_attentions is True).
        """
        # Obtain DINOv2 model outputs
        outputs = self.dinov2(pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions)
        
        # Extract patch embeddings (excluding the CLS token)
        patch_embeddings = outputs.last_hidden_state[:, 1:, :]
        
        # Check the shape of patch_embeddings
        batch_size, num_patches, embedding_dim = patch_embeddings.shape
        #print(f"patch_embeddings shape: {patch_embeddings.shape}")

        # Calculate the expected number of patches
        num_patches_expected = self.classifier.width * self.classifier.height
        if num_patches != num_patches_expected:
            raise ValueError(f"Unexpected number of patches: {num_patches}, expected: {num_patches_expected}")

        # Apply the classification head to the patch embeddings
        logits = self.classifier(patch_embeddings)
        
        # Resize logits to match the size of the input images
        logits = torch.nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)
        
        # Initialize loss to None
        loss = None
        if labels is not None:
            # Compute the cross-entropy loss with class weights
            #loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights, ignore_index=0)  # Use 0 as the ignore index for background

            loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights)
            loss = loss_fct(logits, labels)
        
        # Return the output containing the loss, logits, hidden states, and attentions
        return SemanticSegmenterOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [6]:
class LinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1):
        super(LinearClassifier, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.classifier = torch.nn.Conv2d(in_channels, num_labels, (1,1))

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        embeddings = embeddings.permute(0,3,1,2)

        return self.classifier(embeddings)

In [7]:
#set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the complete model
model = torch.load('/explore/nobackup/people/sking11/dinov2model_6400.pth')
model.to(device)

  model = torch.load('/explore/nobackup/people/sking11/dinov2model_6400.pth')


Dinov2ForSemanticSegmentation(
  (dinov2): Dinov2Model(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Dinov2Encoder(
      (layer): ModuleList(
        (0-11): 12 x Dinov2Layer(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attention): Dinov2Attention(
            (attention): Dinov2SelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): Dinov2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
      

# Doing Predictions 

In [22]:
# Directory to save predicted masks
output_dir = "/explore/nobackup/people/sking11/outputmasks"

# List to store all predicted masks
predicted_masks_list = []

# Ensure the model is in evaluation mode
model.eval()

# Loop through each image in the validation dataset
for idx in tqdm(range(len(val_dataset))):
    val_image, _, original_image, _ = val_dataset[idx]
    
    # Convert the image to a tensor and add a batch dimension
    pixel_values = torch.tensor(val_image).permute(0, 1, 2).unsqueeze(0).to(device)
    
    # Forward pass through the model
    with torch.no_grad():
        outputs = model(pixel_values)
    
    # Upsample the logits to match the original image size
    upsampled_logits = torch.nn.functional.interpolate(outputs.logits,
                                                       size=original_image.shape[:2],
                                                       mode="bilinear", align_corners=False)
    
    # Get the predicted segmentation map
    predicted_map = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()
    
    # Save the predicted map in the list
    predicted_masks_list.append(predicted_map)
    
    # Convert predicted map to an image (you may need to normalize the values)
    #predicted_image = Image.fromarray((predicted_map * 255).astype(np.uint8))
    
    # Save the predicted image
    #output_path = os.path.join(output_dir, f"predicted_mask_{idx}.png")
    #predicted_image.save(output_path)

    #print(f"Processed image {idx + 1}/{len(val_dataset)}")

print("All images processed and saved.")

# Now, `predicted_masks_list` contains all the predicted masks

  pixel_values = torch.tensor(val_image).permute(0, 1, 2).unsqueeze(0).to(device)
100%|██████████| 1813/1813 [01:16<00:00, 23.69it/s]

All images processed and saved.





In [23]:
#Double check that lengths match up
print(len(val_images))
print(len(predicted_masks_list))

1813
1813


Put all the mask pieces together into one complete image. The num_patches_x and num_patches_y should be changed to match the Image patches shape from the prior tiling step. 

In [24]:
import numpy as np
from PIL import Image

def unpatchify_binary(patches_list, patch_size, num_patches_x, num_patches_y):
    """
    Reconstruct the full binary mask from a list of patches.
    
    Args:
        patches_list (list of numpy.ndarray): List of patch arrays with binary values.
        patch_size (tuple): Size of each patch (height, width).
        num_patches_x (int): Number of patches along the height.
        num_patches_y (int): Number of patches along the width.
        
    Returns:
        numpy.ndarray: The reconstructed full-size binary mask.
    """
    patch_height, patch_width = patch_size
    full_height = num_patches_x * patch_height
    full_width = num_patches_y * patch_width
    
    # Initialize the full-size mask array
    full_mask = np.zeros((full_height, full_width), dtype=np.float32)
    
    # Initialize a count array to handle overlapping patches (if any)
    count_array = np.zeros((full_height, full_width), dtype=np.float32)
    
    # Place each patch in the correct position
    index = 0
    for i in range(num_patches_x):
        for j in range(num_patches_y):
            if index >= len(patches_list):
                print(f"Warning: Expected more patches but reached the end of the list.")
                break
            
            patch = patches_list[index]
            start_x = i * patch_height
            start_y = j * patch_width
            
            full_mask[start_x:start_x+patch_height, start_y:start_y+patch_width] += patch
            count_array[start_x:start_x+patch_height, start_y:start_y+patch_width] += 1
            
            index += 1
    
    # Normalize if any overlapping patches were averaged
    full_mask = np.divide(full_mask, count_array, out=np.zeros_like(full_mask, dtype=np.float32), where=count_array!=0).astype(np.uint8)
    
    # Threshold to ensure binary output (0 or 1)
    full_mask = np.clip(full_mask, 0, 1)

    return full_mask

# Example usage:
num_patches_x = 37  # Number of rows
num_patches_y = 49  # Number of columns
patch_size = (256, 256)  # Height, width of each patch

# Reconstruct the full binary mask from the list
full_mask = unpatchify_binary(predicted_masks_list, patch_size, num_patches_x, num_patches_y)

# Save or visualize the results
# Convert binary mask to 'L' mode for saving
Image.fromarray(full_mask * 255).convert('L').save('/explore/nobackup/people/sking11/fullmaskasdasd.png')

The mask won't line up over the original image as the patching process likely cropped it slightly, so the RGB image needs to be re-patched together as well. Then, you will have a complete predicted mask and an image to overlay it against.  

In [None]:
def unpatchify_rgb(patches_list, patch_size, num_patches_x, num_patches_y):
    """
    Reconstruct the full RGB image from a list of patches.
    
    Args:
        patches_list (list of numpy.ndarray): List of patch arrays with RGB values.
        patch_size (tuple): Size of each patch (height, width).
        num_patches_x (int): Number of patches along the height.
        num_patches_y (int): Number of patches along the width.
        
    Returns:
        numpy.ndarray: The reconstructed full-size RGB image.
    """
    patch_height, patch_width = patch_size
    full_height = num_patches_x * patch_height
    full_width = num_patches_y * patch_width
    
    # Initialize the full-size image array
    full_image = np.zeros((full_height, full_width, 3), dtype=np.float32)
    
    # Initialize a count array to handle overlapping patches (if any)
    count_array = np.zeros((full_height, full_width, 3), dtype=np.float32)
    
    # Place each patch in the correct position
    index = 0
    for i in range(num_patches_x):
        for j in range(num_patches_y):
            if index >= len(patches_list):
                print(f"Warning: Expected more patches but reached the end of the list.")
                break
            
            patch = patches_list[index]
            start_x = i * patch_height
            start_y = j * patch_width
            
            full_image[start_x:start_x+patch_height, start_y:start_y+patch_width, :] += patch
            count_array[start_x:start_x+patch_height, start_y:start_y+patch_width, :] += 1
            
            index += 1
    
    # Normalize if any overlapping patches were averaged
    full_image = np.divide(full_image, count_array, out=np.zeros_like(full_image, dtype=np.float32), where=count_array!=0).astype(np.uint8)
    
    return full_image

# Example usage:
num_patches_x = 37  # Number of rows
num_patches_y = 49  # Number of columns
patch_size = (256, 256)  # Height, width of each patch

# Reconstruct the full RGB image from the list
full_image = unpatchify_rgb(val_images, patch_size, num_patches_x, num_patches_y)

# Save or visualize the results
# Convert RGB image for saving
Image.fromarray(full_image).save('/explore/nobackup/people/sking11/full_imageasdasd.png')