In [31]:
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from sklearn.preprocessing import MinMaxScaler

import matplotlib.pyplot as plt
from skimage.color import label2rgb
from tifffile import imread

from yoeo.main import get_dv2_model, get_upsampler_and_expr, get_hr_feats
from yoeo.utils import to_numpy
# from interactive_seg_backend import featurise_, FeatureConfig

from types import NoneType


SEED = 10672
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = "cuda:1"

In [32]:
cmap = [
            "#fafafa",
            "#1f77b4",
            "#ff7f0e",
            "#2ca02c",
            "#d62728",
        ]
color_list = [[255, 255, 255], [31, 119, 180], [255, 127, 14], [44, 160, 44], [255, 0, 0]]
COLORS = np.array(color_list) / 255.0

In [34]:
def apply_labels_as_overlay(labels: np.ndarray, img: Image.Image, colors: list, alpha: float=1.0) -> Image.Image:
    labels_unsqueezed = np.expand_dims(labels, -1)

    overlay = label2rgb(labels, colors=colors[1:], kind='overlay', bg_label=0, image_alpha=1, alpha=alpha)
    out = np.where(labels_unsqueezed, overlay * 255, np.array(img)).astype(np.uint8)
    img_with_labels = Image.fromarray(out)
    return img_with_labels

def rescale(arr: np.ndarray, swap_channels: bool=True) -> np.ndarray:
    if swap_channels:
        arr = np.transpose(arr, (1, 2, 0))
    h, w, c = arr.shape
    flat = arr.reshape((h * w, c))
    rescaled_flat = MinMaxScaler(clip=True).fit_transform(flat)
    return rescaled_flat.reshape((h, w, c))

def add_inset_zoom(xywh: list[int], fig_xywh: list[float], img_arr: np.ndarray, labels: np.ndarray | None, ax ) -> object:
    x0, y0, w, h = xywh
    H, W, C = img_arr.shape
    inset_data = np.zeros_like(img_arr)
    inset_data[y0:y0+h, x0:x0+w, :] = img_arr[y0:y0+h, x0:x0+w, :]

    extent = (0, H, W, 0)
    axin = ax.inset_axes(
        fig_xywh, xlim=(x0, x0+w), ylim=(y0, y0+h))
    axin.set_xticks([])
    axin.set_yticks([])
    #axin.set_axis_off()
    if type(labels) != NoneType:
        inset_data = label2rgb(labels, img_arr, COLORS[1:], kind='overlay', alpha=0.6, bg_label=-1)
        axin.imshow(inset_data,)
    else:
        axin.imshow(inset_data, cmap="binary_r",) # cmap="binary_r"
    ax.indicate_inset_zoom(axin, edgecolor="black", lw=2)
    axin.set_ylim((y0 + h, y0))

    axin.patch.set_edgecolor('black')  

    axin.patch.set_linewidth(4)  

    return axin

In [12]:
DATA_PATH = "fig_data/is_examples"

img_paths = ["Battery2.png", "biphase_steel_crop.png", "cells.jpg"]
imgs = [Image.open(f"{DATA_PATH}/{path}").convert('RGB') for path in img_paths]

label_paths = ["2_am_2_labels.tiff", "biphase_steel_crop_labels.tiff", "cells_lots_labels.tiff"]
labels = [imread(f"{DATA_PATH}/{path}") for path in label_paths]

imgs_with_labels = [apply_labels_as_overlay(label, img, COLORS) for label, img in zip(labels, imgs)]

classical_seg_paths = ["2_am_2_classical.tiff", "biphase_steel_crop_classical.tiff", "cells_classical.tiff"]
deep_seg_paths = ["2_am_2.tiff", "biphase_steel_crop.tiff", "cells_.tiff"]

classical_segs = [imread(f"{DATA_PATH}/segs/{path}") for path in classical_seg_paths]
deep_segs = [imread(f"{DATA_PATH}/segs/{path}") for path in deep_seg_paths]

classical_seg_imgs = [label2rgb(seg, colors=COLORS[1:]) for seg in classical_segs]
deep_seg_imgs = [label2rgb(seg, colors=COLORS[1:]) for seg in deep_segs]

In [5]:
dv2 = get_dv2_model(True, device=DEVICE)

model_path = "../trained_models/e5000_full_fit_reg.pth"
cfg_path = "../yoeo/models/configs/combined_no_shift.json"

upsampler, expr = get_upsampler_and_expr(model_path, cfg_path, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main


In [6]:
def get_feats(img: Image.Image) -> np.ndarray:
    hr_feats = get_hr_feats(img, dv2, upsampler, DEVICE, n_ch_in=expr.n_ch_in)
    hr_feats_np = to_numpy(hr_feats)
    hr_feats_np = rescale(hr_feats_np)[:, :, :3].astype(np.float64)
    torch.cuda.empty_cache()
    return hr_feats_np

In [7]:
feats = [get_feats(img) for img in imgs]

In [47]:
%%capture
TITLE_FS = 25
LABEL_FS = 23
TICK_FS = 21

LABEL_PAD = 0
ROT = 90
plt.rcParams["font.family"] = "serif"

fig, axs = plt.subplots(nrows=len(img_paths), ncols=3)
fig.set_size_inches((18, 10))

titles = ["Image + labels", "Classical Features", "Classical + Deep Features"]
materials = ["Cathode (2 AM)", "Biphase Steel", "Plant Cells"]

inset_zoom_locs = [(400, 300, 400, 400), (300, 100, 300, 300), (600, 75, 300, 300)]

for y, row in enumerate(axs):
    if y == 0:
        for x, ax in enumerate(row):
            ax.set_title(titles[x], fontsize=TITLE_FS)

    img_arr = np.array(imgs_with_labels[y])
    classical_seg_arr = classical_segs[y]
    deep_seg_arr = deep_segs[y]

    row[0].imshow(img_arr)
    row[1].imshow(classical_seg_imgs[y])
    row[2].imshow(deep_seg_imgs[y])

    row[0].set_ylabel(materials[y], fontsize=TITLE_FS)

    inset_data_arrs = [None, classical_seg_arr, deep_seg_arr]
    for x, ax in enumerate(row):
        ax.set_xticklabels([])
        ax.set_yticklabels([])

        add_inset_zoom(inset_zoom_locs[y], [0.65, 0.05, 0.55, 0.55], img_arr, inset_data_arrs[x], ax  )

plt.tight_layout()
plt.savefig('fig_out/is_examples.png' ,bbox_inches='tight')