In [12]:
import numpy as np
from tifffile import imread
from PIL import Image
import matplotlib.pyplot as plt
from skimage.color import label2rgb
from skimage.transform import resize
import matplotlib.gridspec as gridspec

from hr_dv2.wss.featurisers import get_featuriser_classifier
from hr_dv2.wss.train import Classifiers, get_classifier, train_model, apply_model

In [2]:
def load_image(image_path: str) -> Image.Image:
    extension = image_path.split('.')[-1].lower()
    if extension in ['jpg', 'jpeg', 'png']:
        image = Image.open(image_path).convert('RGB')
    elif extension in ['tiff', 'tif']:
        image_arr = imread(image_path)
        image = Image.fromarray(image_arr)
    else:
        raise ValueError(f"Unsupported image format: {extension}")
    return image

def tiff_to_labels(tiff: np.ndarray, rev: bool=False) -> np.ndarray:
    if tiff.shape[0] == 1:
        tiff = tiff[0]
    out = tiff
    vals = np.unique(tiff)[::-1] if rev else np.unique(tiff)
    for i, val in enumerate(vals):
        out = np.where(tiff == val, i, out)
    return out

def resize_longest_side(img: Image.Image, l: int, patch_size: int = 14) -> Image.Image:
    oldh, oldw = img.height, img.width
    scale = l * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    neww = int(neww + 0.5)
    newh = int(newh + 0.5)
    neww = neww - (neww % patch_size)
    newh = newh - (newh % patch_size)

    return img.resize((neww, newh))


def resize_longest_side_arr(arr: np.ndarray, l: int, patch_size: int = 14) -> Image.Image:
    oldh, oldw = arr.shape[:2]
    scale = l * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    neww = int(neww + 0.5)
    newh = int(newh + 0.5)
    neww = neww - (neww % patch_size)
    newh = newh - (newh % patch_size)

    return resize(arr, (newh, neww), order=0, anti_aliasing=False, clip=True, preserve_range=True)

In [3]:
folder_names = ["nickel_superalloy"]

images = [load_image(f"data/wss/{folder_name}/image.png") for folder_name in folder_names]
images = [resize_longest_side(im, 518) for im in images]
labels = [tiff_to_labels(imread(f"data/wss/{folder_name}/labels.tiff")) for folder_name in folder_names]
labels = [resize_longest_side_arr(label, 518) for label in labels]

In [4]:
featurisers = ["weka", "FeatUp", "hybrid"]
classifiers: list[Classifiers] = ['linear', 'logistic', 'rf', 'mlp']

In [5]:
%%capture
all_results = {}
for i, (img, label) in enumerate(zip(images, labels)):
    img_name = folder_names[i]
    print(f"Processing {img_name}")
    all_results[img_name] = {}

    for featuriser in featurisers:
        all_results[img_name][featuriser] = {}
        print(f"Featuriser: {featuriser}")
        for classifier_name in classifiers:
            classifier = get_classifier(classifier_name)
            trs = "both" if featuriser in ("hybrid", "DINOv2-S-14") else None
            model = get_featuriser_classifier(featuriser, None, None, trs)
            
            image_arr = np.array(img)
            label_arr = np.array(label)
            features = model.img_to_features(img)
            # Train the model
            classifier = train_model(classifier, [features], [label])
            
            # Apply the model to the image
            predictions = apply_model(classifier, features)

            all_results[img_name][featuriser][classifier_name] = predictions

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [111]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset

def add_inset_zoom(xywh: list[int], fig_xywh: list[float], img_arr: np.ndarray, labels: np.ndarray | None, ax ) -> object:
    x0, y0, w, h = xywh
    H, W, C = img_arr.shape
    inset_data = np.zeros_like(img_arr)
    inset_data[y0:y0+h, x0:x0+w, :] = img_arr[y0:y0+h, x0:x0+w, :]

    # axin = ax.inset_axes(
    #     fig_xywh, xlim=(x0, x0+w), ylim=(y0, y0+h), zorder=10)
    axin = inset_axes(ax, width="35%", height="35%", bbox_transform=ax.transAxes, bbox_to_anchor=(0, 0, 1.15, 1) ) 
    axin.set_xticks([])
    axin.set_yticks([])
    if labels is not None:
        inset_data = label2rgb(labels, img_arr, COLORS[1:], kind='overlay', alpha=1, bg_label=-1)
        axin.imshow(inset_data)
    else:
        axin.imshow(inset_data, cmap="binary_r",) # cmap="binary_r"
    # ax.indicate_inset_zoom(axin, edgecolor="black", lw=2)
    axin.set_xlim((x0, x0 + w))
    axin.set_ylim((y0, y0 + h))
    mark_inset(ax, axin, loc1=2, loc2=4, fc="none", ec="black", lw=2)
    

    axin.patch.set_edgecolor('black')  

    axin.patch.set_linewidth(4)  

    return axin

def add_labels_overlay(img: Image.Image, labels: np.ndarray, colors: np.ndarray, alpha: float=0.5) -> np.ndarray:
    img_arr = np.array(img)
    overlay = label2rgb(labels, img_arr, colors, kind='overlay', alpha=1, bg_label=0)
    out_arr = img_arr.copy()
    labels_padded = np.expand_dims(labels, -1)
    out_arr = np.where(labels_padded == 0, img_arr / 255.0, overlay)
    return out_arr

def centre_crop(img_or_arr: np.ndarray |  Image.Image, cw: int, ch: int) -> np.ndarray | Image.Image:
    if type(img_or_arr) == np.ndarray:
        arr: np.ndarray = img_or_arr #type:ignore
        ih, iw = arr.shape
        oy, ox = (ih - ch) // 2, (iw - cw) // 2
        cropped = arr[oy:oy+ch, ox:ox+cw]
        return cropped
    else:
        img: Image.Image = img_or_arr #type: ignore
        ih, iw = img.height, img.width
        oy, ox = (ih - ch) // 2, (iw - cw) // 2
        bbox = (ox, oy, ox + cw, oy + ch)
        return img.crop(bbox)

def swap_label_vals(arr: np.ndarray, val_0: int, val_1: int, swap_val: int = 100) -> np.ndarray:
    tmp = np.where(arr == val_0, swap_val, arr)
    tmp = np.where(arr == val_1, val_0, tmp)
    tmp = np.where(arr == swap_val, val_1, tmp)
    return tmp


In [112]:
cmap = [
            "#fafafa",
            "#1f77b4",
            "#ff7f0e",
            "#2ca02c",
            "#d62728",
        ]
color_list = [[255, 255, 255], [31, 119, 180], [255, 127, 14], [44, 160, 44], [255, 0, 0]]
COLORS = np.array(color_list) / 255.0

In [114]:
%%capture
TITLE_FS = 24
plt.rcParams["font.family"] = "serif"
N_ROWS = len(featurisers) + 1
N_EXAMPLES = len(classifiers)
N_COLS = N_EXAMPLES

data_coords = [10, 410, 75, 75]
fig_coords = [0.75, -0.05, 0.35, 0.35]

# fig, axs = plt.subplots(nrows=N_ROWS, ncols=N_COLS, figsize=(N_COLS * 4.5, N_ROWS * 4.5))
fig = plt.figure(figsize=(N_COLS * 4, N_ROWS * 4))

gs = gridspec.GridSpec(4, 4, height_ratios=[1.15, 1, 1, 1])

img_ax = fig.add_subplot(gs[0, 1:3])
img_with_labels = add_labels_overlay(img, label, COLORS[1:], alpha=0.2)
img_ax.imshow(img_with_labels)
img_ax.set_title("Image + labels", fontsize=TITLE_FS)
img_ax.set_xticks([])
img_ax.set_yticks([])

add_inset_zoom(data_coords, fig_coords, img_with_labels, None, img_ax)  


col_titles = ["Ridge", "Logistic", "Random Forest", "MLP"]
row_titles = ["Classical", "Ours (DINOv2-S-14)", "Ours (Hybrid)"]

img_name = folder_names[0]
for i, featuriser in enumerate(featurisers):
    for j, classifier in enumerate(classifiers):
        ax = fig.add_subplot(gs[i + 1, j]) 
        # ax = axs[i, j]
        if i == 0:
            ax.set_title(col_titles[j], fontsize=TITLE_FS)
        if j == 0:
            ax.set_ylabel(row_titles[i], fontsize=TITLE_FS)
        pred = all_results[img_name][featuriser][classifier]

        img_with_labels = add_labels_overlay(img, pred, COLORS[1:], alpha=0.2)
        add_inset_zoom(data_coords, fig_coords, img_with_labels, None, ax)  

        ax.imshow(img_with_labels)
        ax.set_xticks([])
        ax.set_yticks([])

plt.tight_layout()
plt.savefig("../fig_out/wss_classifier_ablation.png", bbox_inches='tight')