In [None]:
from dinotool import DinoToolModel
from PIL import Image
import matplotlib.pyplot as plt

images= [Image.open("../test/data/bird1.jpg"),
         Image.open("../test/data/imagefolder/suo.tif"),
         Image.open("../test/data/pepper.png"),
         Image.open("../test/data/still_life.png")
]


dinov3s = DinoToolModel(model_name="dinov3-s")
dinov3 = DinoToolModel(model_name="dinov3-b")
dinov3_sat = DinoToolModel(model_name="dinov3-l-sat")

dinov2s = DinoToolModel(model_name="dinov2-s")
dinov2 = DinoToolModel(model_name="dinov2-b")
radio = DinoToolModel(model_name="radio-b")
siglip2 = DinoToolModel(model_name="siglip2")

In [None]:
models = [dinov3s, dinov3, dinov3_sat, dinov2s, dinov2, radio, siglip2]
model_names = [model.model_name_shortcut for model in models]
model_names = ["DINOv3-S\ndinov3-s",
               "DINOv3-B\ndinov3-b",
               "DINOv3-L (SAT)\ndinov3-l-sat",
               "DINOv2-S\nvit-s",
               "DINOv2-B\nvit-b",
               "C-RADIOv3-B\nradio-b",
               "SigLIP2-B\nsiglip2"]
image_names = ["bird1", "suo", "pepper", "still_life"]
img_sizes = [img.size for img in images]

In [None]:
pcas = []
for model in models:
    for img in images:
        img = img.convert("RGB")
        print(model.model_name, img.size)
        img_tensor = model.get_transform(img.size).transform(img).unsqueeze(0)
        features = model(img_tensor)
        pca_array = model.pca(features)
        pcas.append(pca_array)

In [None]:
fig, ax = plt.subplots(len(image_names), len(model_names)+1, figsize=(8, 5))

# Titles for each model column
ax[0][0].set_title("Original image", fontsize=10)
for j, model_name in enumerate(model_names):
    ax[0][j+1].set_title(model_name, fontsize=10)

for i, img in enumerate(images):
    ax[i][0].imshow(img)
    ax[i][0].set_ylabel(img_sizes[i], fontsize=10)  # or "Image"
    ax[i][0].xaxis.set_visible(False)
    ax[i][0].tick_params(left=False, bottom=False, labelleft=False)

n = 0
for j, model_name in enumerate(model_names):
    # First column = original images
    # Fill in remaining columns with pcas
    for i, image_size in enumerate(img_sizes):
        ax[i][j+1].imshow(pcas[n])
        ax[i][j+1].xaxis.set_visible(False)
        ax[i][j+1].yaxis.set_visible(False)
        ax[i][j+1].tick_params(left=False, bottom=False, labelleft=False)
        n += 1

plt.tight_layout()
plt.savefig("resources/model_comparison.png", dpi=300)