In [10]:
from skimage.color import label2rgb
from scipy.ndimage import zoom
from skimage.transform import resize
import matplotlib.pyplot as plt
from os import getcwd

from tifffile import imread

import numpy as np


examples = ["Frame5_4.6x.tiff", # lighter
            "Box13_P1_JK_4600x_0118.tiff", # weird shape
            "Box13_P1_JK_4600x_0103.tiff", 
            "Box13_P1_JK_4600x_0099.tiff",
            "Box13_O2_JK_4600x_0004.tiff", # darker
            ]

PREFIX = ""
MASK_SUFFIX = "_segmentation.tifnomalized.tif"

#PRED_FOLDERS = ["masks", "preds/dv2_out", "preds/rf_out_no_crf"]
PRED_FOLDERS = ["masks", "preds/classical_crf", "preds/FeatUp_crf", "preds/bilinear_trained_sigma_0_crf", "preds/DINO-S-8_crf", "preds/DINOv2-S-14_crf", "preds/hybrid_crf"]
# PRED_FOLDERS = ["masks", "preds/classical_", "preds/FeatUp_", "preds/bilinear_trained_sigma_0_", "preds/DINO-S-8_", "preds/DINOv2-S-14_", "preds/hybrid_"]
# PRED_FOLDERS = ["masks", "preds/classical_", "preds/hybrid_crf"]

color_list = [[255, 255, 255], [44, 160, 44], [255, 127, 14], [31, 119, 180]]
COLORS = np.array(color_list) / 255.0

def remap_label_arr(arr: np.ndarray) -> np.ndarray:
    unq_vals = sorted(np.unique(arr))
    #div_val = unq_vals[1] if unq_vals[0] == 0 else unq_vals[0]
    div_val = unq_vals[-1] - unq_vals[-2]
    return (arr // div_val) - 1

In [11]:
%%capture
fig, axs = plt.subplots(nrows=len(PRED_FOLDERS), ncols=len(examples))

fig.set_size_inches(5 * len(examples), 5 * len(PRED_FOLDERS))
plt.rcParams["font.family"] = "serif"
#titles = ["Ground Truth", "Deep Features", "Classical Features"]
titles = ["Ground Truth", "Classical", "FeatUp", "Dv2-S-14 (bilinear)", "Ours (D-S-8)", "Ours (Dv2-S-14)", "Ours (Hybrid)"]
#titles = ["Ground Truth", "Classical", "Ours (Hybrid)"]

for col, fname in enumerate(examples):
    original = imread(f"data/{fname[:-1]}")
    

    for row, pred_folder in enumerate(PRED_FOLDERS):
        ax = axs[row, col]

        if (col == 0):
            ax.set_ylabel(titles[row], fontsize=36)

        suffix = MASK_SUFFIX if row == 0 else ""
        
        if row == 0:
            real_fname = fname[:-1]
        else:
            real_fname = fname

        low_res_data = imread(f"{pred_folder}/{real_fname}{suffix}")
        low_res_data = remap_label_arr(low_res_data)
        
        data = resize(low_res_data, (1024, 1024), preserve_range=True) if row > 0 else low_res_data
        data = data.astype(np.uint8)
        overlay = label2rgb(data, original, colors=COLORS[1:], alpha=0.4, bg_label=-1)
        
        ax.imshow(overlay)
        ax.set_xticks([])
        ax.set_yticks([])
        #ax.set_axis_off()
plt.tight_layout()

name = 'full_crf'
plt.savefig(f"figures/pred_comparison_{name}.png", bbox_inches="tight")

In [3]:
%%capture
train_stack = imread("training_data/train_stack_small.tif")
greyscale_labels = imread("training_data/wss_train_labels.tiff").astype(np.int32)

# print(greyscale_labels.shape, greyscale_labels.dtype)

# color_list = [[255, 255, 255,], [31, 119, 180], [255, 127, 14],  [44, 160, 44]]
# COLORS = np.array(color_list) / 255.0

r, c = 2, 3
fig, axs = plt.subplots(nrows=r, ncols=c)
fig.set_size_inches(24, 16)
for i, arr in enumerate(train_stack):
    x, y = i % c, i // c
    ax = axs[y, x]
    greyscale_label = greyscale_labels[i]
    label = remap_label_arr(greyscale_label)
    print(np.unique(label))
    if i == 5:
        label[0, 0] = 0
    #    label  = np.where(label == 0, 0, labels_arr[-1])
    img = label2rgb(label, arr, colors=COLORS[1:], bg_label=-1, alpha=0.4)
    ax.imshow(img)

    ax.set_axis_off()
plt.tight_layout()
plt.savefig(f"figures/labels.png", bbox_inches="tight")

In [4]:
data = np.genfromtxt("figures/scaling.csv", delimiter=",")[1:]
x = data[:, 0]
dv2_miou = data[:, 1]
dv2_std = data[:, 3]

rf_miou = data[:, 2]
rf_std = data[:, 4]
data

array([[ 1.   ,  0.799,  0.349,  0.136,  0.102],
       [ 2.   ,  0.778,  0.38 ,  0.158,  0.102],
       [ 4.   ,  0.799,  0.393,  0.141,  0.103],
       [ 8.   ,  0.827,  0.398,  0.126,  0.112],
       [16.   ,  0.858,  0.407,  0.115,  0.126]])

In [5]:
def format_plot(ax, title: str, xlabel: str, ylabel: str, title_fontsize: int, label_fontsize: int, tick_fontsize: int, legend_fontsize: int ) -> None:
    plt.rcParams["font.family"] = "serif"
    ax.set_xlabel(xlabel, fontsize=label_fontsize)
    ax.set_ylabel(ylabel, fontsize=label_fontsize)
    ax.tick_params(axis='both', labelsize=tick_fontsize)
    ax.set_title(title, fontsize=title_fontsize)
    ax.legend(fontsize=legend_fontsize)

In [6]:
%%capture
fig = plt.figure(2) #axs = plt.subplots(1, 2)
fig.set_size_inches(16, 16)

RED = '#d00000'
BLUE = '#1e74fd'
plt.errorbar(x, dv2_miou, dv2_std, label="Deep Features", marker='.', lw=6, ms=30, ecolor=RED, color=RED)
plt.errorbar(x, rf_miou, rf_std, label="Classical Features", marker='.', lw=6, ms=30)
format_plot(fig.gca(), "", "# labelled micrographs", "Class-averaged mIoU", 20, 32, 24, 24)