In [59]:
import numpy as np
import torch
from random import seed
from os import listdir
from time import time
import matplotlib.pyplot as plt
from skimage.color import label2rgb

from yoeo.main import get_dv2_model, get_upsampler_and_expr
from interactive_seg_backend.configs import FeatureConfig, TrainingConfig
from is_helpers import AllowedDatasets, eval_preds, get_pca_over_images_or_dir, get_and_cache_features_over_images, train_model_over_images, apply_model_over_images
    
from typing import Literal

SEED = 10672
seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = "cuda:1"

In [2]:
dv2 = get_dv2_model(True, device=DEVICE)

model_path = "../trained_models/e5000_full_fit_reg.pth"
cfg_path = "../yoeo/models/configs/combined_no_shift.json"

upsampler, expr = get_upsampler_and_expr(model_path, cfg_path, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main


[128, 128, 128, 128]


In [3]:
SAVE: bool = False
PATH = "fig_data/is_benchmark"
AllowedDatasets = Literal["Ni_superalloy_SEM", "T_cell_TEM", "Cu_ore_RLM"]
dataset: tuple[AllowedDatasets, ...] = ("Ni_superalloy_SEM", "T_cell_TEM", "Cu_ore_RLM")

chosen_dataset = "Ni_superalloy_SEM"
fnames = sorted(listdir(f"{PATH}/{chosen_dataset}/images/"))
images = [f"{PATH}/{chosen_dataset}/images/{fname}" for fname in fnames]

In [4]:
pca = get_pca_over_images_or_dir(images, dv2)

feat_cfg = FeatureConfig()
train_cfg = TrainingConfig(feat_cfg, n_samples=-1, add_dino_features=True, classifier='xgb', classifier_params = {"class_weight": "balanced", "max_depth": 32,})
classical_train_cfg = TrainingConfig(feat_cfg, n_samples=-1, add_dino_features=False, classifier='xgb', classifier_params = {"class_weight": "balanced", "max_depth": 32,})

In [5]:
start_feat_t = time()
get_and_cache_features_over_images(chosen_dataset, train_cfg, '.tmp', PATH, dv2, upsampler, expr, pca)
end_feat_t = time()

In [6]:
TRAIN_IMG_FNAMES: dict[AllowedDatasets, list[str]] = {"Cu_ore_RLM": ["004", "028", "049", "077"], 
                                                      "Ni_superalloy_SEM": ["000", "001", "005", "007"], 
                                                      "T_cell_TEM": ["000", "005", "007", "026"]
                                                      }

In [7]:
base_labels = TRAIN_IMG_FNAMES[chosen_dataset] #["000", "001", "005", "007"]
all_label_paths = sorted(listdir(f"{PATH}/{chosen_dataset}/labels"))
all_label_fnames = [fname.split('.')[0] for fname in all_label_paths]

label_fnames = base_labels + [fname for fname in all_label_fnames if fname not in base_labels]

In [8]:
start_sparse_train_apply_t = time()
selected_labels = TRAIN_IMG_FNAMES[chosen_dataset]
feat_paths = [f"{PATH}/.tmp/{name.split('.')[0]}.npy" for name in selected_labels]
classifier, _ = train_model_over_images(chosen_dataset, train_cfg, PATH, selected_labels , dv2, upsampler, expr, feat_paths, overwrite_with_gt=True)

all_feat_fnames = [f"{PATH}/.tmp/{fname}" for fname in sorted(listdir(f"{PATH}/.tmp"))]
sparse_deep_preds = apply_model_over_images(chosen_dataset, train_cfg, classifier, PATH, dv2, upsampler, expr, False, -1, pca, all_feat_fnames)
sparse_miou, sparse_std_miou = eval_preds(chosen_dataset, sparse_deep_preds, PATH)
print(f"(4/22): {sparse_miou:.4f} +/-{sparse_std_miou:.4f}")
end_sparse_train_apply_t = time()

(4/22): 0.7523 +/-0.1067


In [9]:
start_full_train_apply_t = time()
selected_labels = TRAIN_IMG_FNAMES[chosen_dataset]
feat_paths = [f"{PATH}/.tmp/{name.split('.')[0]}.npy" for name in selected_labels]
classifier, _ = train_model_over_images(chosen_dataset, train_cfg, PATH, selected_labels , dv2, upsampler, expr, feat_paths, reveal_all=True)

all_feat_fnames = [f"{PATH}/.tmp/{fname}" for fname in sorted(listdir(f"{PATH}/.tmp"))]
full_deep_preds = apply_model_over_images(chosen_dataset, train_cfg, classifier, PATH, dv2, upsampler, expr, False, -1, pca, all_feat_fnames)
full_miou, full_std_miou = eval_preds(chosen_dataset, full_deep_preds, PATH)
print(f"(4/22): {full_miou:.4f} +/-{full_std_miou:.4f}")
end_full_train_apply_t = time()

(4/22): 0.7876 +/-0.1477


In [10]:
n_epochs = 200
out_dir = "fig_data/CNN_comparison/ni_superalloy/stored_CNN_results/"
cnn_full_data = np.load(f"{out_dir}/full_data_4_imgs_e{n_epochs}.npy", allow_pickle=True)
cnn_sparse_data = np.load(f"{out_dir}/sparse_data_4_imgs_e{n_epochs}.npy", allow_pickle=True)

In [11]:
cnn_full_times = [0] + [d["tot_time"] for d in cnn_full_data]
cnn_full_mious = [0] + [d["miou"] for d in cnn_full_data]
cnn_sparse_times = [0] +[d["tot_time"] for d in cnn_sparse_data]
cnn_sparse_mious = [0] +[d["miou"] for d in cnn_sparse_data]

In [12]:
sparse_t = (end_feat_t - start_feat_t) + (end_sparse_train_apply_t - start_sparse_train_apply_t)
full_t = (end_feat_t - start_feat_t) + (end_full_train_apply_t - start_full_train_apply_t)

In [13]:
deep_sparse_times = [0, sparse_t-0.1, sparse_t, cnn_sparse_times[-1]]
deep_sparse_mious = [0, 0, sparse_miou, sparse_miou]

deep_full_times = [0, full_t-0.1, full_t, cnn_full_times[-1]]
deep_full_mious = [0, 0, full_miou, full_miou]

save_per = 10
for i, t in enumerate(cnn_sparse_times):
    if t > sparse_t:
        sparse_train_to_idx = i
        break
sparse_train_to_epoch = (sparse_train_to_idx - 1) * save_per


for i, t in enumerate(cnn_full_times):
    if t > full_t:
        full_train_to_idx = i
        break
full_train_to_epoch = (full_train_to_idx - 1) * save_per
print(sparse_train_to_epoch, full_train_to_epoch)

20 30


In [14]:
cnn_full_data_at_is_train_end = np.load(f"{out_dir}/full_data_4_imgs_e{full_train_to_epoch}.npy", allow_pickle=True)
cnn_sparse_data_at_is_train_end = np.load(f"{out_dir}/sparse_data_4_imgs_e{sparse_train_to_epoch}.npy", allow_pickle=True)

In [60]:
color_list = [[255, 255, 255], [0, 62, 131], [181, 209, 204], [250, 43, 0], [255, 184, 82]]
COLORS = np.array(color_list) / 255.0

In [178]:
TITLE_FS = 25
LABEL_FS = 23
TICK_FS = 21
PAD = 60
TITLE_PAD=25

def hide_axis_ticks(ax, frameoff: bool=True):
    ax.tick_params(which="both", bottom=False, top=False, left=False, right=False)
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    if frameoff:
        ax.set_frame_on(False)

In [120]:
from tifffile import imread
img = imread((images[9]))
label = imread(f"{PATH}/{chosen_dataset}/segmentations/009.tif")

In [121]:
cnn_sparse_final_dict = cnn_sparse_data_at_is_train_end[-1]
cnn_sparse_pred = cnn_sparse_final_dict['preds'][2]
is_sparse_pred = sparse_deep_preds['009.tif']

cnn_full_final_dict = cnn_full_data_at_is_train_end[-1]
gt = cnn_full_final_dict['gts'][2]
cnn_full_pred = cnn_full_final_dict['preds'][2]
is_full_pred = full_deep_preds['009.tif']

In [180]:
%%capture
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches

plt.rcParams["font.family"] = "serif"

# Create the figure
fig = plt.figure(figsize=(20, 8))

# Define the GridSpec
# Set width ratios: first column is 2x wider
gs = gridspec.GridSpec(2, 5, width_ratios=[2.5, 1, 0.2, 1, 1])

# Add the large plot spanning 2 rows and 1 column (first column)
ax_large = fig.add_subplot(gs[:, 0])  # All rows, column 0

ax_large.plot(cnn_full_times, cnn_full_mious, color='C0', ls='-', lw=2, label='CNN, full')
ax_large.plot(cnn_sparse_times, cnn_sparse_mious, color='C0', ls='--', lw=2, label='CNN, sparse')

ax_large.plot(deep_full_times, deep_full_mious, color='C1', ls='-', lw=2, label='IS, full')
ax_large.plot(deep_sparse_times, deep_sparse_mious, color='C1', ls='--', lw=2, label='IS, sparse')

ax_large.set_ylabel('mIoU', fontsize=LABEL_FS)
ax_large.set_xlabel('Time (s)', fontsize=LABEL_FS)

ax_large.tick_params(axis='both', labelsize=TICK_FS)

ax_large.legend(fontsize=TICK_FS)

# Add the remaining subplots
ax_1 = fig.add_subplot(gs[0, 1])
ax_1.imshow(img, cmap='gist_grey')
hide_axis_ticks(ax_1)
ax_1.set_title('Val. Image', fontsize=TITLE_FS)


ax_2 = fig.add_subplot(gs[0, 3])
ax_2.set_ylabel('Sparse labels', fontsize=LABEL_FS)
ax_2.set_title('CNN', fontsize=TITLE_FS, pad=TITLE_PAD)
ax_2.imshow(label2rgb(cnn_sparse_pred, colors=COLORS[1:], bg_label=-1))
hide_axis_ticks(ax_2)

ax_3 = fig.add_subplot(gs[0, 4])
ax_3.set_title('HR ViT', fontsize=TITLE_FS, pad=TITLE_PAD)
ax_3.imshow(label2rgb(is_sparse_pred, colors=COLORS[1:], bg_label=-1))
hide_axis_ticks(ax_3)


ax_4 = fig.add_subplot(gs[1, 1])
ax_4.set_title('Ground truth', fontsize=TITLE_FS)
ax_4.imshow(label2rgb(label, colors=COLORS[1:], bg_label=-1))
hide_axis_ticks(ax_4)

ax_5 = fig.add_subplot(gs[1, 3])
ax_5.set_ylabel('Full labels', fontsize=LABEL_FS)
ax_5.imshow(label2rgb(cnn_full_pred, colors=COLORS[1:], bg_label=-1))
hide_axis_ticks(ax_5)

ax_6 = fig.add_subplot(gs[1, 4])
ax_6.imshow(label2rgb(is_full_pred, colors=COLORS[1:], bg_label=-1))
hide_axis_ticks(ax_6)

for key, ax in zip(('a', 'b'), (ax_large, ax_1,)):
    y = 1.20 if key =='b' else 1.05
    x = -0.20 if key =='b' else -0.15
    ax.text(x, y, f"{key}.", transform=ax.transAxes, 
            size=LABEL_FS + 4, weight='bold')



for ax_pair in ((ax_2, ax_3), (ax_5, ax_6)):
    ax1, ax2 = ax_pair
    pos1 = ax1.get_position()
    pos2 = ax2.get_position()

    # Compute bounding box that covers both axes
    x0 = pos1.x0
    y0 = min(pos1.y0, pos2.y0)
    x1 = pos2.x1
    y1 = max(pos1.y1, pos2.y1)

    # Add rectangle to figure behind those axes
    rect = patches.FancyBboxPatch(
        (x0 - 0.01, y0 - 0.01),           # Bottom left corner
        (x1 - x0) + 0.01,                 # Width
        (y1 - y0) + 0.02,                 # Height
        boxstyle="square,pad=0.02",
        # boxstyle="",
        # edgecolor="red",
        edgecolor="lightgray",
        facecolor="lightgray",
        linewidth=2,
        transform=fig.transFigure,        # Important: coordinates in figure space
        zorder=-10                          # Put behind axes
    )
    fig.patches.append(rect)


# plt.tight_layout()
plt.savefig('fig_out/CNN_comparison.png', bbox_inches='tight')
plt.show()

In [123]:
%%capture
fig, axs = plt.subplots(nrows=1, ncols=3)
final_dict = cnn_sparse_data_at_is_train_end[-1]
gt = final_dict['gts'][2]
cnn_pred = final_dict['preds'][2]
axs[0].imshow(gt)
axs[1].imshow(cnn_pred)
axs[2].imshow(sparse_deep_preds['009.tif'])

In [124]:
%%capture
fig, axs = plt.subplots(nrows=1, ncols=3)
final_dict = cnn_full_data_at_is_train_end[-1]
gt = final_dict['gts'][2]
cnn_pred = final_dict['preds'][2]
axs[0].imshow(gt)
axs[1].imshow(cnn_pred)
axs[2].imshow(full_deep_preds['009.tif'])

In [125]:
%%capture
fig, axs = plt.subplots(nrows=1, ncols=3)
final_dict = cnn_full_data_at_is_train_end[-1]
gt = final_dict['gts'][1]
cnn_pred = final_dict['preds'][1]
axs[0].imshow(gt)
axs[1].imshow(cnn_pred)
axs[2].imshow(full_deep_preds['019.tif'])

In [126]:
%%capture
plt.plot(cnn_full_times, cnn_full_mious, color='C0', ls='-')
plt.plot(cnn_sparse_times, cnn_sparse_mious, color='C0', ls='--')

plt.plot(deep_full_times, deep_full_mious, color='C1', ls='-')
plt.plot(deep_sparse_times, deep_sparse_mious, color='C1', ls='--')

plt.ylabel('mIoU')
plt.xlabel('Time (s)')