# Semantic Segmentation with convpaint and DINOv2

This notebooks demonstrates how to run a semantic segmentation on an image using DINOv2 for feature extraction and a random forest algorithm for classification. It is based on the notebook provided by convpaint.


## Imports

In [44]:
%load_ext autoreload
%autoreload 2

import napari
import numpy as np
import skimage
from matplotlib import pyplot as plt
from dino_paint_utils import (train_dino_forest,
                              predict_dino_forest,
                              selfpredict_dino_forest)
                              

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Choose the model

Choose the DINOv2 model to be used:
|key | model| features
|---|---|---|
|'s' | dinov2_vits14| 384|
|'b' | dinov2_vitb14| 768|
|'l' | dinov2_vitl14| 1024|
|'g' | dinov2_vitg14| 1536|

In [45]:
dinov2_model = 's'

## Train

Load an image and its annotation/labels to train the model on.

In [46]:
# image_to_train = skimage.data.cells3d()
# image_to_train = image_to_train[30, 1]
# from napari_convpaint.convpaint_sample import create_annotation_cell3d
# labels_to_train = create_annotation_cell3d()[0][0]
# crop = ((60,188), (0,128))
# crop = ((20,20+224), (0,224))
# image_to_train = image_to_train[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
# labels_to_train = labels_to_train[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]

# LOAD ASTRONAUT IMAGE (RGB) AND ANNOTATION
image_to_train = skimage.data.astronaut()#[0:504,0:504,:]
labels_to_train = plt.imread('astro_labels_2.tif')[:,:,0]#[0:504,0:504]

Exctract the features using DINOv2 and use them to train a random forest classifier.

In [47]:
train = train_dino_forest(image_to_train, labels_to_train, crop_to_patch=False, scale=1, upscale_order=1, dinov2_model=dinov2_model, show_napari=True)
random_forest, image_train, labels_train, features_space_train = train

Using cache found in C:\Users\roman/.cache\torch\hub\facebookresearch_dinov2_main


## Predict

Load an image to predict the labels for using the trained model above.

In [48]:
image_to_pred = skimage.data.camera()
# image_pred = skimage.data.cat()
# image_pred = skimage.data.horse().astype(np.int32)
# image_pred = skimage.data.binary_blobs().astype(np.int32)
# image_to_pred = skimage.data.coins()

Exctract the features and use them together with the trained classifier to make a prediciton for the labels.

In [49]:
pred = predict_dino_forest(image_to_pred, random_forest, crop_to_patch=True, scale=1, upscale_order=1, dinov2_model=dinov2_model, show_napari=True)
predicted_labels, image_pred, features_space_pred = pred

Using cache found in C:\Users\roman/.cache\torch\hub\facebookresearch_dinov2_main


## Selfpredict

We can also directly do a training and prediction on the same image.

In [50]:
# self_pred_image = skimage.data.astronaut()
# self_pred_labels = plt.imread('astro_labels_1.tif')[:,:,0]

# self_pred = selfpredict_dino_forest(self_pred_image, self_pred_labels, crop_to_patch=True, scale=1, upscale_order=1, dinov2_model='s', show_napari=True)