- generate all features beforehand and store: may have to cache in .tmp dir
- train N classifiers over the N sets of (increasing) labels
- apply each N classifier, compute miou, store

In [1]:
import numpy as np
import torch
from os import listdir
from random import seed

from vulture.main import get_dv2_model, get_upsampler_and_expr

from interactive_seg_backend.configs import FeatureConfig, TrainingConfig


from is_helpers import AllowedDatasets, eval_preds, get_pca_over_images_or_dir, get_and_cache_features_over_images, train_model_over_images, apply_model_over_images
    
from typing import Literal

SEED = 10672
np.random.seed(SEED)
torch.manual_seed(SEED)
seed(SEED)
DEVICE = "cuda:1"

N CPUS: 110


In [2]:
dv2 = get_dv2_model(True, device=DEVICE)

model_path = "../trained_models/e5000_full_fit_reg.pth"
cfg_path = "../vulture/models/configs/combined_no_shift.json"

upsampler, expr = get_upsampler_and_expr(model_path, cfg_path, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main


[128, 128, 128, 128]


In [3]:
SAVE: bool = False
PATH = "fig_data/is_benchmark"
AllowedDatasets = Literal["Ni_superalloy_SEM", "T_cell_TEM", "Cu_ore_RLM"]
dataset: tuple[AllowedDatasets, ...] = ("Ni_superalloy_SEM", "T_cell_TEM", "Cu_ore_RLM")

chosen_dataset = "Cu_ore_RLM"
fnames = sorted(listdir(f"{PATH}/{chosen_dataset}/images/"))
images = [f"{PATH}/{chosen_dataset}/images/{fname}" for fname in fnames]

In [4]:
pca = get_pca_over_images_or_dir(images, dv2)

feat_cfg = FeatureConfig()
train_cfg = TrainingConfig(feat_cfg, n_samples=-1, add_dino_features=True, classifier='xgb', classifier_params = {"class_weight": "balanced", "max_depth": 32,})
classical_train_cfg = TrainingConfig(feat_cfg, n_samples=-1, add_dino_features=False, classifier='xgb', classifier_params = {"class_weight": "balanced", "max_depth": 32,})

In [5]:
get_and_cache_features_over_images(chosen_dataset, train_cfg, '.tmp', PATH, dv2, upsampler, expr, pca)
get_and_cache_features_over_images(chosen_dataset, classical_train_cfg, '.tmp_classical', PATH, dv2, upsampler, expr, pca)

In [6]:
TRAIN_IMG_FNAMES: dict[AllowedDatasets, list[str]] = {"Cu_ore_RLM": ["004", "028", "049", "077"], 
                                                      "Ni_superalloy_SEM": ["000", "001", "005", "007"], 
                                                      "T_cell_TEM": ["000", "005", "007", "026"]
                                                      }

In [7]:
base_labels = TRAIN_IMG_FNAMES[chosen_dataset] #["000", "001", "005", "007"]
all_label_paths = sorted(listdir(f"{PATH}/{chosen_dataset}/labels"))
all_label_fnames = [fname.split('.')[0] for fname in all_label_paths]

label_fnames = base_labels + [fname for fname in all_label_fnames if fname not in base_labels]

In [8]:
deep_mious, deep_std_mious = [], []
for n_labels in range(4,23):
    selected_labels = label_fnames[:n_labels]
    feat_paths = [f"{PATH}/.tmp/{name.split('.')[0]}.npy" for name in selected_labels]
    classifier, _ = train_model_over_images(chosen_dataset, train_cfg, PATH, selected_labels, dv2, upsampler, expr, feat_paths, overwrite_with_gt=True )

    all_feat_fnames = [f"{PATH}/.tmp/{fname}" for fname in sorted(listdir(f"{PATH}/.tmp"))]
    deep_preds = apply_model_over_images(chosen_dataset, train_cfg, classifier, PATH, dv2, upsampler, expr, False, -1, pca, all_feat_fnames)
    miou, std_miou = eval_preds(chosen_dataset, deep_preds, PATH)
    print(f"({n_labels:2d}/22): {miou:.4f} +/-{std_miou:.4f}")
    deep_mious.append(miou)
    deep_std_mious.append(std_miou)

( 4/22): 0.8711 +/-0.0332
( 5/22): 0.8730 +/-0.0321
( 6/22): 0.8727 +/-0.0301
( 7/22): 0.8780 +/-0.0299
( 8/22): 0.8800 +/-0.0318
( 9/22): 0.8821 +/-0.0320
(10/22): 0.8835 +/-0.0325
(11/22): 0.8859 +/-0.0345
(12/22): 0.8883 +/-0.0344
(13/22): 0.8923 +/-0.0303
(14/22): 0.8956 +/-0.0300
(15/22): 0.8963 +/-0.0301
(16/22): 0.8966 +/-0.0302
(17/22): 0.8976 +/-0.0304
(18/22): 0.8998 +/-0.0295
(19/22): 0.8993 +/-0.0303
(20/22): 0.9000 +/-0.0295
(21/22): 0.9020 +/-0.0290
(22/22): 0.9027 +/-0.0290


In [9]:
# Ni (08/07/25):
# ( 4/22): 0.7488 +/-0.1069
# ( 5/22): 0.7612 +/-0.0930
# ( 6/22): 0.7646 +/-0.0871
# ( 7/22): 0.7709 +/-0.0651
# ( 8/22): 0.7767 +/-0.0674
# ( 9/22): 0.7917 +/-0.0713
# (10/22): 0.7935 +/-0.0708
# (11/22): 0.7950 +/-0.0749
# (12/22): 0.7969 +/-0.0747
# (13/22): 0.8000 +/-0.0767
# (14/22): 0.8060 +/-0.0783
# (15/22): 0.8069 +/-0.0789
# (16/22): 0.8117 +/-0.0814
# (17/22): 0.8128 +/-0.0868
# (18/22): 0.8133 +/-0.0934
# (19/22): 0.8168 +/-0.0984
# (20/22): 0.8170 +/-0.0967
# (21/22): 0.8176 +/-0.0987
# (22/22): 0.8196 +/-0.1000

# T-Cell
# Deep
# ( 4/22): 0.6261 +/-0.1678
# ( 5/22): 0.6422 +/-0.1681
# ( 6/22): 0.6683 +/-0.1707
# ( 7/22): 0.6859 +/-0.1694
# ( 8/22): 0.6963 +/-0.1687
# ( 9/22): 0.7047 +/-0.1690
# (10/22): 0.7081 +/-0.1734
# (11/22): 0.7103 +/-0.1801
# (12/22): 0.7137 +/-0.1871
# (13/22): 0.7346 +/-0.1825
# (14/22): 0.7413 +/-0.1816
# (15/22): 0.7536 +/-0.1789
# (16/22): 0.7730 +/-0.1628
# (17/22): 0.7779 +/-0.1634
# (18/22): 0.7814 +/-0.1636
# (19/22): 0.7881 +/-0.1600
# (20/22): 0.7959 +/-0.1601
# (21/22): 0.8027 +/-0.1578
# (22/22): 0.8099 +/-0.1561

# Classical
# ( 4/22): 0.4520 +/-0.1453
# ( 5/22): 0.4597 +/-0.1459
# ( 6/22): 0.4836 +/-0.1436
# ( 7/22): 0.4941 +/-0.1395
# ( 8/22): 0.5031 +/-0.1399
# ( 9/22): 0.5051 +/-0.1405
# (10/22): 0.5128 +/-0.1441
# (11/22): 0.5245 +/-0.1451
# (12/22): 0.5264 +/-0.1488
# (13/22): 0.5303 +/-0.1491
# (14/22): 0.5339 +/-0.1487
# (15/22): 0.5440 +/-0.1450
# (16/22): 0.5526 +/-0.1461
# (17/22): 0.5586 +/-0.1453
# (18/22): 0.5613 +/-0.1457
# (19/22): 0.5669 +/-0.1426
# (20/22): 0.5723 +/-0.1395
# (21/22): 0.5744 +/-0.1375
# (22/22): 0.5834 +/-0.1343

In [10]:
classical_mious, classical_std_mious = [], []
for n_labels in range(4,23):
    selected_labels = label_fnames[:n_labels]
    feat_paths = [f"{PATH}/.tmp_classical/{name.split('.')[0]}.npy" for name in selected_labels]
    classifier, _ = train_model_over_images(chosen_dataset, classical_train_cfg, PATH, selected_labels, dv2, upsampler, expr, feat_paths, overwrite_with_gt=True )

    all_feat_fnames = [f"{PATH}/.tmp_classical/{fname}" for fname in sorted(listdir(f"{PATH}/.tmp_classical"))]
    classical_preds = apply_model_over_images(chosen_dataset, classical_train_cfg, classifier, PATH, dv2, upsampler, expr, False, -1, pca, all_feat_fnames)
    miou, std_miou = eval_preds(chosen_dataset, classical_preds, PATH)
    print(f"({n_labels:2d}/22): {miou:.4f} +/-{std_miou:.4f}")
    classical_mious.append(miou)
    classical_std_mious.append(std_miou)

( 4/22): 0.8531 +/-0.0548
( 5/22): 0.8550 +/-0.0478
( 6/22): 0.8495 +/-0.0660
( 7/22): 0.8511 +/-0.0600
( 8/22): 0.8517 +/-0.0663
( 9/22): 0.8466 +/-0.0804
(10/22): 0.8474 +/-0.0857
(11/22): 0.8493 +/-0.0909
(12/22): 0.8514 +/-0.0901
(13/22): 0.8798 +/-0.0351
(14/22): 0.8864 +/-0.0335
(15/22): 0.8877 +/-0.0327
(16/22): 0.8894 +/-0.0335
(17/22): 0.8878 +/-0.0402
(18/22): 0.8891 +/-0.0403
(19/22): 0.8904 +/-0.0390
(20/22): 0.8909 +/-0.0391
(21/22): 0.8930 +/-0.0397
(22/22): 0.8941 +/-0.0387


In [11]:
data_dict = {}
data_dict["classical_miou"] = classical_mious
data_dict["classical_miou_std"] = classical_std_mious
data_dict["classical_preds"] = classical_preds

data_dict["deep_miou"] = deep_mious
data_dict["deep_miou_std"] = deep_std_mious
data_dict["deep_preds"] = deep_preds

if SAVE:
    np.save(f"{PATH}/miou_results/{chosen_dataset}.npy", data_dict)

In [12]:
%%capture
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "serif"


TITLE_FS = 25
LABEL_FS = 23
TICK_FS = 21


n_labels = [4 + i for i in range(0, len(deep_mious))]

fig = plt.figure(figsize=(6, 6))
ax = plt.gca()

plt.plot(n_labels, classical_mious, marker='.', lw=3, ms=15, label='Classical')
plt.plot(n_labels, deep_mious, marker='.', lw=3, ms=15, label='+HR ViT')

plt.xlabel('# labelled images', fontsize=LABEL_FS)
plt.ylabel('mIoU', fontsize=LABEL_FS)
ax.tick_params(axis='both', labelsize=TICK_FS)

ax.grid(True, linestyle="--", alpha=0.6)
plt.legend(fontsize=TICK_FS)
plt.tight_layout(pad=2.5)
plt.savefig(f'fig_out/{chosen_dataset}_miou_vs_n_labels.png', bbox_inches='tight')

In [13]:
from PIL import Image
from interactive_seg_backend.file_handling import load_image, load_labels
from skimage.color import label2rgb


# cmap = [
#             "#fafafa",
#             "#1f77b4",
#             "#ff7f0e",
#             "#2ca02c",
#             "#d62728",
#         ]
# color_list = [[255, 255, 255], [31, 119, 180], [255, 127, 14], [44, 160, 44], [255, 0, 0]]
color_list = [[255, 255, 255], [0, 62, 131], [181, 209, 204], [250, 43, 0], [255, 184, 82]]
COLORS = np.array(color_list) / 255.0

def apply_labels_as_overlay(labels: np.ndarray, img: Image.Image, colors: list, alpha: float=1.0) -> Image.Image:
    labels_unsqueezed = np.expand_dims(labels, -1)

    overlay = label2rgb(labels, colors=colors[1:], kind='overlay', bg_label=0, image_alpha=1, alpha=alpha)
    out = np.where(labels_unsqueezed, overlay * 255, np.array(img)).astype(np.uint8)
    img_with_labels = Image.fromarray(out)
    return img_with_labels

def add_inset_zoom(xywh: list[int], fig_xywh: list[float], img_arr: np.ndarray, labels: np.ndarray | None, ax ) -> object:
    x0, y0, w, h = xywh
    H, W, C = img_arr.shape
    inset_data = np.zeros_like(img_arr)
    inset_data[y0:y0+h, x0:x0+w, :] = img_arr[y0:y0+h, x0:x0+w, :]

    axin = ax.inset_axes(
        fig_xywh, xlim=(x0, x0+w), ylim=(y0, y0+h))
    axin.set_xticks([])
    axin.set_yticks([])
    #axin.set_axis_off()
    if labels is not None:
        inset_data = label2rgb(labels, img_arr, COLORS[1:], kind='overlay', alpha=0.6, bg_label=-1)
        axin.imshow(inset_data,)
    else:
        axin.imshow(inset_data, cmap="binary_r",) # cmap="binary_r"
    ax.indicate_inset_zoom(axin, edgecolor="black", lw=2)
    axin.set_ylim((y0 + h, y0))

    axin.patch.set_edgecolor('black')  

    axin.patch.set_linewidth(4)  

    return axin

In [14]:
%%capture
n_examples = 5
width = 2.75
n_cols = 4
fig, axs = plt.subplots(nrows=n_examples, ncols=n_cols, figsize=(width * n_cols, width * n_examples))

step = len(all_label_fnames) // n_examples
fnames = all_label_fnames[::step][:n_examples]


images = {fname: Image.fromarray(load_image(f"{PATH}/{chosen_dataset}/images/{fname}.tif")).convert('RGB') for fname in all_label_fnames}
labels = {fname: (load_labels(f"{PATH}/{chosen_dataset}/labels/{fname}.tif")) for fname in all_label_fnames}
segs = {fname: (load_labels(f"{PATH}/{chosen_dataset}/segmentations/{fname}.tif")) for fname in all_label_fnames}


titles = ["Image + labels", " Ground truth", "Classical", "+HR ViT"]
for j in range(n_cols):
    axs[0, j].set_title(titles[j], fontsize=TITLE_FS)


for i, fname in enumerate(fnames):
    img, label, seg = images[fname], labels[fname], segs[fname]
    if label.shape[0] == 1:
        label = label[0]
    overlay_img = apply_labels_as_overlay(label, img, COLORS)

    ground_truth = label2rgb(seg + 1, colors=COLORS[1:])
    classical_pred = label2rgb(classical_preds[f"{fname}.tif"] + 1, colors=COLORS[1:])
    deep_pred = label2rgb(deep_preds[f"{fname}.tif"] + 1, colors=COLORS[1:])



    axs[i, 0].imshow(overlay_img, cmap='binary_r')
    axs[i, 1].imshow(ground_truth)
    add_inset_zoom([45, 110, 100, 100], [0.7, 0.15, 0.3, 0.3], ground_truth, None, axs[i, 1])
    axs[i, 2].imshow(classical_pred)
    add_inset_zoom([45, 110, 100, 100], [0.7, 0.15, 0.3, 0.3], classical_pred, None, axs[i, 2])
    axs[i, 3].imshow(deep_pred)
    add_inset_zoom([45, 110, 100, 100], [0.7, 0.15, 0.3, 0.3], deep_pred, None, axs[i, 3])

    for ax in axs[i]:
        ax.set_axis_off()

plt.tight_layout()
plt.savefig(f'fig_out/{chosen_dataset}_max_labels_preds.png', bbox_inches='tight')