In [9]:
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
from sklearn.preprocessing import MinMaxScaler

from random import seed

import matplotlib.pyplot as plt
from skimage.color import label2rgb
from tifffile import imread

from vulture.main import get_dv2_model, get_upsampler_and_expr, get_hr_feats
from vulture.utils import to_numpy, do_2D_pca, closest_crop, convert_image
from vulture.feature_prep import get_lr_feats

from interactive_seg_backend.file_handling import load_labels

from types import NoneType


SEED = 10672
seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = "cuda:0"

In [10]:
color_list = [[255, 255, 255], [0, 62, 131], [181, 209, 204], [250, 43, 0], [255, 184, 82]]
COLORS = np.array(color_list) / 255.0

In [11]:
def apply_labels_as_overlay(labels: np.ndarray, img: Image.Image, colors: list, alpha: float=1.0) -> Image.Image:
    labels_unsqueezed = np.expand_dims(labels, -1)

    overlay = label2rgb(labels, colors=colors[1:], kind='overlay', bg_label=0, image_alpha=1, alpha=alpha)
    out = np.where(labels_unsqueezed, overlay * 255, np.array(img)).astype(np.uint8)
    img_with_labels = Image.fromarray(out)
    return img_with_labels


def rescale(a: np.ndarray) -> np.ndarray:
    a_min = a.min(axis=(0, 1), keepdims=True)
    a_max = a.max(axis=(0, 1), keepdims=True)
    out = (a - a_min) / (a_max - a_min)
    return out


In [12]:
model_path = "../trained_models/e5000_full_fit_reg.pth"
cfg_path = "../vulture/models/configs/combined_no_shift.json"

In [None]:
dv2 = get_dv2_model(True, device=DEVICE)

model_path = "../trained_models/fit_reg_f128.pth"
# cfg_path = "../vulture/models/configs/combined_no_shift.json"

upsampler, expr = get_upsampler_and_expr(model_path, None, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main
INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/vit_small_patch14_reg4_dinov2.lvd142m)
INFO:timm.models._hub:[timm/vit_small_patch14_reg4_dinov2.lvd142m] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


[128, 128, 128, 128]


In [15]:
DATA_PATH = "fig_data/natural"

img_names = ['plant', 'bird', 'church', 'balls']
imgs: list[Image.Image] = []
lr_feat_vis: list[np.ndarray] = []
hr_feat_vis: list[np.ndarray] = []
classical_preds: list[np.ndarray] = []
deep_preds: list[np.ndarray] = []

with torch.no_grad():
    for name in img_names:
        img_path = f"{DATA_PATH}/{name}.png"
        img = Image.open(img_path).convert('RGB')

        tr = closest_crop(img.height, img.width, 14, True)


        inp_img = (
            TF.normalize(
                TF.pil_to_tensor(img).to(torch.float32),
                [0.485, 0.456, 0.406],
                [0.229, 0.224, 0.225],
            )
            .unsqueeze(0)
            .to(DEVICE)
        )
        inp_img_dino = convert_image(img, tr, to_half=True, device_str=DEVICE)


        labels = load_labels(f"{DATA_PATH}/{name}_labels.tiff")
        with_labels = apply_labels_as_overlay(labels, img, COLORS)

        imgs.append(with_labels)

        lr_feats, _ = get_lr_feats(dv2, [inp_img_dino], 50, fit3d=True, n_feats_in=expr.n_ch_in)

        lr_feats = F.normalize(lr_feats, p=1, dim=1)
        with torch.autocast(DEVICE, torch.float16):
            hr_feats = upsampler(inp_img, lr_feats)
        
        lr_feats_np = lr_feats.cpu()[0].numpy().astype(np.float32)
        hr_feats_np = hr_feats.cpu()[0].numpy().astype(np.float32)

        lr_feats_red = lr_feats_np.transpose((1, 2, 0))[:, :, 0:3]
        lr_feats_red = rescale(lr_feats_red)
        hr_feats_red = hr_feats_np.transpose((1, 2, 0))[:, :, 0:3]
        hr_feats_red = rescale(hr_feats_red)

        lr_feat_vis.append(lr_feats_red)
        hr_feat_vis.append(hr_feats_red)

        classical_pred =  load_labels(f"{DATA_PATH}/{name}_seg_classical.tiff") + 1
        classical_preds.append(label2rgb(classical_pred, None, COLORS[1:]))

        deep_pred =  load_labels(f"{DATA_PATH}/{name}_seg_deep.tiff") + 1
        deep_preds.append(label2rgb(deep_pred, None, COLORS[1:]))

In [19]:
%%capture
plt.rcParams["font.family"] = "serif"
TITLE_FS = 25
N_ROWS = len(img_names) 
N_COLS = 5
FIG_W = 3

titles = ["Image + labels", "LR features", "HR ViT", "Classical seg.", "+HR ViT seg."]
fig, axs = plt.subplots(nrows=N_ROWS, ncols=N_COLS, figsize=(N_COLS * FIG_W, N_ROWS * FIG_W))
for i in range(N_ROWS):

    arrs = (imgs[i], lr_feat_vis[i], hr_feat_vis[i], classical_preds[i], deep_preds[i])

    for j, arr in enumerate(arrs):
        ax = axs[i,j]
        ax.set_axis_off()
        ax.imshow(arr)

        if i == 0:
            ax.set_title(titles[j], fontsize=TITLE_FS)

plt.tight_layout()
plt.savefig('fig_out/supp_natural_examples.png' ,bbox_inches='tight')