Comparison over:
    - differnt upsamplers
    - different classifiers
For the benchmark images
Bipartite (calc then plot)

In [1]:
import numpy as np
from tifffile import imread
from PIL import Image
import matplotlib.pyplot as plt
from skimage.color import label2rgb
from skimage.transform import resize

from hr_dv2.wss.featurisers import get_featuriser_classifier
from hr_dv2.wss.train import get_classifier, train_model, apply_model

  from .autonotebook import tqdm as notebook_tqdm


N CPUS: 110


In [2]:
def load_image(image_path: str) -> Image.Image:
    extension = image_path.split('.')[-1].lower()
    if extension in ['jpg', 'jpeg', 'png']:
        image = Image.open(image_path).convert('RGB')
    elif extension in ['tiff', 'tif']:
        image_arr = imread(image_path)
        image = Image.fromarray(image_arr)
    else:
        raise ValueError(f"Unsupported image format: {extension}")
    return image

def tiff_to_labels(tiff: np.ndarray, rev: bool=False) -> np.ndarray:
    if tiff.shape[0] == 1:
        tiff = tiff[0]
    out = tiff
    vals = np.unique(tiff)[::-1] if rev else np.unique(tiff)
    for i, val in enumerate(vals):
        out = np.where(tiff == val, i, out)
    return out

def resize_longest_side(img: Image.Image, l: int, patch_size: int = 14) -> Image.Image:
    oldh, oldw = img.height, img.width
    scale = l * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    neww = int(neww + 0.5)
    newh = int(newh + 0.5)
    neww = neww - (neww % patch_size)
    newh = newh - (newh % patch_size)

    return img.resize((neww, newh))


def resize_longest_side_arr(arr: np.ndarray, l: int, patch_size: int = 14) -> Image.Image:
    oldh, oldw = arr.shape[:2]
    scale = l * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    neww = int(neww + 0.5)
    newh = int(newh + 0.5)
    neww = neww - (neww % patch_size)
    newh = newh - (newh % patch_size)

    return resize(arr, (newh, neww), order=0, anti_aliasing=False, clip=True, preserve_range=True)

In [3]:
folder_names = ["nmc_cathode", "sic_barrier_coating", "micrograph_394", "glutamic_acid", "nickel_superalloy", "gold_nanoparticles"]

images = [load_image(f"data/wss/{folder_name}/image.png") for folder_name in folder_names]
images = [resize_longest_side(im, 518) for im in images]
labels = [tiff_to_labels(imread(f"data/wss/{folder_name}/labels.tiff")) for folder_name in folder_names]
labels = [resize_longest_side_arr(label, 518) for label in labels]

In [4]:
# featurisers = ["bilinear", "FeatUp", "DINOv2-S-14"]
featurisers = ["hybrid_featup", "hybrid"]

In [5]:
%%capture
all_results = {}
for i, (img, label) in enumerate(zip(images, labels)):
    img_name = folder_names[i]
    print(f"Processing {img_name}")
    all_results[img_name] = {}

    for featuriser in featurisers:
        print(f"Featuriser: {featuriser}")
        classifier = get_classifier("rf")
        trs = "both" if featuriser in ("hybrid", "DINOv2-S-14") else None
        model = get_featuriser_classifier(featuriser, None, None, trs)
        print(model)
        
        image_arr = np.array(img)
        label_arr = np.array(label)
        features = model.img_to_features(img)
        print(features.shape)
        # Train the model
        classifier = train_model(classifier, [features], [label])
        
        # Apply the model to the image
        predictions = apply_model(classifier, features)

        all_results[img_name][featuriser] = predictions


===============PLOTTING===============

In [6]:
def add_inset_zoom(xywh: list[int], fig_xywh: list[float], img_arr: np.ndarray, labels: np.ndarray | None, ax ) -> object:
    x0, y0, w, h = xywh
    H, W, C = img_arr.shape
    inset_data = np.zeros_like(img_arr)
    inset_data[y0:y0+h, x0:x0+w, :] = img_arr[y0:y0+h, x0:x0+w, :]

    axin = ax.inset_axes(
        fig_xywh, xlim=(x0, x0+w), ylim=(y0, y0+h), zorder=10)
    axin.set_xticks([])
    axin.set_yticks([])
    if labels is not None:
        inset_data = label2rgb(labels, img_arr, COLORS[1:], kind='overlay', alpha=0.6, bg_label=-1)
        axin.imshow(inset_data,)
    else:
        axin.imshow(inset_data, cmap="binary_r",) # cmap="binary_r"
    ax.indicate_inset_zoom(axin, edgecolor="black", lw=2)
    axin.set_ylim((y0 + h, y0))

    axin.patch.set_edgecolor('black')  

    axin.patch.set_linewidth(4)  

    return axin

def add_labels_overlay(img: Image.Image, labels: np.ndarray, colors: np.ndarray, alpha: float=0.5) -> np.ndarray:
    img_arr = np.array(img)
    overlay = label2rgb(labels, img_arr, colors, kind='overlay', alpha=alpha, bg_label=0)
    out_arr = img_arr.copy()
    labels_padded = np.expand_dims(labels, -1)
    out_arr = np.where(labels_padded == 0, img_arr / 255.0, overlay)
    return out_arr

def centre_crop(img_or_arr: np.ndarray |  Image.Image, cw: int, ch: int) -> np.ndarray | Image.Image:
    if type(img_or_arr) == np.ndarray:
        arr: np.ndarray = img_or_arr #type:ignore
        ih, iw = arr.shape
        oy, ox = (ih - ch) // 2, (iw - cw) // 2
        cropped = arr[oy:oy+ch, ox:ox+cw]
        return cropped
    else:
        img: Image.Image = img_or_arr #type: ignore
        ih, iw = img.height, img.width
        oy, ox = (ih - ch) // 2, (iw - cw) // 2
        bbox = (ox, oy, ox + cw, oy + ch)
        return img.crop(bbox)

def swap_label_vals(arr: np.ndarray, val_0: int, val_1: int, swap_val: int = 100) -> np.ndarray:
    tmp = np.where(arr == val_0, swap_val, arr)
    tmp = np.where(arr == val_1, val_0, tmp)
    tmp = np.where(arr == swap_val, val_1, tmp)
    return tmp


In [7]:
cmap = [
            "#fafafa",
            "#1f77b4",
            "#ff7f0e",
            "#2ca02c",
            "#d62728",
        ]
color_list = [[255, 255, 255], [31, 119, 180], [255, 127, 14], [44, 160, 44], [255, 0, 0]]
COLORS = np.array(color_list) / 255.0

In [8]:
%%capture
plt.rcParams["font.family"] = "serif"
N_ROWS = len(images)
N_EXAMPLES = len(featurisers)
N_COLS = 1 + N_EXAMPLES

titles = ["NMC cathode", "SiC barrier coating", "Micrograph 394", "Glutamic Acid", "Nickel Superalloy", "Gold nanoparticles"]
# row_titles = ["User Labels", "Bilinear (DINOv2-S-14)", "FeatUp", "Ours (DINOv2-S-14)"]
row_titles = ["User Labels", "FeatUp (Hybrid)", "Ours (Hybrid)"]

titles = ["NMC cathode", "SiC barrier coating", "Micrograph 394", "Glutamic Acid", "Nickel Superalloy", "Gold nanoparticles"]

bat_bbox = [110, 20, 50, 50]
ebc_bbox = [370, 820 // 2, 50, 50]
nmc_bbox = [90, 20, 60, 50]
glut_bbox = [240, 250, 50, 50]
ni_bbox = [330, 360, 50, 50]
tem_bbox =  [420, 320, 50, 50]
bboxes = [bat_bbox, ebc_bbox, nmc_bbox, glut_bbox, ni_bbox, tem_bbox]

fig_bbox = [-0.0, 0.0, 0.4, 0.4]

fig, axs = plt.subplots(N_ROWS, N_COLS, figsize=(N_COLS * 4.5, N_ROWS * 4.5))
for row in range(N_ROWS):
    img_name = folder_names[row]
    img = images[row]
    label = labels[row]

    if img_name == "micrograph_394":
        h, w = label.shape
        img = centre_crop(img, h, h)
        label = centre_crop(label, h, h)
    elif img_name == "nickel_superalloy":
        label = swap_label_vals(label, 2, 3)

    # img_with_labels = label2rgb(labels[row], np.array(images[row]), COLORS[1:], kind='overlay', alpha=0.2, bg_label=0)
    img_with_labels = add_labels_overlay(img, label, COLORS[1:], alpha=1)
    axs[row, 0].imshow(img_with_labels)
    img_name = folder_names[row]

    for col in range(N_COLS):
        ax = axs[row, col]
        if row == 0:
            ax.set_title(row_titles[col], fontsize=28)
        ax.set_xticks([])
        ax.set_yticks([])

        img_ax = axs[row, 0]

        # inset_ax = add_inset_zoom(bboxes[row], fig_bbox, np.array(img), None, img_ax )
        img_ax.set_ylabel(titles[row], fontsize=28)

        if col == 0:
            continue

        featuriser = featurisers[col - 1]
        predictions = all_results[img_name][featuriser]
        if img_name == "micrograph_394":
            h, w = predictions.shape
            pred = centre_crop(predictions, h, h)
        elif img_name == "nickel_superalloy":
            pred = swap_label_vals(predictions, 2, 3)
        else:
            img = images[row]
            pred = predictions
        
        
        img_with_labels = add_labels_overlay(img, pred, COLORS[1:], alpha=0.7)
        ax.imshow(img_with_labels)
plt.tight_layout()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
            hspace = -0.20, wspace = 0.1)
# plt.savefig("../fig_out/upsampler_ablation__.png", bbox_inches='tight')
plt.savefig("../fig_out/hybrid_ablation.png", bbox_inches='tight')

In [9]:
folder_names = ["nmc_cathode", "sic_barrier_coating", "micrograph_394", "glutamic_acid", "nickel_superalloy", "gold_nanoparticles"]

images = [load_image(f"data/wss/{folder_name}/image.png") for folder_name in folder_names]
images = [resize_longest_side(im, 518) for im in images]
labels = [tiff_to_labels(imread(f"data/wss/{folder_name}/labels.tiff")) for folder_name in folder_names]
labels = [resize_longest_side_arr(label, 518) for label in labels]

In [10]:
featurisers = ["weka", "DINOv2-S-14", "hybrid"]

In [11]:
%%capture
all_results = {}
for i, (img, label) in enumerate(zip(images, labels)):
    img_name = folder_names[i]
    print(f"Processing {img_name}")
    all_results[img_name] = {}


    for featuriser in featurisers:
        print(f"Featuriser: {featuriser}")
        classifier = get_classifier("rf")
        trs = "both" if featuriser in ("hybrid", "DINOv2-S-14") else None
        model = get_featuriser_classifier(featuriser, None, None, trs)
        
        image_arr = np.array(img)
        label_arr = np.array(label)
        features = model.img_to_features(img)
        print(features.shape)
        # Train the model
        classifier = train_model(classifier, [features], [label])
        
        # Apply the model to the image
        predictions = apply_model(classifier, features)

        all_results[img_name][featuriser] = predictions

In [12]:
%%capture
plt.rcParams["font.family"] = "serif"
N_ROWS = len(images)
N_EXAMPLES = len(featurisers)
N_COLS = 1 + N_EXAMPLES

titles = ["NMC cathode", "SiC barrier coating", "Micrograph 394", "Glutamic Acid", "Nickel Superalloy", "Gold nanoparticles"]
row_titles = ["User Labels", "Classical Features", "Ours (DINOv2-S-14)", "Ours (Hybrid)"]

bat_bbox = [110, 20, 50, 50]
ebc_bbox = [370, 820 // 2, 50, 50]
nmc_bbox = [90, 20, 60, 50]
glut_bbox = [240, 250, 50, 50]
ni_bbox = [330, 360, 50, 50]
tem_bbox =  [420, 320, 50, 50]
bboxes = [bat_bbox, ebc_bbox, nmc_bbox, glut_bbox, ni_bbox, tem_bbox]

fig_bbox = [-0.0, 0.0, 0.4, 0.4]

fig, axs = plt.subplots(N_ROWS, N_COLS, figsize=(N_COLS * 4.5, N_ROWS * 4.5))
for row in range(N_ROWS):
    img_name = folder_names[row]
    img = images[row]
    label = labels[row]

    if img_name == "micrograph_394":
        h, w = label.shape
        img = centre_crop(img, h, h)
        pred = centre_crop(predictions, h, h)
        label = centre_crop(label, h, h)
    elif img_name == "nickel_superalloy":
        label = swap_label_vals(label, 2, 3)

    img_with_labels = add_labels_overlay(img, label, COLORS[1:], alpha=0.2)
    img_ax = axs[row, 0]
    img_ax.imshow(img_with_labels)

    inset_ax = add_inset_zoom(bboxes[row], fig_bbox, np.array(img), None, img_ax )
    img_ax.set_ylabel(titles[row], fontsize=28)
    
    for col in range(N_COLS):
        ax = axs[row, col]
        # ax.set_axis_off()
        if row == 0:
            ax.set_title(row_titles[col], fontsize=28)
        ax.set_xticks([])
        ax.set_yticks([])

        if col == 0:
            continue
        
        featuriser = featurisers[col - 1]
        predictions = all_results[img_name][featuriser]

        if img_name == "micrograph_394":
            h, w = predictions.shape
            pred = centre_crop(predictions, h, h)
        elif img_name == "nickel_superalloy":
            pred = swap_label_vals(predictions, 2, 3)
        else:
            img = images[row]
            pred = predictions

        img_with_labels = add_labels_overlay(img, pred, COLORS[1:], alpha=0.7)
        ax.imshow(img_with_labels)
        inset_ax = add_inset_zoom(bboxes[row], fig_bbox, np.array(img), pred, ax )

plt.tight_layout()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
            hspace = -0.20, wspace = 0.1)
plt.savefig("../fig_out/wss_classical_hybrid.png", bbox_inches='tight')