# Semantic Segmentation with convpaint and DINOv2

This notebooks demonstrates how to run a semantic segmentation on an image using DINOv2 for feature extraction and a random forest algorithm for classification. It is based on the notebook provided by convpaint.


## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import napari
import numpy as np
import skimage
from matplotlib import pyplot as plt

from napari_convpaint.conv_paint import ConvPaintWidget
from napari_convpaint.conv_paint_utils import (train_classifier,
                                               extract_annotated_pixels)
from dino_paint_utils import (extract_single_image_dinov2_features,
                              scale_to_patch,
                              dino_features_to_image,
                              predict_to_image)


## Define the parameters

In [14]:
PATCH_SIZE = (14, 14)
dinov2_model = 's' # choose between s, b, l and g
crop_to_patch_train = True
crop_to_patch_predict = True
scale_train = 1
scale_predict = 2

## Load data

First, we load an image for training the classifier and the corresponding annotation, as well as an image to predict.

In [10]:
# image_original = skimage.data.cells3d()
# image_original = image_original[30, 1]
# from napari_convpaint.convpaint_sample import create_annotation_cell3d
# labels_original = create_annotation_cell3d()[0][0]
# crop = ((60,188), (0,128))
# crop = ((20,20+224), (0,224))
# image_original = image_original[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
# labels_original = labels_original[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]

# LOAD ASTRONAUT IMAGE (RGB) AND ANNOTATION
image_train = skimage.data.astronaut()#[0:504,0:504,:]
labels_train = plt.imread('astro_labels_2.tif')[:,:,0]#[0:504,0:504]

image_pred = skimage.data.camera()

# PRINT SHAPES
print(f"Original image shape: {image_original.shape}")
print(f"Original label image shape: {labels_original.shape}")

Original image shape: (512, 512, 3)
Original label image shape: (512, 512)


## Train
Exctract the features using DINOv2 and use them to train a random forest classifier.

In [16]:
image_to_train = scale_to_patch(image_train, crop_to_patch_train, scale_train, interpolation_order=1)
labels_to_train = scale_to_patch(labels_train, crop_to_patch_train, scale_train, 0)
labels_to_train = labels_to_train.astype(np.int32)

features_trained = extract_single_image_dinov2_features(image_to_train, dinov2_model)
features_space_trained = dino_features_to_image(features_trained, image_to_train.shape, PATCH_SIZE)
features_annot, targets = extract_annotated_pixels(features_space_trained, labels_to_train, full_annotation=False)
random_forest = train_classifier(features_annot, targets)

Using cache found in C:\Users\roman/.cache\torch\hub\facebookresearch_dinov2_main


## Predict
Exctract the features and use them together with the trained classifier to make a prediciton for the labels.

In [15]:
image_to_predict = scale_to_patch(image_to_pred, crop_to_patch_predict, scale_predict, interpolation_order=1)
features_to_predict = extract_single_image_dinov2_features(image_to_predict, dinov2_model)
predictions = random_forest.predict(features_to_predict)
predicted_labels = predict_to_image(predictions, image_to_predict.shape, interpolation_order=1)

viewer = napari.Viewer()
# viewer.add_image(image_to_train.astype(np.int32))
# viewer.add_labels(labels_to_train)
viewer.add_image(image_to_predict.astype(np.int32))
viewer.add_labels(predicted_labels)

Using cache found in C:\Users\roman/.cache\torch\hub\facebookresearch_dinov2_main


<Labels layer 'predicted_labels' at 0x28633e2a730>