# DINOv3 Experiments on Multi‑Modal Synthetic Datasets

This notebook expands our previous experiments by incorporating two
additional synthetic modalities:

* **Connectomics** – long, curvilinear structures that approximate
  neurites in 3D electron microscopy data.
* **Atlas** – coarse anatomical regions inspired by reference atlases
  such as the Allen Mouse Brain Atlas.

We compare two DINOv3 backbones – ConvNeXt‑Tiny and ViT‑Large – on
segmentation, cross‑modality generalisation and registration.  The
underlying feature extractor is the placeholder implementation from
neuros, so results are illustrative rather than definitive.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score

from neuros.plugins.cv.dinov3_backbone import DINOv3Backbone
from neuros.plugins.cv.feature_matching import patch_correlation, estimate_translation


def generate_synthetic_dataset(num_images: int, size: int = 128, mode: str = "em"):
    '''Generate synthetic datasets for different modalities.

    Modes:
        em          - random noise with bright spots (neurites)
        mri         - smooth gradient with circular lesion
        histology   - texture with darker nucleus region
        connectomics- curvilinear strands on noise (neurite tracing)
        atlas       - coarse anatomical regions delineated by boundaries
    '''
    images = []
    masks = []
    rng = np.random.default_rng(123)
    for _ in range(num_images):
        if mode == "em":
            img = rng.normal(loc=0.5, scale=0.15, size=(size, size, 1)).clip(0, 1)
            mask = np.zeros((size, size), dtype=np.int32)
            for _ in range(rng.integers(3, 7)):
                cx, cy = rng.integers(0, size, size=2)
                r = rng.integers(size//20, size//10)
                Y, X = np.ogrid[:size, :size]
                circle = (X - cx)**2 + (Y - cy)**2 <= r**2
                mask[circle] = 1
                img[circle] += 0.5
            img = img.clip(0, 1)
            img_rgb = np.repeat(img, 3, axis=2)
        elif mode == "mri":
            x = np.linspace(-1, 1, size)
            y = np.linspace(-1, 1, size)
            X, Y = np.meshgrid(x, y)
            img = 0.5 + 0.5 * (X + Y) / 2
            mask = ((X**2 + Y**2) < 0.2**2).astype(np.int32)
            img[mask == 1] = 1.0
            img_rgb = np.repeat(img[:, :, None], 3, axis=2)
        elif mode == "histology":
            base = rng.uniform(0.8, 1.0, size=(size, size, 3))
            noise = rng.normal(0, 0.05, size=(size, size, 3))
            img_rgb = (base + noise).clip(0, 1)
            mask = np.zeros((size, size), dtype=np.int32)
            cx, cy = rng.integers(size//4, 3*size//4, size=2)
            r = size // 6
            Y, X = np.ogrid[:size, :size]
            nucleus = (X - cx)**2 + (Y - cy)**2 <= r**2
            mask[nucleus] = 1
            img_rgb[nucleus] -= 0.4
            img_rgb = img_rgb.clip(0, 1)
        elif mode == "connectomics":
            # Create long curvilinear structures on a noisy background
            img = rng.normal(0.5, 0.15, size=(size, size, 1)).clip(0, 1)
            mask = np.zeros((size, size), dtype=np.int32)
            # Draw random sine wave strands across the image
            num_strands = rng.integers(2, 5)
            Y = np.arange(size)
            for _ in range(num_strands):
                amp = rng.uniform(5, 20)
                freq = rng.uniform(0.02, 0.1)
                phase = rng.uniform(0, 2*np.pi)
                center = rng.uniform(size*0.2, size*0.8)
                x_vals = center + amp * np.sin(freq * Y + phase)
                for y, x_center in enumerate(x_vals.astype(np.int32)):
                    x_center = np.clip(x_center, 0, size-1)
                    width = rng.integers(size//50, size//30)
                    x_start = int(max(x_center - width, 0))
                    x_end = int(min(x_center + width, size-1))
                    mask[y, x_start:x_end] = 1
                    img[y, x_start:x_end, 0] = 1.0
            img_rgb = np.repeat(img, 3, axis=2)
        elif mode == "atlas":
            # Generate coarse anatomical regions as distinct intensity blobs
            img_rgb = rng.uniform(0.4, 0.8, size=(size, size, 3))
            mask = np.zeros((size, size), dtype=np.int32)
            num_regions = rng.integers(3, 6)
            centers = rng.integers(size//4, 3*size//4, size=(num_regions, 2))
            radii = rng.integers(size//10, size//5, size=num_regions)
            labels = rng.permutation(num_regions)
            Y, X = np.ogrid[:size, :size]
            for idx, (cx, cy) in enumerate(centers):
                region = (X - cx)**2 + (Y - cy)**2 <= radii[idx]**2
                mask[region] = labels[idx] + 1  # non-zero labels
                color_shift = rng.uniform(-0.1, 0.1, size=3)
                img_rgb[region] = np.clip(img_rgb[region] + color_shift, 0, 1)
        else:
            raise ValueError(f"Unknown mode: {mode}")
        images.append((img_rgb * 255).astype(np.uint8))
        # For atlas and connectomics, we consider foreground vs background only
        masks.append((mask > 0).astype(np.int32))
    return images, masks


def flatten_dataset(images, masks, backbone):
    '''Flatten a dataset into patch features and labels for segmentation.'''
    all_feats = []
    all_labels = []
    for img, mask in zip(images, masks):
        pf = backbone.embed([img])[0]
        gs = backbone.grid_size
        ps = backbone.patch_size
        mask_crop = mask[:gs*ps, :gs*ps]
        patch_mask = mask_crop.reshape(gs, ps, gs, ps)
        labels = (patch_mask.sum(axis=(1, 3)) > (ps*ps/2)).astype(np.int32)
        all_feats.append(pf)
        all_labels.append(labels.flatten())
    return np.concatenate(all_feats), np.concatenate(all_labels)


def train_segmentation(backbone_id, train_mode, test_mode):
    '''Train on one modality and test on another.'''
    backbone = DINOv3Backbone(model_id=backbone_id)
    images_train, masks_train = generate_synthetic_dataset(10, mode=train_mode)
    images_test, masks_test = generate_synthetic_dataset(5, mode=test_mode)
    X_train, y_train = flatten_dataset(images_train, masks_train, backbone)
    X_test, y_test = flatten_dataset(images_test, masks_test, backbone)
    # Handle single-class case
    if len(np.unique(y_train)) < 2:
        majority = y_train[0]
        preds = np.full_like(y_test, majority)
        acc = accuracy_score(y_test, preds)
        f1 = 0.0
    else:
        clf = LogisticRegression(max_iter=300)
        clf.fit(X_train, y_train)
        preds = clf.predict(X_test)
        acc = accuracy_score(y_test, preds)
        f1 = f1_score(y_test, preds)
    return acc, f1


def evaluate_registration(backbone_id, mode, shift=(1, 1)):
    '''Generate a pair of images with a known shift and estimate translation.'''
    backbone = DINOv3Backbone(model_id=backbone_id)
    images, _ = generate_synthetic_dataset(1, mode=mode)
    img = images[0]
    dy, dx = shift
    # Create a shifted copy by rolling the image; fill wrap-around with zeros
    shifted = np.roll(img, shift=(dy*backbone.patch_size, dx*backbone.patch_size), axis=(0, 1))
    if dy > 0:
        shifted[:dy*backbone.patch_size, :] = 0
    elif dy < 0:
        shifted[dy*backbone.patch_size:, :] = 0
    if dx > 0:
        shifted[:, :dx*backbone.patch_size] = 0
    elif dx < 0:
        shifted[:, dx*backbone.patch_size:] = 0
    # Extract features
    pf_a = backbone.embed([img])[0]
    pf_b = backbone.embed([shifted])[0]
    sim = patch_correlation(pf_a, pf_b)
    est_dy, est_dx = estimate_translation(sim, backbone.grid_size)
    return (dy, dx), (est_dy, est_dx)


In [None]:
# Segmentation across modalities
results_seg = {}
modalities = ["em", "mri", "histology", "connectomics", "atlas"]
models = {"CNX-Tiny": "cnx-tiny", "ViT-Large": "vit-large"}
for train_mod in modalities:
    results_seg[train_mod] = {}
    for test_mod in modalities:
        results_seg[train_mod][test_mod] = {}
        for name, model_id in models.items():
            acc, f1 = train_segmentation(model_id, train_mod, test_mod)
            results_seg[train_mod][test_mod][name] = (acc, f1)
            print(f"{name} train {train_mod} -> test {test_mod}: acc={acc:.3f}, f1={f1:.3f}")
results_seg

In [None]:
# Registration evaluation
results_reg = {}
modalities = ["em", "mri", "histology", "connectomics", "atlas"]
shifts = [(1, 1), (0, 2), (-2, -1)]
for mode in modalities:
    results_reg[mode] = {}
    for name, model_id in models.items():
        # Evaluate different shifts and average absolute error
        errors = []
        for shift in shifts:
            true_shift, est_shift = evaluate_registration(model_id, mode, shift=shift)
            err = (abs(true_shift[0] - est_shift[0]) + abs(true_shift[1] - est_shift[1]))
            errors.append(err)
            print(f"{name} on {mode} shift {true_shift} -> estimated {est_shift}")
        mean_error = float(np.mean(errors))
        results_reg[mode][name] = mean_error
    print()
results_reg

In [None]:
# Visualise segmentation cross-modality accuracy and F1
import numpy as np
import matplotlib.pyplot as plt
# Build matrices for accuracy and F1 for each model
metrics = {"accuracy": 0, "f1": 1}
for metric_name, idx in metrics.items():
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    for j, (model_name, _) in enumerate(models.items()):
        data = []
        for train_mod in modalities:
            row = []
            for test_mod in modalities:
                row.append(results_seg[train_mod][test_mod][model_name][idx])
            data.append(row)
        data = np.array(data)
        ax = axes[j]
        im = ax.imshow(data, vmin=0, vmax=1)
        ax.set_xticks(range(len(modalities)))
        ax.set_xticklabels(modalities, rotation=45)
        ax.set_yticks(range(len(modalities)))
        ax.set_yticklabels(modalities)
        ax.set_xlabel("Test modality")
        ax.set_ylabel("Train modality")
        ax.set_title(f"{metric_name.capitalize()} - {model_name}")
        # Annotate cells
        for (i, k), val in np.ndenumerate(data):
            ax.text(k, i, f"{val:.2f}", ha='center', va='center', color='white' if val < 0.5 else 'black')
    fig.suptitle(f"Cross-modality {metric_name.capitalize()}")
    plt.tight_layout()
    plt.show()


In [None]:
# Plot registration errors
import matplotlib.pyplot as plt
labels = modalities
cnx_errors = [results_reg[mod]["CNX-Tiny"] for mod in modalities]
vit_errors = [results_reg[mod]["ViT-Large"] for mod in modalities]
width = 0.35
fig, ax = plt.subplots(figsize=(8, 4))
ax.bar(np.arange(len(labels)) - width/2, cnx_errors, width, label='CNX-Tiny')
ax.bar(np.arange(len(labels)) + width/2, vit_errors, width, label='ViT-Large')
ax.set_xticks(np.arange(len(labels)))
ax.set_xticklabels(labels, rotation=45)
ax.set_ylabel('Mean absolute patch shift error')
ax.set_title('Registration performance across modalities')
ax.legend()
plt.tight_layout()
plt.show()
