In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.preprocessing import minmax_scale

from interactive_seg_backend.file_handling import load_image, load_labels
from interactive_seg_backend.main import featurise, train_and_apply, TrainingConfig, FeatureConfig

from yoeo.main import (
    get_hr_feats,
    get_dv2_model,
    get_upsampler_and_expr,
)
from yoeo.utils import convert_image, to_numpy, closest_crop, Experiment, do_2D_pca, add_flash_attention
from yoeo.comparisons.lift import ViTLiFTExtractor
from yoeo.comparisons.strided import StridedDv2

from PIL import Image
from skimage.color import label2rgb
import matplotlib.pyplot as plt

from typing import Any, Literal

SEED = 10673
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = "cuda:1"

FeatureTypes = Literal["classical", "dv2_nearest", "dv2_bilinear", "strided", "featup_jbu", "lift", "hr_vit"]

N CPUS: 110


  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


TODO:
- test ni supealloy/cracks (i.e small tertiary phases)
    - cracks is good: it's where we need to align the notions of particleness from deep feats with classical feats 
- overlay onto image


In [2]:
normal_dv2 = get_dv2_model(fit_3d=False, device=DEVICE)
dv2 = get_dv2_model(True, device=DEVICE)

model_path = "../trained_models/e5000_full_fit_reg.pth"
cfg_path = "../yoeo/models/configs/combined_no_shift.json"

upsampler, expr = get_upsampler_and_expr(model_path, cfg_path, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main


In [3]:
featup_jbu = torch.hub.load("mhamilton723/FeatUp", "dinov2", use_norm=True).to(DEVICE).eval()

Using cache found in /home/ronan/.cache/torch/hub/mhamilton723_FeatUp_main
Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dinov2_main


In [4]:
strided = StridedDv2('dinov2', 'vits14_reg', 1).to(DEVICE).eval()
strided.model = add_flash_attention(strided.model)

Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dinov2_main


In [5]:
lift_path = "../trained_models/lift/lift_dino_vits8.pth"
lift = ViTLiFTExtractor('dino_vits8', lift_path=lift_path, channel=384, facet='key', device=DEVICE).eval()

Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dino_main


Loaded Backbone: dino_vits8
Loaded LiFT module from: ../trained_models/lift/lift_dino_vits8.pth


In [6]:
@torch.no_grad()
def _our_featurise(img: Image.Image, dv2: torch.nn.Module, upsampler: torch.nn.Module, n_ch_in: int) -> np.ndarray:
    hr_feats = get_hr_feats(img, dv2, upsampler, DEVICE, n_ch_in=n_ch_in)
    hr_feats_np = to_numpy(hr_feats)
    reduced_hr = hr_feats_np
    return reduced_hr

@torch.no_grad()
def _original_featurise(img: Image.Image, dv2: torch.nn.Module, resize: str | None="nearest") -> np.ndarray:
    # _img = Image.open(path).convert("RGB")
    _h, _w = img.height, img.width
    tr = closest_crop(_h, _w)

    tensor = convert_image(img, tr, device_str=DEVICE)
    _, _, h, w = tensor.shape
    with torch.autocast("cuda", torch.float16):
        dino_feats = dv2.forward_features(tensor)['x_norm_patchtokens']
    n_patch_w, n_patch_h = w // 14, h // 14
    dino_feats = dino_feats.permute((0, 2, 1))
    dino_feats = dino_feats.reshape((1, -1, n_patch_h, n_patch_w,))
    if resize == "nearest":
        dino_feats = F.interpolate(dino_feats, (_h, _w), mode='nearest')
    elif resize == "bilinear":
        dino_feats = F.interpolate(dino_feats, (_h, _w), mode='bilinear')
    dino_feats_np = to_numpy(dino_feats)
    return dino_feats_np

@torch.no_grad()
def _jbu_featurise(img: Image.Image, jbu: torch.nn.Module) -> np.ndarray:
    # _img = Image.open(path).convert("RGB")
    _h, _w = img.height, img.width
    tr = closest_crop(_h, _w)

    tensor = convert_image(img, tr, device_str=DEVICE)
    with torch.autocast("cuda", torch.float16):
        jbu_feats = jbu(tensor.to(torch.float32))
    jbu_feats = F.interpolate(jbu_feats, (_h, _w))
    jbu_feats_np = to_numpy(jbu_feats)
    return jbu_feats_np

@torch.no_grad()
def _lift_featurise(path: str, img: Image.Image, lift: ViTLiFTExtractor, n: int=3, patch_size: int=8) -> np.ndarray:
    def closest_mult(x: int, p: int=8) -> int:
        return x - (x % p)
    tr_h, tr_w = closest_mult(img.height, 8), closest_mult(img.width, 8)
    _h, _w = img.height, img.width

    image_batch, _ = lift.preprocess(path, (tr_h, tr_w))
    image_batch = image_batch.to(DEVICE)

    with torch.autocast("cuda", torch.float16):
        lift_feats = lift.extract_descriptors_iterative_lift(image_batch, lift_iter=n)
    _, _, c = lift_feats.shape
    sf = int(patch_size / (2**n))
    reshaped = lift_feats.squeeze(0).T.reshape((1, c, tr_h // sf, tr_w //  sf))
    print(reshaped.shape)
    resized = F.interpolate(reshaped, (_h, _w))
    return to_numpy(resized)

@torch.no_grad()
def _strided_featurise(img: Image.Image, strided: StridedDv2) -> np.ndarray:
    # _img = Image.open(path).convert("RGB")
    _h, _w = img.height, img.width
    tr = closest_crop(_h, _w)

    tensor = convert_image(img, tr, device_str=DEVICE)
    with torch.autocast("cuda", torch.float16):
        strided_feats = strided(tensor)
    strided_feats = F.interpolate(strided_feats, (_h, _w))
    strided_feats_feats_np = to_numpy(strided_feats)
    return strided_feats_feats_np

In [7]:
PATH = "fig_data/supp_upsampler_choice"
tc = TrainingConfig(FeatureConfig(), classifier='xgb', classifier_params={'class_weight': 'balanced'})
# img_path = f"{PATH}/img_patch14.tif"
img_path = f"{PATH}/noisy_NMC_cracks.png"
img = Image.fromarray(load_image(f"{img_path}")).convert('RGB')
# labels = load_labels(f"{PATH}/labels_patch14_more.tiff")
labels = load_labels(f"{PATH}/noisy_nmc_labels_more.tiff")

In [8]:
def get_features(img: Image.Image, feat_type: FeatureTypes) -> tuple[np.ndarray, np.ndarray]:
    classical_feats = featurise(np.array(img), tc)
    if feat_type == "classical":
        return (classical_feats, classical_feats)
    elif feat_type == "dv2_nearest":
        deep_feats = _original_featurise(img, normal_dv2, 'nearest')
    elif feat_type == "dv2_bilinear":
        deep_feats = _original_featurise(img, normal_dv2, 'bilinear')
    elif feat_type == "strided":
        deep_feats = _strided_featurise(img, strided)
    elif feat_type == "featup_jbu":
        deep_feats = _jbu_featurise(img, featup_jbu)
    elif feat_type == "lift":
        deep_feats = _lift_featurise(img_path, img, lift)
    elif feat_type == "hr_vit":
        deep_feats = _our_featurise(img, dv2, upsampler, expr.n_ch_in)
    
    deep_feats = np.transpose(deep_feats, (1, 2, 0))
    return (classical_feats, deep_feats)

def rescale(arr: np.ndarray, swap_channels: bool=True) -> np.ndarray:
    if swap_channels:
        arr = np.transpose(arr, (1, 2, 0))
    h, w, c = arr.shape
    flat = arr.reshape((h * w, c))
    rescaled_flat = minmax_scale(flat)
    return rescaled_flat.reshape((h, w, c))

def get_feat_vis(feats: np.ndarray, feat_type: FeatureTypes) -> np.ndarray:
    if feat_type in ("hr_vit", ):
        return rescale(feats[:, :, :3], swap_channels=False)
    else:
        tmp_transpose = feats.transpose((2, 0, 1))
        reduced = do_2D_pca(tmp_transpose, 3)
        reshaped = reduced.transpose((0, 1, 2))
        return rescale(reshaped, swap_channels=False)

In [9]:
features = ("classical", "dv2_nearest", "dv2_bilinear", "strided", "featup_jbu", "lift", "hr_vit")
res: dict[FeatureTypes, dict] = {}

for feat_type in features:
    print(feat_type)
    res[feat_type] = {}
    classical, deep = get_features(img, feat_type)
    feats = np.concatenate((classical, deep), -1)
    # feats = deep
    res[feat_type]["feat_vis"] = get_feat_vis(deep, feat_type)

    pred, _, _ = train_and_apply(feats, labels, tc)
    res[feat_type]["pred"] = pred

classical
dv2_nearest
dv2_bilinear
strided
featup_jbu
lift
torch.Size([1, 384, 576, 680])
hr_vit


In [10]:
# color_list = [[255, 255, 255], [31, 119, 180], [255, 127, 14], [44, 160, 44], [255, 0, 0]]
color_list = [[255, 255, 255], [0, 62, 131], [181, 209, 204], [250, 43, 0], [255, 184, 82]]
COLORS = np.array(color_list) / 255.0

In [11]:
def apply_labels_as_overlay(labels: np.ndarray, img: Image.Image, colors: list, alpha: float=1.0) -> Image.Image:
    labels_unsqueezed = np.expand_dims(labels, -1)

    overlay = label2rgb(labels, colors=colors[1:], kind='overlay', bg_label=0, image_alpha=1, alpha=alpha)
    out = np.where(labels_unsqueezed, overlay * 255, np.array(img)).astype(np.uint8)
    img_with_labels = Image.fromarray(out)
    return img_with_labels

In [12]:
def add_inset_zoom(xywh: list[int], fig_xywh: list[float], img_arr: np.ndarray, labels: np.ndarray | None, ax ) -> object:
    x0, y0, w, h = xywh
    H, W, C = img_arr.shape
    inset_data = np.zeros_like(img_arr)
    inset_data[y0:y0+h, x0:x0+w, :] = img_arr[y0:y0+h, x0:x0+w, :]

    axin = ax.inset_axes(
        fig_xywh, xlim=(x0, x0+w), ylim=(y0, y0+h))
    axin.set_xticks([])
    axin.set_yticks([])
    #axin.set_axis_off()
    if labels is not None:
        inset_data = label2rgb(labels, img_arr, COLORS[1:], kind='overlay', alpha=1, bg_label=-1)
        axin.imshow(inset_data,)
    else:
        axin.imshow(inset_data, cmap="binary_r",) # cmap="binary_r"
    ax.indicate_inset_zoom(axin, edgecolor="black", lw=2)
    axin.set_ylim((y0 + h, y0))

    axin.patch.set_edgecolor('black')  

    axin.patch.set_linewidth(4)  

    return axin

In [13]:
%%capture
from matplotlib.gridspec import GridSpec

plt.rcParams["font.family"] = "serif"

TITLE_FS = 25
LABEL_FS = 23
TICK_FS = 21

width = 4
height = 4
N = len(features) 
#figsize=(width * 2 * 2, height * (N // 2))

half_N = (N // 2)
nrows = 1 + half_N
ncols = 4
fig = plt.figure(figsize=(width * ncols, height * nrows ))
gs = GridSpec(nrows, ncols, figure=fig)


zoom_1_data_loc, zoom_2_data_loc = [140, 310, 90, 90], [140, 20, 120, 120]
zoom_1_fig_loc, zoom_2_fig_loc = [-0.2, 0.0, 0.3, 0.3], [0.8, 0.6, 0.4, 0.4]


ax_top = fig.add_subplot(gs[0, :2])  # Span across both column
img_with_labels = apply_labels_as_overlay(labels, img, colors=COLORS)
add_inset_zoom(zoom_1_data_loc, zoom_1_fig_loc, np.array(img_with_labels), None, ax=ax_top)
add_inset_zoom(zoom_2_data_loc, zoom_2_fig_loc, np.array(img_with_labels), None, ax=ax_top)

ax_top.set_title("Image + labels", fontsize=TITLE_FS)
ax_top.imshow(img_with_labels)
ax_top.set_axis_off()

titles = ["Classical", "+Dv2 nearest", "+Dv2 bilinear", "+Strided", "+Featup (JBU)", "+LiFT", "+HR ViT"]
for i, (feat, subdict) in enumerate(res.items()):
    # if feat in ("strided", "lift", "featup_jbu"):
    #     continue
    if feat == "classical":
        row, col_1, col_2 = 0, 2, 3
    else:
        row =  (i - 1) % half_N + 1
        col_1 = 2 * ((i - 1) // half_N)
        col_2 = 2 * ((i - 1) // half_N) + 1

    feat_ax = fig.add_subplot(gs[row, col_1])
    feat_ax.set_title(titles[i], fontsize=TITLE_FS)
    feat_ax.imshow(subdict["feat_vis"])
    pred_ax = fig.add_subplot(gs[row, col_2])

    pred_recoloured = label2rgb(subdict["pred"] + 1, colors=COLORS[1:])
    pred_ax.imshow(pred_recoloured)

    add_inset_zoom(zoom_1_data_loc, zoom_1_fig_loc, pred_recoloured, None, ax=pred_ax)
    add_inset_zoom(zoom_2_data_loc, zoom_2_fig_loc, pred_recoloured, None, ax=pred_ax)

    for ax in (feat_ax, pred_ax):
        ax.set_axis_off()


plt.tight_layout()
plt.savefig('fig_out/supp_upsampler_choice.png', bbox_inches='tight')