In [116]:
import numpy as np
import torch
from sklearn.preprocessing import minmax_scale

from interactive_seg_backend.file_handling import load_image, load_labels
from interactive_seg_backend.main import featurise, train_and_apply, TrainingConfig, FeatureConfig

from yoeo.main import (
    get_hr_feats,
    get_dv2_model,
    get_upsampler_and_expr,
)
from yoeo.utils import to_numpy

from PIL import Image
from skimage.color import label2rgb
import matplotlib.pyplot as plt

SEED = 10673
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = "cuda:1"

In [73]:
dv2 = get_dv2_model(True, device=DEVICE)

model_path = "../trained_models/e5000_full_fit_reg.pth"
cfg_path = "../yoeo/models/configs/combined_no_shift.json"

upsampler, expr = get_upsampler_and_expr(model_path, cfg_path, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main


In [74]:
PATH = "fig_data/supp_pos_bias"
img = load_image(f"{PATH}/outside3_crop.tif")
labels = load_labels(f"{PATH}/outside3_crop_labels.tiff")

In [111]:
hr_feats = get_hr_feats(img, dv2, upsampler, DEVICE, n_ch_in=expr.n_ch_in)
hr_feats_np = to_numpy(hr_feats).transpose((1, 2, 0))
hr_feats_rgb = hr_feats_np[:, :, :3]

In [112]:
tc = TrainingConfig(FeatureConfig())
classical_feats = featurise(img, tc)

seg_with_classical, _, _ = train_and_apply(classical_feats, labels, tc)
seg_with_deep, _, _ = train_and_apply(hr_feats_np, labels, tc)

In [113]:
h, w, c = hr_feats_rgb.shape
rescaled_flat = minmax_scale(hr_feats_rgb.reshape((h * w, c)))
rescaled = rescaled_flat.reshape((h , w, c))

In [118]:
cmap = [
            "#fafafa",
            "#1f77b4",
            "#ff7f0e",
            "#2ca02c",
            "#d62728",
        ]
color_list = [[255, 255, 255], [31, 119, 180], [255, 127, 14], [44, 160, 44], [255, 0, 0]]
COLORS = np.array(color_list) / 255.0

In [130]:
def apply_labels_as_overlay(labels: np.ndarray, img: Image.Image, colors: list, alpha: float=1.0) -> Image.Image:
    labels_unsqueezed = np.expand_dims(labels, -1)

    overlay = label2rgb(labels, colors=colors[1:], kind='overlay', bg_label=0, image_alpha=1, alpha=alpha)
    out = np.where(labels_unsqueezed, overlay * 255, np.array(img)).astype(np.uint8)
    img_with_labels = Image.fromarray(out)
    return img_with_labels

def rescale(arr: np.ndarray, swap_channels: bool=True) -> np.ndarray:
    if swap_channels:
        arr = np.transpose(arr, (1, 2, 0))
    h, w, c = arr.shape
    flat = arr.reshape((h * w, c))
    rescaled_flat = minmax_scale(flat)
    return rescaled_flat.reshape((h, w, c))

def add_inset_zoom(xywh: list[int], fig_xywh: list[float], img_arr: np.ndarray, labels: np.ndarray | None, ax ) -> object:
    x0, y0, w, h = xywh
    H, W, C = img_arr.shape
    inset_data = np.zeros_like(img_arr)
    inset_data[y0:y0+h, x0:x0+w, :] = img_arr[y0:y0+h, x0:x0+w, :]

    axin = ax.inset_axes(
        fig_xywh, xlim=(x0, x0+w), ylim=(y0, y0+h))
    axin.set_xticks([])
    axin.set_yticks([])
    #axin.set_axis_off()
    if labels is not None:
        inset_data = label2rgb(labels, img_arr, COLORS[1:], kind='overlay', alpha=0.6, bg_label=-1)
        axin.imshow(inset_data,)
    else:
        axin.imshow(inset_data, cmap="binary_r",) # cmap="binary_r"
    ax.indicate_inset_zoom(axin, edgecolor="black", lw=2)
    axin.set_ylim((y0 + h, y0))

    axin.patch.set_edgecolor('black')  

    axin.patch.set_linewidth(4)  

    return axin

In [131]:
img_with_labels = apply_labels_as_overlay(labels, img, COLORS)

In [155]:
%%capture
TITLE_FS = 25
LABEL_FS = 23
TICK_FS = 21

plt.rcParams["font.family"] = "serif"
fig, axs = plt.subplots(ncols=4, figsize=(18, 10))
axs[0].imshow(img_with_labels, cmap='binary_r')
add_inset_zoom([275, 125, 300, 50], [0.4, 0.5, 0.6, 0.3], np.array(img_with_labels), None, axs[0])

axs[1].imshow(rescaled)
axs[2].imshow(label2rgb(seg_with_deep + 1, colors=COLORS[1:]))
axs[3].imshow(label2rgb(seg_with_classical + 1, colors=COLORS[1:]))

titles = ["Image + (biased) labels", "HR ViT PCA", 'HR ViT only', "Classical"]
for i, ax in enumerate(axs):
    ax.set_axis_off()
    ax.set_title(titles[i], fontsize=TITLE_FS)
plt.tight_layout()
plt.savefig('fig_out/supp_pos_bias.png' ,bbox_inches='tight')