# Semantic Segmentation with convpaint and DINOv2

This notebooks demonstrates how to run a semantic segmentation on an image using DINOv2 for feature extraction and a random forest algorithm for classification. It is based on the notebook provided by convpaint.


## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import napari
import numpy as np
import skimage
from matplotlib import pyplot as plt
from dino_paint_utils import (train_dino_forest,
                              predict_dino_forest,
                              selfpredict_dino_forest,
                              test_dino_forest,
                              pad_to_patch)

## Choose the model

1) Choose the **DINOv2 model** to be used (assign None to not use DINOv2):

|key | model| features
|---|---|---|
|'s' | dinov2_vits14| 384|
|'b' | dinov2_vitb14| 768|
|'l' | dinov2_vitl14| 1024|
|'g' | dinov2_vitg14| 1536|
|+ '_r' | *base_model*_reg (not supported yet)| add registers|

2) Choose the **layers of DINOv2** to used features (give a list of indices 0-11); each layer has the number of features specific for the model as listed in the table above.

3) Choose the **layers of VGG16** to be attatched as additional features (give a list of indices; only use Conv2d layers; assign None to not use VGG16):

|index|layer|
|---|---|
|**0**|**Conv2d3, 64, kernel_size=3, stride=1, padding=1**|
|1|ReLUinplace=True|
|**2**|**Conv2d64, 64, kernel_size=3, stride=1, padding=1**|
|3|ReLUinplace=True|
|4|MaxPool2dkernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False|
|**5**|**Conv2d64, 128, kernel_size=3, stride=1, padding=1**|
|6|ReLUinplace=True|
|**7**|**Conv2d128, 128, kernel_size=3, stride=1, padding=1**|
|8|ReLUinplace=True|
|9|MaxPool2dkernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False|
|**10**|**Conv2d128, 256, kernel_size=3, stride=1, padding=1**|
|11|ReLUinplace=True|
|**12**|**Conv2d256, 256, kernel_size=3, stride=1, padding=1**|
|13|ReLUinplace=True|
|**14**|**Conv2d256, 256, kernel_size=3, stride=1, padding=1**|
|15|ReLUinplace=True|
|16|MaxPool2dkernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False|
|**17**|**Conv2d256, 512, kernel_size=3, stride=1, padding=1**|
|18|ReLUinplace=True|
|**19**|**Conv2d512, 512, kernel_size=3, stride=1, padding=1**|
|20|ReLUinplace=True|
|**21**|**Conv2d512, 512, kernel_size=3, stride=1, padding=1**|
|22|ReLUinplace=True|
|23|MaxPool2dkernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False|
|**24**|**Conv2d512, 512, kernel_size=3, stride=1, padding=1**|
|25|ReLUinplace=True|
|**26**|**Conv2d512, 512, kernel_size=3, stride=1, padding=1**|
|27|ReLUinplace=True|
|**28**|**Conv2d512, 512, kernel_size=3, stride=1, padding=1**|
|29|ReLUinplace=True|
|30|MaxPool2dkernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False|

4) Choose if the **image itself** (3 rgb channels) shall be added as features.

5) Choose the **scale factor combinations** to use.



In [4]:
dinov2_model = 's'
dinov2_layers = [8,9,10,11]
extra_pads= [] #[7] #[2,4,6,8,10,12]
vgg16 = None #[2] # [0,2,5,7,10,12,14,17,19,21,24,26,28]
image_as_feature = True
scales = [] #[1,2,3]

upscale_order = 0
pad_mode = 'reflect'

## Train

Load an image and its annotation/labels to train the model on.

In [56]:
image_to_train = skimage.data.cells3d()
image_to_train = image_to_train[30, 1]
from napari_convpaint.convpaint_sample import create_annotation_cell3d
labels_to_train = create_annotation_cell3d()[0][0]
image_to_train = image_to_train[:, :126]
labels_to_train = labels_to_train[:, :126]

# crop = ((60,288), (0,178))
# crop = ((20,20+224), (0,224))
# image_to_train = image_to_train[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
# labels_to_train = labels_to_train[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]

# LOAD ASTRONAUT IMAGE (RGB) AND ANNOTATION
# image_to_train = skimage.data.astronaut()#[0:504,0:504,:]
# labels_to_train = plt.imread('images_and_labels/astro_labels.tif')#[0:504,0:504]

# LOAD HARDER CELL IMAGE AND ITS LABELS
# image_to_train = plt.imread('images_and_labels/00_00016.tiff')
# labels_to_train = plt.imread('images_and_labels/00_00016_labels.tiff')


Viewer(camera=Camera(center=(0.0, 511.5, 511.5), zoom=0.5281484082397004, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(1.0, 1.0), scaled=True, size=10, style=<CursorStyle.STANDARD: 'standard'>), dims=Dims(ndim=2, ndisplay=2, last_used=0, range=((0.0, 1024.0, 1.0), (0.0, 1024.0, 1.0)), current_step=(511, 511), order=(0, 1), axis_labels=('0', '1')), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'labels_to_train' at 0x2ba32e9e9d0>], help='use <1> for activate the label eraser, use <2> for activate the paint brush, use <3> for activate the fill bucket, use <4> for pick mode', status='Ready', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002B806E90310>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={},

Exctract the features using DINOv2 and/or VGG16 and use them to train a random forest classifier.

In [57]:
train = train_dino_forest(image_to_train, labels_to_train,
                          upscale_order=upscale_order, pad_mode=pad_mode, extra_pads=extra_pads, scales=scales,
                          dinov2_model=dinov2_model, dinov2_layers=dinov2_layers, vgg16_layers=vgg16, append_image_as_feature=image_as_feature,
                          show_napari=True)
random_forest, image_train, labels_train, features_space_train = train



## Predict

Load an image to predict the labels for using the trained model above.

In [45]:
image_to_pred = skimage.data.cells3d()
image_to_pred = image_to_pred[40, 1][:,125:251]
ground_truth = plt.imread('images_and_labels/cells_cross_ground_truth.tif')

# crop = ((20,248), (50,278))
# crop = ((20,20+224), (0,224))
# image_to_pred = image_to_pred[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]

# LOAD AN IMAGE TO PREDICT BASED ON THE CLASSIFIER TRAINED ON THE ASTRONAUT IMAGE
# image_to_pred = skimage.data.camera()
# ground_truth = plt.imread('images_and_labels/cam_ground_truth.tif')
# image_to_pred = skimage.data.cat()
# image_to_pred = skimage.data.horse().astype(np.int32)
# image_to_pred = skimage.data.binary_blobs().astype(np.int32)
# image_to_pred = skimage.data.coins()
# ground_truth = None


Viewer(camera=Camera(center=(0.0, 255.5, 255.5), zoom=1.0562968164794009, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(1.0, 1.0), scaled=True, size=10, style=<CursorStyle.STANDARD: 'standard'>), dims=Dims(ndim=2, ndisplay=2, last_used=0, range=((0.0, 512.0, 1.0), (0.0, 512.0, 1.0)), current_step=(255, 255), order=(0, 1), axis_labels=('0', '1')), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'ground_truth' at 0x2b959ec8730>], help='use <1> for activate the label eraser, use <2> for activate the paint brush, use <3> for activate the fill bucket, use <4> for pick mode', status='Ready', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002B806E90310>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, keym

Exctract the features and use them together with the trained classifier to make a prediciton for the labels.

In [46]:
pred = predict_dino_forest(image_to_pred, random_forest, ground_truth=ground_truth,
                           upscale_order=upscale_order, pad_mode=pad_mode, extra_pads=extra_pads, scales=scales,
                           dinov2_model=dinov2_model, dinov2_layers=dinov2_layers, vgg16_layers=vgg16, append_image_as_feature=image_as_feature,
                           show_napari=True)
predicted_labels, image_pred, features_space_pred, acc = pred



## Selfpredict

We can also directly do a training and prediction on the same image (extracting the features only once).

In [53]:
self_pred_image = skimage.data.astronaut()#[0:504,0:504,:]
self_pred_labels = plt.imread('images_and_labels/astro_labels.tif')#[0:504,0:504]
ground_truth = plt.imread('images_and_labels/astro_ground_truth.tif')

# self_pred_image = image_to_train
# self_pred_labels = labels_to_train
# # ground_truth = None

self_pred = selfpredict_dino_forest(self_pred_image, self_pred_labels, ground_truth,
                                    upscale_order=upscale_order, pad_mode=pad_mode, extra_pads=extra_pads, scales=scales,
                                    dinov2_model=dinov2_model, dinov2_layers=dinov2_layers, vgg16_layers=vgg16, append_image_as_feature=image_as_feature,
                                    show_napari=True)
predicted_labels, image_scaled, labels_scaled, feature_space, acc = self_pred

## Tests against ground truth

In [35]:
image_to_train = skimage.data.astronaut()#[200:300,250:400]
labels_to_train = plt.imread('images_and_labels/astro_labels.tif')#[200:300,250:400]
image_to_pred = None #skimage.data.camera()
ground_truth = plt.imread('images_and_labels/astro_ground_truth.tif')#[200:300,250:400]

# viewer = napari.Viewer()
# viewer.add_image(image_to_train)
# viewer.add_labels(labels_to_train)
# viewer_2 = napari.Viewer()
# viewer_2.add_image(image_to_pred)
# viewer_2.add_labels(ground_truth)

all_vggs = [0,2,5,7,10,12,14,17,19,21,24,26,28]
single_vggs = [[i] for i in all_vggs]
consecutive_vggs = [all_vggs[:s] for s in range(1,len(all_vggs))]
dual_vggs = [[all_vggs[i], all_vggs[j]] for i in range(len(all_vggs)) for j in range(i+1, len(all_vggs))]

dino_models = [None, 's']#, 'b']
dino_layer_comboss = [()] #[[8, 9, 10, 11], [11]]
vgg_layer_combos = [None, [2, 7]] #, [24,26,28]] #, all_vggs]#[0], [10], [17], [24], [0, 10, 17, 24]]#
im_feats = [False] #, True]
scale_combos = [(), [1.5]] #[[1,2]] #, [1,2]]

In [36]:
test = test_dino_forest(image_to_train, labels_to_train, ground_truth, image_to_pred,
                        dinov2_models=dino_models, dinov2_layer_combos=dino_layer_comboss, scale_combos=scale_combos, vgg16_layer_combos=vgg_layer_combos, im_feats=im_feats,
                        print_avg=True, print_max=True)
accs, avg_accs, max_acc = test

Running tests for DINOv2 model None, layers ()...
    Running tests for VGG16 layers None...
        Running tests without image as feature and with scale combination ()...
        Running tests without image as feature and with scale combination [1.5]...
    Running tests for VGG16 layers [2, 7]...
        Running tests without image as feature and with scale combination ()...
        Running tests without image as feature and with scale combination [1.5]...
Running tests for DINOv2 model s, layers ()...
    Running tests for VGG16 layers None...
        Running tests without image as feature and with scale combination ()...
        Running tests without image as feature and with scale combination [1.5]...
    Running tests for VGG16 layers [2, 7]...
        Running tests without image as feature and with scale combination ()...
        Running tests without image as feature and with scale combination [1.5]...

--- AVERAGES ---
Average accuracy for DINOv2 model None, layers (): 82.25%

In [40]:
# plt.imshow(image_to_train)
# im_pad = pad_to_patch(image_to_train, "top", "left", pad_mode="symmetric", extra_pad=15, patch_size=(1,1))
# plt.imshow(im_pad)
# np.set_printoptions(linewidth=np.inf)
# print(im_pad.shape, "\n", im_pad[:30,:30])
# print(accs)
acc_3d = accs[:,:,0,:]
print(acc_3d)
napari.imshow(acc_3d)

[[[0.         0.        ]
  [0.83196259 0.81309128]]

 [[0.96504593 0.96849823]
  [0.96096039 0.9721756 ]]]


(Viewer(camera=Camera(center=(0.0, 0.0, 0.5), zoom=222.775, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(0.0, 0.0, 0.0, 0.0), scaled=True, size=1, style=<CursorStyle.STANDARD: 'standard'>), dims=Dims(ndim=4, ndisplay=2, last_used=0, range=((0.0, 2.0, 1.0), (0.0, 2.0, 1.0), (0.0, 1.0, 1.0), (0.0, 2.0, 1.0)), current_step=(0, 0, 0, 0), order=(0, 1, 2, 3), axis_labels=('0', '1', '2', '3')), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Image layer 'accs' at 0x25d76d834c0>], help='use <2> for transform', status='Ready', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x0000025D19833B80>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, keymap={}),
 <Image layer 'accs' at 0x25d76d834c0>)