In [1]:
import numpy as np
import torch

from interactive_seg_backend.file_handling import load_image, load_labels
from interactive_seg_backend.main import featurise, train_and_apply, TrainingConfig, FeatureConfig

from yoeo.main import (
    get_hr_feats,
    get_dv2_model,
    get_upsampler_and_expr,
)
from yoeo.utils import to_numpy
from is_helpers import get_deep_feats, train_model_over_images, apply_model_over_images, eval_preds

from PIL import Image
from skimage.color import label2rgb
import matplotlib.pyplot as plt

from typing import Any, Literal

SEED = 10673
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = "cuda:1"

N CPUS: 110


  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


In [2]:
dv2 = get_dv2_model(True, device=DEVICE)

model_path = "../trained_models/e5000_full_fit_reg.pth"
cfg_path = "../yoeo/models/configs/combined_no_shift.json"

upsampler, expr = get_upsampler_and_expr(model_path, cfg_path, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main


[128, 128, 128, 128]


In [3]:
PATH = "fig_data/supp_classifier_choice"
img = load_image(f"{PATH}/img.tif")
labels = load_labels(f"{PATH}/labels.tif")
ground_truth = load_labels(f"{PATH}/gt.tif")

In [12]:
featuresets = ("Classical", "+HR ViT")
classifiers = ("linear_regression", "logistic_regression", "random_forest", "xgb", 'mlp')
results: dict[str, Any] = {"Classical": {}, "+HR ViT": {}}

params = [{}, {}, {'class_weight': 'balanced'}, {'class_weight': 'balanced'}, {}]

In [5]:
dummy_tc = TrainingConfig(FeatureConfig())
classical_feats = featurise(img, dummy_tc)
results["Classical"]['features'] = classical_feats

deep_feats = get_deep_feats(img, dv2, upsampler, expr, 32)
results["+HR ViT"]['features'] = np.concatenate((classical_feats, deep_feats), axis=-1)

In [6]:
for featureset in featuresets:
    for classifier in classifiers:
        feats = results[featureset]["features"]
        tc = TrainingConfig(FeatureConfig(), classifier)
        seg, _, _ = train_and_apply(feats, labels, tc)
        results[featureset][classifier] = seg

In [7]:
# color_list = [[255, 255, 255], [31, 119, 180], [255, 127, 14], [44, 160, 44], [255, 0, 0]]
color_list = [[255, 255, 255], [0, 62, 131], [181, 209, 204], [250, 43, 0], [255, 184, 82]]
COLORS = np.array(color_list) / 255.0

In [8]:
def apply_labels_as_overlay(labels: np.ndarray, img: Image.Image, colors: list, alpha: float=1.0) -> Image.Image:
    labels_unsqueezed = np.expand_dims(labels, -1)

    overlay = label2rgb(labels, colors=colors[1:], kind='overlay', bg_label=0, image_alpha=1, alpha=alpha)
    out = np.where(labels_unsqueezed, overlay * 255, np.array(img)).astype(np.uint8)
    img_with_labels = Image.fromarray(out)
    return img_with_labels

def add_inset_zoom(xywh: list[int], fig_xywh: list[float], img_arr: np.ndarray, labels: np.ndarray | None, ax ) -> object:
    x0, y0, w, h = xywh
    H, W, C = img_arr.shape
    inset_data = np.zeros_like(img_arr)
    inset_data[y0:y0+h, x0:x0+w, :] = img_arr[y0:y0+h, x0:x0+w, :]

    axin = ax.inset_axes(
        fig_xywh, xlim=(x0, x0+w), ylim=(y0, y0+h))
    axin.set_xticks([])
    axin.set_yticks([])
    #axin.set_axis_off()
    if labels is not None:
        inset_data = label2rgb(labels, img_arr, COLORS[1:], kind='overlay', alpha=1, bg_label=-1)
        axin.imshow(inset_data,)
    else:
        axin.imshow(inset_data, cmap="binary_r",) # cmap="binary_r"
    ax.indicate_inset_zoom(axin, edgecolor="black", lw=2)
    axin.set_ylim((y0 + h, y0))

    axin.patch.set_edgecolor('black')  

    axin.patch.set_linewidth(4)  

    return axin

In [9]:
%%capture

n_examples = 3
n_cols = len(classifiers)

plt.rcParams["font.family"] = "serif"

TITLE_FS = 25
LABEL_FS = 23
TICK_FS = 21

width = 4
fig, axs = plt.subplots(nrows=n_examples, ncols=n_cols, figsize=(width * n_cols, width * n_examples))

for row in axs:
    for ax in row:
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_frame_on(False)

img_with_labels = apply_labels_as_overlay(labels[0], Image.fromarray(img).convert('RGB'), COLORS)
axs[0, 1].set_title('Image + labels', fontsize=TITLE_FS)
axs[0, 1].imshow(img_with_labels)
axs[0, 2].set_title('Ground truth', fontsize=TITLE_FS)
axs[0, 2].imshow(label2rgb(ground_truth + 1, colors=COLORS[1:]))

titles = ("Linear", "Logistic regression", "Random Forest", "XGB", "MLP")
for i, (featureset, sub_dict) in enumerate(results.items()):
    print(featureset)
    for j, classifier in enumerate(classifiers):
        pred = sub_dict[classifier]
        axs[i + 1, j].imshow(label2rgb(pred + 1, colors=COLORS[1:]))
        add_inset_zoom([40, 80, 160, 100], [0.6, 0.1, 0.5, 0.4], np.array(img_with_labels), pred, axs[i + 1, j])

        if i == 0:
            axs[i + 1, j].set_title(titles[j], fontsize=TITLE_FS)
        if j == 0: 
            axs[i + 1, j].set_ylabel(featureset, fontsize=TITLE_FS)

plt.tight_layout()
plt.savefig(f'fig_out/supp_classifier_choice.png', bbox_inches='tight')

In [13]:
PATH = "fig_data/is_benchmark"
AllowedDatasets = Literal["Ni_superalloy_SEM", "T_cell_TEM", "Cu_ore_RLM"]

TRAIN_IMG_FNAMES: dict[AllowedDatasets, list[str]] = {"Cu_ore_RLM": ["004", "028", "049", "077"], 
                                                      "Ni_superalloy_SEM": ["000", "001", "005", "007"], 
                                                      "T_cell_TEM": ["000", "005", "007", "026"]
                                                      }

all_classical_preds: dict[str, dict[str, np.ndarray]] = {k: {} for k in classifiers}
all_deep_preds: dict[str, dict[str, np.ndarray]] = {k: {} for k in classifiers}

chosen_dataset = "Ni_superalloy_SEM"
for classifier, params in zip(classifiers, params):
    feat_cfg = FeatureConfig()

    classical_train_cfg = TrainingConfig(feat_cfg, n_samples=-1, add_dino_features=False, classifier=classifier, classifier_params = params)
    classical_model, _ = train_model_over_images(chosen_dataset, classical_train_cfg, PATH, TRAIN_IMG_FNAMES[chosen_dataset], dv2, upsampler, expr)

    deep_train_cfg = TrainingConfig(feat_cfg, n_samples=-1, add_dino_features=True, classifier=classifier, classifier_params = params)
    deep_model, pca = train_model_over_images(chosen_dataset, deep_train_cfg, PATH, TRAIN_IMG_FNAMES[chosen_dataset], dv2, upsampler, expr)

    classical_preds = apply_model_over_images(chosen_dataset, classical_train_cfg, classical_model, PATH, dv2, upsampler, expr, verbose=False)
    deep_preds = apply_model_over_images(chosen_dataset, deep_train_cfg, deep_model, PATH, dv2, upsampler, expr, verbose=True, existing_pca=pca)

    all_classical_preds[classifier] = classical_preds
    all_deep_preds[classifier] = deep_preds

Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
[00/23] - 000.tif
[10/23] - 010.tif
[20/23] - 020.tif
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
[00/23] - 000.tif
[10/23] - 010.tif
[20/23] - 020.tif
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
[00/23] - 000.tif
[10/23] - 010.tif
[20/23] - 020.tif
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
[00/23] - 000.tif
[10/23] - 010.tif
[20/23] - 020.tif
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finished featurising
Finishe

In [14]:
for classifier in classifiers:
    classical_preds, deep_preds = all_classical_preds[classifier], all_deep_preds[classifier]
    miou_classical, miou_std_classical = eval_preds(chosen_dataset, classical_preds, PATH)
    miou_deep, miou_std_deep = eval_preds(chosen_dataset, deep_preds, PATH )
    print(f"======== {classifier} ========")
    print(f"mIoU_classical: {miou_classical:.4f}+/-{miou_std_classical:.4f} vs mIoU_deep: {miou_deep:.4f}+/-{miou_std_deep:.4f}\n")

mIoU_classical: 0.4071+/-0.0768 vs mIoU_deep: 0.6259+/-0.1486

mIoU_classical: 0.3842+/-0.0652 vs mIoU_deep: 0.5864+/-0.0518

mIoU_classical: 0.5322+/-0.1401 vs mIoU_deep: 0.6925+/-0.1142

mIoU_classical: 0.5403+/-0.1249 vs mIoU_deep: 0.6803+/-0.1240

mIoU_classical: 0.5238+/-0.1147 vs mIoU_deep: 0.6341+/-0.0936

