# DINOv3 Evaluation on CREMI (ssTEM neurite segmentation)

The **CREMI** dataset contains serial section transmission electron microscopy (ssTEM) volumes of the adult *Drosophila* brain. Training samples A, B and C span approximately 5×5×5 μm at 4×4×40 nm resolution and include neurite membrane annotations and synaptic cleft labels【94900813976022†L69-L83】. The challenge evaluates semantic segmentation of neurites and synapses, with performance measured by F1 score and object classification accuracy【479692440647223†L33-L44】.

This notebook demonstrates how to use the `neuros` package to extract DINOv3 features and train simple segmentation models. We first show how one might download and load the raw data; if you are running in an environment without internet access you can skip those cells. For demonstration purposes we generate a synthetic dataset that mimics the modality of this dataset and perform a lightweight comparison of two DINOv3 backbones (ConvNeXt-Tiny and ViT-Large) on a binary segmentation task.

## Downloading the data

If you have network access and the necessary permissions, you can download the dataset using the following commands. These lines are commented out by default to prevent accidental downloads in restricted environments.

```bash
# Download the CREMI dataset (approx 200 MB).
# !wget -O cremi.zip http://hp06.mindhackers.org/rhoana_product/dataset/cremi.zip
# !unzip cremi.zip -d ./cremi_data
```

## Loading the data

The following code shows how you might load the downloaded files into Python. Adjust the paths as necessary.

```python
# Example: load HDF5 volumes from CREMI.
# import h5py
# with h5py.File('cremi_data/sample_A_20160501.hdf', 'r') as f:
#     volume = f['volumes/raw'][:]  # numpy array of EM slices

```

In [None]:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score

from neuros.plugins.cv.dinov3_backbone import DINOv3Backbone


def generate_synthetic_dataset(num_images: int, size: int = 128, mode: str = "em"):
    """Generate a synthetic dataset for the specified modality."""
    images, masks = [], []
    rng = np.random.default_rng(123)
    for _ in range(num_images):
        if mode == "em":
            # Round neurites on noisy background
            img = rng.normal(0.5, 0.15, size=(size, size, 1)).clip(0,1)
            mask = np.zeros((size,size), int)
            for _ in range(rng.integers(3,7)):
                cx,cy = rng.integers(0,size,2)
                r = rng.integers(size//20, size//10)
                Y, X = np.ogrid[:size,:size]
                circle = (X-cx)**2 + (Y-cy)**2 <= r**2
                mask[circle] = 1
                img[circle] += 0.5
            img = img.clip(0,1)
            img_rgb = np.repeat(img,3,axis=2)
        elif mode == "mri":
            # Smooth gradient with circular lesion
            x = np.linspace(-1,1,size)
            y = np.linspace(-1,1,size)
            X,Y = np.meshgrid(x,y)
            img = 0.5 + 0.5*(X+Y)/2
            mask = ((X**2 + Y**2) < 0.2**2).astype(int)
            img[mask==1] = 1.0
            img_rgb = np.repeat(img[:,:,None],3,axis=2)
        elif mode == "histology":
            # Textured tissue with darker nucleus
            base = rng.uniform(0.8,1.0, size=(size,size,3))
            noise = rng.normal(0,0.05, size=(size,size,3))
            img_rgb = (base + noise).clip(0,1)
            mask = np.zeros((size,size),int)
            cx,cy = rng.integers(size//4,3*size//4,2)
            r = size//6
            Y,X = np.ogrid[:size,:size]
            nucleus = (X-cx)**2 + (Y-cy)**2 <= r**2
            mask[nucleus] = 1
            img_rgb[nucleus] -= 0.4
            img_rgb = img_rgb.clip(0,1)
        elif mode == "connectomics":
            # Curvilinear neurites
            img = rng.normal(0.5,0.15,(size,size,1)).clip(0,1)
            mask = np.zeros((size,size),int)
            num_strands = rng.integers(2,5)
            Y = np.arange(size)
            for _ in range(num_strands):
                amp = rng.uniform(5,20)
                freq = rng.uniform(0.02,0.1)
                phase = rng.uniform(0,2*np.pi)
                center = rng.uniform(size*0.2,size*0.8)
                x_vals = center + amp*np.sin(freq*Y + phase)
                for y,x_center in enumerate(x_vals.astype(int)):
                    x_center = np.clip(x_center, 0, size-1)
                    width = rng.integers(size//50, size//30)
                    xs = max(x_center - width,0)
                    xe = min(x_center + width, size-1)
                    mask[y, xs:xe] = 1
                    img[y,xs:xe,0] = 1.0
            img_rgb = np.repeat(img,3,axis=2)
        elif mode == "atlas":
            # Coarse anatomical regions with colour shifts
            img_rgb = rng.uniform(0.4,0.8,(size,size,3))
            mask = np.zeros((size,size),int)
            num_regions = rng.integers(3,6)
            centers = rng.integers(size//4,3*size//4,(num_regions,2))
            radii = rng.integers(size//10,size//5,num_regions)
            labels = rng.permutation(num_regions)
            Y,X = np.ogrid[:size,:size]
            for idx,(cx,cy) in enumerate(centers):
                region = (X-cx)**2 + (Y-cy)**2 <= radii[idx]**2
                mask[region] = labels[idx] + 1
                color_shift = rng.uniform(-0.1,0.1,3)
                img_rgb[region] = np.clip(img_rgb[region] + color_shift,0,1)
        elif mode == "calcium":
            # Bright circular ROIs on noisy background
            img = rng.normal(0.4,0.1,(size,size,1)).clip(0,1)
            mask = np.zeros((size,size),int)
            num_cells = rng.integers(3,8)
            for _ in range(num_cells):
                cx,cy = rng.integers(0,size,2)
                r = rng.integers(size//30,size//15)
                Y,X = np.ogrid[:size,:size]
                cell = (X-cx)**2 + (Y-cy)**2 <= r**2
                mask[cell] = 1
                img[cell] += 0.6
            img = img.clip(0,1)
            img_rgb = np.repeat(img,3,axis=2)
        elif mode == "tracking":
            # Overlapping ellipses representing cells
            img = rng.normal(0.5,0.1,(size,size,1)).clip(0,1)
            mask = np.zeros((size,size),int)
            num_cells = rng.integers(5,10)
            for label in range(1,num_cells+1):
                cx,cy = rng.integers(0,size,2)
                rx = rng.integers(size//40,size//20)
                ry = rng.integers(size//40,size//20)
                angle = rng.uniform(0,np.pi)
                Yg,Xg = np.ogrid[:size,:size]
                x_rot = (Xg-cx)*np.cos(angle)+(Yg-cy)*np.sin(angle)
                y_rot = -(Xg-cx)*np.sin(angle)+(Yg-cy)*np.cos(angle)
                ellipse = (x_rot/rx)**2 + (y_rot/ry)**2 <= 1
                mask[ellipse] = 1
                img[ellipse] += 0.5
            img = img.clip(0,1)
            img_rgb = np.repeat(img,3,axis=2)
        else:
            raise ValueError(f"Unknown mode: {mode}")
        images.append((img_rgb*255).astype(np.uint8))
        masks.append((mask>0).astype(int))
    return images, masks


def flatten_dataset(images, masks, backbone):
    """Flatten patch features and binary labels for a dataset."""
    all_feats, all_labels = [], []
    for img, mask in zip(images, masks):
        pf = backbone.embed([img])[0]
        gs = backbone.grid_size
        ps = backbone.patch_size
        mask_crop = mask[:gs*ps, :gs*ps]
        patch_mask = mask_crop.reshape(gs, ps, gs, ps)
        labels = (patch_mask.sum(axis=(1,3)) > (ps*ps/2)).astype(int)
        all_feats.append(pf)
        all_labels.append(labels.flatten())
    return np.concatenate(all_feats), np.concatenate(all_labels)


def evaluate_backbones(mode: str):
    """Train and evaluate ConvNeXt-Tiny and ViT-Large on the given synthetic modality."""
    results = {}
    for back_name, model_id in [
        ('CNX-T', 'facebook/dinov3-convnext-tiny-pretrain-lvd1689m'),
        ('ViT-L', 'facebook/dinov3-vit-large-pretrain-lvd1689m')
    ]:
        backbone = DINOv3Backbone(model_id=model_id)
        # generate a small training and test set
        train_images, train_masks = generate_synthetic_dataset(8, mode=mode)
        test_images, test_masks = generate_synthetic_dataset(4, mode=mode)
        X_train, y_train = flatten_dataset(train_images, train_masks, backbone)
        X_test, y_test = flatten_dataset(test_images, test_masks, backbone)
        # handle single class case
        if len(np.unique(y_train)) < 2:
            majority = y_train[0]
            preds = np.full_like(y_test, majority)
            acc = accuracy_score(y_test, preds)
            f1 = 0.0
        else:
            clf = LogisticRegression(max_iter=300)
            clf.fit(X_train, y_train)
            preds = clf.predict(X_test)
            acc = accuracy_score(y_test, preds)
            f1 = f1_score(y_test, preds)
        results[back_name] = (acc, f1)
    return results


In [None]:

# Evaluate DINOv3 backbones on a synthetic representation of this dataset
results = evaluate_backbones("em")
for model, (acc, f1) in results.items():
    print(f"{model} accuracy: {acc:.3f}, F1 score: {f1:.3f}")

models = list(results.keys())
accs = [results[m][0] for m in models]
f1s = [results[m][1] for m in models]

plt.figure(figsize=(6,3))
plt.bar(models, accs)
plt.title("Segmentation accuracy on synthetic em dataset")
plt.ylabel("Accuracy")
plt.ylim(0,1)
plt.show()

plt.figure(figsize=(6,3))
plt.bar(models, f1s)
plt.title("F1 score on synthetic em dataset")
plt.ylabel("F1 score")
plt.ylim(0,1)
plt.show()
