# Semantic Segmentation with convpaint and DINOv2

This notebooks demonstrates how to run a semantic segmentation on an image using DINOv2 for feature extraction and a random forest algorithm for classification. It is based on the notebook provided by convpaint.


## Imports

In [31]:
%load_ext autoreload
%autoreload 2

import napari
import numpy as np
import skimage
from matplotlib import pyplot as plt
from dino_paint_utils import (train_dino_forest, predict_dino_forest)
                              

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Choose the model

Choose the DINOv2 model to be used:
|key | model|
|---|---|
|'s' | dinov2_vits14|
|'b' | dinov2_vitb14|
|'l' | dinov2_vitl14|
|'g' | dinov2_vitg14|

In [32]:
dinov2_model = 's'

## Train

Load an image and its annotation/labels to train the model on.

In [33]:
# image_original = skimage.data.cells3d()
# image_original = image_original[30, 1]
# from napari_convpaint.convpaint_sample import create_annotation_cell3d
# labels_original = create_annotation_cell3d()[0][0]
# crop = ((60,188), (0,128))
# crop = ((20,20+224), (0,224))
# image_original = image_original[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
# labels_original = labels_original[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]

# LOAD ASTRONAUT IMAGE (RGB) AND ANNOTATION
image_train = skimage.data.astronaut()#[0:504,0:504,:]
labels_train = plt.imread('astro_labels_2.tif')[:,:,0]#[0:504,0:504]

Exctract the features using DINOv2 and use them to train a random forest classifier.

In [34]:
random_forest, image_to_train, labels_to_train, features_space_train = train_dino_forest(image_train, labels_train, crop_to_patch=True, scale=1, dinov2_model=dinov2_model, show_napari=True)

Using cache found in C:\Users\roman/.cache\torch\hub\facebookresearch_dinov2_main


Show the image and labels.

In [13]:
# viewer = napari.Viewer()
# viewer.add_image(image_to_train.astype(np.int32))
# viewer.add_labels(labels_to_train)

<Labels layer 'labels_to_train' at 0x26893854fa0>

## Predict

Load an image to predict the labels for using the trained model above.

In [28]:
# image_pred = skimage.data.camera()
# image_pred = skimage.data.cat()
image_pred = skimage.data.horse().astype(np.int32)
# image_pred = skimage.data.binary_blobs().astype(np.int32)

Exctract the features and use them together with the trained classifier to make a prediciton for the labels.

In [29]:
predicted_labels, image_to_predict, features_space_predict = predict_dino_forest(image_pred, random_forest, crop_to_patch=True, scale=1, dinov2_model=dinov2_model)

Using cache found in C:\Users\roman/.cache\torch\hub\facebookresearch_dinov2_main


Show the image and the predicted labels.

In [30]:
# viewer = napari.Viewer()
# viewer.add_image(image_to_predict.astype(np.int32))
# viewer.add_labels(predicted_labels)

<Labels layer 'predicted_labels' at 0x2694ea71a00>