In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

import matplotlib.pyplot as plt
from skimage.color import label2rgb
from tifffile import imread
from os import listdir

from yoeo.main import get_dv2_model, get_upsampler_and_expr, get_hr_feats
from yoeo.utils import to_numpy

from interactive_seg_backend import featurise_
from interactive_seg_backend.configs import FeatureConfig, TrainingConfig
from interactive_seg_backend.classifiers.base import Classifier
from interactive_seg_backend.file_handling import load_labels, load_image
from interactive_seg_backend.core import train, get_training_data, shuffle_sample_training_data, get_model
from interactive_seg_backend.core import apply_

from typing import Literal

SEED = 10672
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = "cuda:1"

N CPUS: 110


In [2]:
cmap = [
            "#fafafa",
            "#1f77b4",
            "#ff7f0e",
            "#2ca02c",
            "#d62728",
        ]
color_list = [[255, 255, 255], [31, 119, 180], [255, 127, 14], [44, 160, 44], [255, 0, 0]]
COLORS = np.array(color_list) / 255.0

In [3]:
dv2 = get_dv2_model(True, device=DEVICE)

model_path = "../trained_models/e5000_full_fit_reg.pth"
cfg_path = "../yoeo/models/configs/combined_no_shift.json"

upsampler, expr = get_upsampler_and_expr(model_path, cfg_path, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main


In [4]:
PATH = "fig_data/is_benchmark"
AllowedDatasets = Literal["Cu_ore_RLM", "Ni_superalloy_SEM", "T_cell_TEM"]
dataset: tuple[AllowedDatasets, ...] = ("Cu_ore_RLM", "Ni_superalloy_SEM", "T_cell_TEM")

TRAIN_IMG_FNAMES: dict[AllowedDatasets, list[str]] = {"Cu_ore_RLM": ["001", "028", "049", "068"], 
                                                      "Ni_superalloy_SEM": ["000", "001", "005", "007"], 
                                                      "T_cell_TEM": ["000", "027", "021", "105"]
                                                      }

In [5]:
def get_deep_feats(img: Image.Image, K: int =32) -> np.ndarray:
    hr_feats = get_hr_feats(img, dv2, upsampler, DEVICE, n_ch_in=expr.n_ch_in)
    hr_feats_np = to_numpy(hr_feats)
    hr_feats_np = hr_feats_np.transpose((1, 2, 0))[:, :, :K]
    return hr_feats_np

def train_model_over_images(dataset: AllowedDatasets, train_cfg: TrainingConfig) -> Classifier:
    features, labels = [], []

    train_fnames = TRAIN_IMG_FNAMES[dataset]
    for fname in train_fnames:
        img_path = f"{PATH}/{dataset}/images/{fname}.tif"
        labels_path = f"{PATH}/{dataset}/labels/{fname}.tif"

        img_arr = load_image(img_path)
        label_arr = load_labels(labels_path)

        feats = featurise_(img_arr, train_cfg.feature_config)
        if train_cfg.add_dino_features:
            img = Image.fromarray(img_arr).convert('RGB')
            deep_feats = get_deep_feats(img, 32)
            feats = np.concatenate((feats, deep_feats), axis=-1)

        features.append(feats)
        labels.append(label_arr)

    print('Finished featurising')
    fit, target = get_training_data(features, labels)
    fit, target = shuffle_sample_training_data(
        fit, target, train_cfg.shuffle_data, train_cfg.n_samples
    )
    model = get_model(
        train_cfg.classifier, train_cfg.classifier_params, train_cfg.use_gpu
    )
    model = train(model, fit, target, None)
    return model

In [6]:
def apply_model_over_images(dataset: AllowedDatasets, train_cfg: TrainingConfig, model: Classifier, verbose: bool=False, early_cutoff_n: int = -1) -> list[np.ndarray]:
    preds: list[np.ndarray] = []
    img_fnames = sorted(listdir(f"{PATH}/{dataset}/images"))
    N_imgs = len(img_fnames)

    for i, fname in enumerate(img_fnames[:early_cutoff_n]):
        if verbose and i % 10 == 0:
            print(f"[{i:02d}/{N_imgs}] - {fname}")
        img_path = f"{PATH}/{dataset}/images/{fname}"
        img_arr = load_image(img_path)

        feats = featurise_(img_arr, train_cfg.feature_config)
        if train_cfg.add_dino_features:
            img = Image.fromarray(img_arr).convert('RGB')
            deep_feats = get_deep_feats(img, 32)
            feats = np.concatenate((feats, deep_feats), axis=-1)

        pred, _ = apply_(model, feats)
        preds.append(pred)
    return preds

In [7]:
chosen_dataset: AllowedDatasets = "Ni_superalloy_SEM"

In [15]:
feat_cfg = FeatureConfig()
train_cfg = TrainingConfig(feat_cfg, n_samples=-1, add_dino_features=False, classifier='xgb', classifier_params = {"class_weight": "balanced"},)

model = train_model_over_images(chosen_dataset, train_cfg)

Finished featurising


In [16]:
preds = apply_model_over_images(chosen_dataset, train_cfg, model, verbose=True)

[00/23] - 000.tif
[10/23] - 010.tif
[20/23] - 020.tif
