# Semantic Segmentation with convpaint and DINOv2

This notebooks demonstrates how to run a semantic segmentation on an image using DINOv2 for feature extraction and a random forest algorithm for classification. It is based on the notebook provided by convpaint and runs independently from napari.


## Imports

In [635]:
%load_ext autoreload
%autoreload 2

# import napari and its screenshot function
import napari
from napari.utils.notebook_display import nbscreenshot

# import what we need from conv_paint
from napari_convpaint.conv_paint import ConvPaintWidget
# from napari_convpaint.conv_paint_utils import Hookmodel
from napari_convpaint.convpaint_sample import create_annotation_cell3d
from napari_convpaint.conv_paint_utils import (filter_image_multioutputs, get_features_current_layers,
get_multiscale_features, train_classifier, predict_image)
from napari_convpaint.conv_paint_utils import extract_annotated_pixels
 
# import the other general modules used
import numpy as np
import skimage
# import tifffile
from matplotlib import pyplot as plt

# import pytorch and pillow Image
import torch
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, ToPILImage, InterpolationMode, CenterCrop
from PIL import Image


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Define the parameters

In [None]:
PATCH_SIZE = (14, 14)
crop_to_patch = True
scale = 1

## Load data

First, we load an image and the corresponding annotation.

In [636]:
# LOAD CELL3D IMAGE AND ANNOTATION
# Load 3D image with 2 channels (cell borders and nuclei)
# image_original = skimage.data.cells3d()
# Take a layer in middle of cell (30 of 0-59) and take 2nd channel (nuclei)
# image_original = image_original[30, 1]
# Load annotation defined in conv_paint
# labels_original = create_annotation_cell3d()[0][0]

# Take crops of image and annotation
# crop = ((60,188), (0,128))
# crop = ((20,20+224), (0,224))
# image_original = image_original[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
# labels_original = labels_original[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]

# LOAD ASTRONAUT IMAGE (RGB) AND ANNOTATION
image_original = skimage.data.astronaut()#[0:504,0:504,:]
labels_original = plt.imread('astro_labels_2.tif')[:,:,0]#[0:504,0:504]

# PRINT SHAPES
print(f"Original image shape: {image_original.shape}")
print(f"Original label image shape: {labels_original.shape}")

Original image shape: (512, 512, 3)
Original label image shape: (512, 512)


Show the image and annotation as layers in napari. Print out the shape of the original image and the number of originally annotated pixels.

In [637]:
# # create a napari viewer; add the image to it; add the labels/annotation
# viewer = napari.Viewer()
# viewer.add_image(image_original)
# viewer.add_labels(labels_original)

We can also show a napari screenshot.

In [638]:
# show a screenshot of the napari viewer here in the notebook
# nbscreenshot(viewer)

In [639]:
#tifffile.imwrite('label_cell3d.tiff', viewer.layers['Labels'].data)

## Create model

DINOv2 comes in 4 different versions, each increasing in training set size and power. Choose the desired model by assigning it to 'model'.

In [640]:
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
# model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
# model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
# model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')

Using cache found in C:\Users\roman/.cache\torch\hub\facebookresearch_dinov2_main


Define the model parameters and the processing of the image.
Note that the patch size that DINOv2 uses is 14x14. Therefore, the number of patches is the input shape divided by 14.

In [641]:
# Define input shape as the smallest multiple of the patch size that is larger than the image, optionally scaled by a factor
if crop_to_patch:
    crop_shape = (int(np.floor(image_original.shape[0]/PATCH_SIZE[0]))*PATCH_SIZE[0],
                  int(np.floor(image_original.shape[1]/PATCH_SIZE[1]))*PATCH_SIZE[1])
    in_shape = (crop_shape[0] * scale, crop_shape[1] * scale)

else:
    crop_shape = image_original.shape
    in_shape = (int(np.ceil(image_original.shape[0]/PATCH_SIZE[0]))*PATCH_SIZE[0] * scale,
                int(np.ceil(image_original.shape[1]/PATCH_SIZE[1]))*PATCH_SIZE[1] * scale)

image_cropped = image_original[0:crop_shape[0], 0:crop_shape[1]]
labels_cropped = labels_original[0:crop_shape[0], 0:crop_shape[1]]
image_scaled = skimage.transform.resize(image_cropped, in_shape, mode='edge', order=1, preserve_range=True)
labels_scaled = skimage.transform.resize(labels_cropped, in_shape, mode='edge', order=0, preserve_range=True)

# Calculate the shape of the patched image (i.e. how many patches fit in the image)
if not (in_shape[0]%PATCH_SIZE[0] == 0 and in_shape[1]%PATCH_SIZE[1] == 0):
    raise ValueError('Input shape must be divisible by patch size')
else:
    patched_image_shape = (int(in_shape[0]/PATCH_SIZE[0]), int(in_shape[1]/PATCH_SIZE[1]))


print(f"Original image is: {image_original.shape[0]} x {image_original.shape[1]} pixels")
print(f"Image is cropped to: {crop_shape[0]} x {crop_shape[1]} pixels")
print(f"Shape of input used for model (multiple of patch size): {in_shape[0]} x {in_shape[1]} pixels")
print(f"Patched image shape: {patched_image_shape[0]} x {patched_image_shape[1]} patches")

Original image is: 512 x 512 pixels
Image is cropped to: 504 x 504 pixels
Shape of input used for model (multiple of patch size): 504 x 504 pixels
Patched image shape: 36 x 36 patches


## Convert & preprocess image

Resize the image to the input shape (which has to be a multiple of the patch size 14).

In [642]:
# Scale original image to input size of the model
image_scaled = image_scaled.astype(np.float32)
labels_scaled = labels_scaled.astype(np.int32)

Show the scaled version of the image and labels to ensure that it still looks good.

In [643]:
# viewer = napari.Viewer()
# viewer.add_image(image_scaled.astype(np.int32))
# viewer.add_labels(labels_scaled)

Convert the image to RGB. Then preprocess it into a torch tensor, normalized according to distribution expected by the model.

In [644]:
# Convert to RGB
if image_scaled.ndim == 2:
    image_rgb = np.stack((image_scaled,)*3, axis=-1)
else:
    image_rgb = image_scaled

# New shape is (224, 224, 3); type is float64
# print(image_rgb.shape); print(image_rgb.dtype)
# napari.view_image(image_rgb.astype(np.int64), rgb=True)

In [645]:
current_mean, current_sd = np.mean(image_rgb, axis=(0,1)), np.std(image_rgb, axis=(0,1))
new_mean, new_sd = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])
image_norm = (image_rgb - current_mean) / current_sd
image_norm = image_norm * new_sd + new_mean

# Check that it worked
# print(np.mean(image_rgb, axis=(0,1)))
# print(np.mean(image_norm, axis=(0,1)))

# Show normalized image
# napari.view_image(image_norm, rgb=True)

In [646]:
# Convert to PyTorch tensor
image_tensor = ToTensor()(image_norm).float()

# Preprocess image and convert to PyTorch tensor
# preprocess = Compose([
#     ToTensor(),  # Convert to PyTorch tensor
#     # Resize(in_shape, interpolation=InterpolationMode.BILINEAR, antialias=True),  # Resize to the input size expected by the model
#     # CenterCrop(in_shape),  # Crop to the input size expected by the model
#     # ToTensor(),  # Convert to PyTorch tensor
#     # Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # Normalize to the input distribution expected by the model
# ])

# image_tensor = preprocess(image_norm).float()
# image_tensor = preprocess(image_rgb)

# image_tensor = np.array(image_tensor)
# image_tensor = preprocess(Image.fromarray(image_rgb.astype(np.uint8))).float()


Show output of preprocessing to verify it is correct.

In [647]:
# print(f"Mean: {image_tensor.mean(dim=(1, 2))}")
# print(f"Standard Deviation: {image_tensor.std(dim=(1, 2))}")
# np.mean(image_tensor, axis=(1,2))
# print(image_rgb[0:4,0:4,:])
# print(image_tensor.numpy().transpose(1,2,0)[0:4,0:4,:])

# tensor_np = image_tensor.numpy().transpose(1,2,0)
# print(tensor_np.shape)
# napari.view_image(tensor_np, rgb=True)

## Feature extraction

Now that the model is defined, we can run an image through it and extract features from it.

In [648]:
# Add an extra batch dimension 
image_batch = image_tensor.unsqueeze(0)

# Move image to the GPU if available
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# image_batch = image_batch.to(device)
# model = model.to(device)

# Pass image through the model (assuming image_batch is a batch of test images)
with torch.no_grad():
    features_dict = model.forward_features(image_batch)
    features = features_dict['x_norm_patchtokens']

# The output shape is [batch_size, num_patches, features] = [1, 256, NUM_FEATURES]
# print(features.shape)

Rearrange and reshape dimensions of the DINOv2 output.

In [649]:
# Read out the number of features, which is dependent on the model chosen (ViTs14 = 384; ViTb14 = 768; ViTl14 = 1024; ViTg14 = 1536)
NUM_FEATURES = features.shape[2]
print(f"Number features: {NUM_FEATURES}")

# Rearrange dimensions of the feature tensor to [batch_size, features, num_patches] = [1, NUM_FEATURES, 256]
features_perm = features.permute(0, 2, 1)
# print(features.shape)

# Reshape linear patches (256) into 2D: [batch_size, features, patches_w, patches_h] = [1, NUM_FEATURES, 16, 16]
features_wh = features_perm.reshape(1, NUM_FEATURES, patched_image_shape[0], patched_image_shape[1])
# print(features_wh.shape)

# Upsample to original image size, i.e. [batch_size, features, image_w, image_h] = [1, NUM_FEATURES, 128, 128] or [1, NUM_FEATURES, 224, 224]
# scaling_factor = (14 * image_original.shape[0] / 224, 14 * image_original.shape[1] / 224)

# Upsample to the size of the scaled image (i.e. interpolate with scaling factor = patch_size = 14)
features_int = torch.nn.functional.interpolate(features_wh, scale_factor=PATCH_SIZE)
# print(features_int.shape)

# Convert to numpy array and remove batch dimension to get [features, image_w, image_h] = [NUM_FEATURES, 128, 128] or [NUM_FEATURES, 224, 224]
features_np = features_int.numpy()
features_np = np.squeeze(features_np, axis=0)
# print(f"Shape of features_np: {features_np.shape}")


Number features: 384


Show feature space in napari.

In [650]:
# # Show the loaded image and the annotation
# viewer = napari.Viewer()
# viewer.add_image(image_scaled.astype(np.int32))
# viewer.add_labels(labels_scaled)
# # add the feature space
# viewer.add_image(features_np)

## Train and use Classifier

Extract features and target values (labels) where image is annotated.

In [651]:
features_annot, targets = extract_annotated_pixels(features_np, labels_scaled, full_annotation=False)
# features.shape = (646, NUM_FEATURES)
# targets.shape = (646,)
print(f"Shape of annotated features: {features_annot.shape}")
print(f"Number of targets: {targets.shape[0]}")
print(f"Number of originally annotated pixels: {sum(labels_original[labels_original>0])}")
print(f"Number of annotated pixels in resized annotation: {sum(labels_scaled[labels_scaled>0])}")

# # NOTE: in convpaint, we had
# features.shape = (218, 640)
# targets.shape = (218,)
# And the number of originally annotated pixels was 327

Shape of annotated features: (5744, 384)
Number of targets: 5744
Number of originally annotated pixels: 8789
Number of annotated pixels in resized annotation: 8761


Train the classifier.

In [652]:
random_forest = train_classifier(features_annot, targets)

## Prediction

In [653]:
# NOTE: If we wanted to predict on another image, we would have to do the following:
# 1) Extract features from the new image using the same DINOv2 model
# 2) predict on the features using the random forest created above (learned from the original image)

# Convert features to numpy array
features_to_predict = features.numpy()
# Remove the batch dimension
features_to_predict_lin = features_to_predict.squeeze(0)

# Run predict of random forest on all features
predictions = random_forest.predict(features_to_predict_lin)

# We have 256 predictions, which corresponds to the 256 patches (16x16 in the image)
# print(predictions.shape)

Reshape and resize the predictions so we can show and overlay them on the image.

In [654]:
# Reshape the predictions to the shape of the image of patches
predicted_image = predictions.reshape(patched_image_shape[0], patched_image_shape[1])
# Resize to the size of the scaled input image
predicted_image = skimage.transform.resize(predicted_image, in_shape, mode='edge', order=0, anti_aliasing=False)
# Transform interpolated values to integer values
predicted_image = predicted_image.astype(np.uint8)

And finally we can visualize the output (and quantify its quality):

In [655]:
viewer = napari.Viewer()
# add the loaded image to it
viewer.add_image(image_scaled.astype(np.int32))
# add the loaded labels/annotation
viewer.add_labels(labels_scaled)
# add the prediction
viewer.add_labels(predicted_image)

<Labels layer 'predicted_image' at 0x2c07f3b19d0>

In [656]:
# nbscreenshot(viewer)

## Tests Roman

In [657]:
# # CREATE AND SHOW FULL OUTPUT OF A LAYER OF VGG16 (= 64 FEATURES)
# def get_layer_features(image, layer, show_napari = False, interpolate = False):
        
#     model = Hookmodel(model_name='vgg16')


#     all_layers = [key for key in model.module_dict.keys()]
#     # Choose just 1 layer, and register a hook there
#     if isinstance(layer, str):
#         layers = [layer]
#     elif isinstance(layer, int):
#         layers = [all_layers[layer]]
    
#     # layers = ['features.30 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) avgpool AdaptiveAvgPool2d(output_size=(7, 7))']
#     model.register_hooks(selected_layers=layers)

#     # Get features using only this first layer and without scaling
#     features, targets = get_features_current_layers(
#         model=model, image=image, annotations=image, scalings=[1], use_min_features=False, order=interpolate)

#     # Convert the DataFrame to a numpy array
#     features_array = features.values
#     # Get the shape of the image
#     image_shape = image.shape
#     # Reshape the features array to match the image shape and add the second dimension of features as the third dimension
#     features_image = features_array.reshape(*image_shape, -1)

#     # Move the last dimension to the first position
#     features_image = np.moveaxis(features_image, -1, 0)
#     # print(features.shape)
#     # print(features_image.shape)

#     # Now you can view the new_features using napari
#     if show_napari: napari.view_image(features_image)
#     return features_image

In [658]:
# # RUN

# # image = image.T

# # Get features of multiple (all) layers
# conv_layers = [0,2]#,5,7,10,12,14,17,19,21,24,26,28]
# all_conv = [get_layer_features(image, l) for l in conv_layers]


# ### Pad first dimension of the layers with fewer features and concatenate all layers into a 4D Image

# # Get the shapes of all outputs
# shapes = [output.shape for output in all_conv]
# # Find the maximum shape in each dimension
# max_shape = np.max(shapes, axis=0)
# # Pad all outputs to have the max shape
# from numpy.lib import pad
# all_conv_padded = np.array([pad(output, [(0, max_dim - dim) for dim, max_dim in zip(output.shape, max_shape)]) for output in all_conv])

# # Show in Napari
# napari.view_image(all_conv_padded)