In [1]:
import torch
import numpy as np
from PIL import Image
from random import seed

import matplotlib.pyplot as plt
from skimage.color import label2rgb
from tifffile import imread
from os import listdir

from vulture.main import get_dv2_model, get_upsampler_and_expr

from interactive_seg_backend.configs import FeatureConfig, TrainingConfig
from interactive_seg_backend.file_handling import load_labels
from is_helpers import train_model_over_images, apply_model_over_images, eval_preds

from interactive_seg_backend.classifiers import XGBCPU
import pandas as pd

from typing import Literal

SEED = 10672
np.random.seed(SEED)
torch.manual_seed(SEED)
seed(SEED)
DEVICE = "cuda:1"

N CPUS: 110


In [2]:
dv2 = get_dv2_model(True, device=DEVICE)

model_path = "../trained_models/fit_reg_f128.pth"

upsampler, expr = get_upsampler_and_expr(model_path, None, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main


In [3]:
PATH = "fig_data/is_benchmark"
AllowedDatasets = Literal["Ni_superalloy_SEM", "T_cell_TEM", "Cu_ore_RLM"]
dataset: tuple[AllowedDatasets, ...] = ("Ni_superalloy_SEM",)

TRAIN_IMG_FNAMES: dict[AllowedDatasets, list[str]] = {"Cu_ore_RLM": ["004", "028", "049", "077"], 
                                                      "Ni_superalloy_SEM": ["000", "001", "005", "007"], 
                                                      "T_cell_TEM": ["000", "005", "007", "026"]
                                                      }

all_classical_preds: dict[AllowedDatasets, dict[str, np.ndarray]] = {k: {} for k in dataset}
all_deep_preds: dict[AllowedDatasets, dict[str, np.ndarray]] = {k: {} for k in dataset}

In [4]:
chosen_dataset = "T_cell_TEM"
feat_cfg = FeatureConfig()

classical_train_cfg = TrainingConfig(feat_cfg, n_samples=-1, add_dino_features=False, classifier='xgb', classifier_params = {"class_weight": "balanced", "max_depth": 32},)
classical_model, _ = train_model_over_images(chosen_dataset, classical_train_cfg, PATH, TRAIN_IMG_FNAMES[chosen_dataset], dv2, upsampler, expr, overwrite_with_gt=True)

Finished featurising
Finished featurising
Finished featurising
Finished featurising


In [5]:
deep_train_cfg = TrainingConfig(feat_cfg, n_samples=-1, add_dino_features=True, classifier='xgb', classifier_params = {"class_weight": "balanced", "max_depth": 32},)
deep_model, pca = train_model_over_images(chosen_dataset, deep_train_cfg, PATH, TRAIN_IMG_FNAMES[chosen_dataset], dv2, upsampler, expr, overwrite_with_gt=True)

Finished featurising
Finished featurising
Finished featurising
Finished featurising


In [None]:
def get_max_depths(df: pd.DataFrame):
    """
    Returns a list of max depths for each tree in the XGBoost tree DataFrame.
    """
    max_depths = []
    for tree_id in df['Tree'].unique():
        tree_df = df[df['Tree'] == tree_id].set_index('ID')
        print(tree_df.head(1))
        # problem: not always 0-0
        stack = [('0-0', 0)]  # (node_id, depth)
        print(tree_df.head())
        max_depth = 0
        while stack:
            node_id, depth = stack.pop()
            print(node_id, depth)
            node = tree_df.loc[node_id]
            if node['Feature'] == 'Leaf':
                max_depth = max(max_depth, depth)
            else:
                stack.append((node['Yes'], depth + 1))
                stack.append((node['No'], depth + 1))
        max_depths.append(max_depth)
    return max_depths

def get_tree_stats(model: XGBCPU):
    booster = model.model.get_booster()

    df: pd.DataFrame = booster.trees_to_dataframe()

    trees = df['Tree'].unique()

    avg_depth = np.log(df.groupby("Tree")["Node"].count())

    max_depths = get_max_depths(df)
    # print(max_depths)

    avg_nodes = df.groupby("Tree")["Node"].count().mean()

    leaves = df[df["Feature"] == "Leaf"].groupby("Tree")
    # Average number of leaves per tree
    avg_leaves = leaves.size().mean()

    # print(leaves.size())
    depth = np.mean(np.log2(leaves.size()))
    # print(depth)

    # print(df.head(10))

    # print(f"Average depth: {avg_depth:.2f}")
    print(f"Average nodes: {avg_nodes:.2f}")
    print(f"Average leaves: {avg_leaves:.2f}")

In [21]:
# get_tree_stats(classical_model)

In [30]:
get_tree_stats(deep_model)

     Tree  Node Feature     Split  Yes   No Missing        Gain       Cover  \
ID                                                                            
0-0     0     0     f84  0.000208  0-1  0-2     0-2  132116.969  247923.547   

     Category  
ID             
0-0       NaN  
     Tree  Node Feature     Split  Yes    No Missing          Gain  \
ID                                                                   
0-0     0     0     f84  0.000208  0-1   0-2     0-2  132116.96900   
0-1     0     1     f58  0.113342  0-3   0-4     0-4   27345.17190   
0-2     0     2     f59  0.061768  0-5   0-6     0-6   19102.42380   
0-3     0     3     f66 -0.059845  0-7   0-8     0-8    2224.18750   
0-4     0     4     f62  0.030518  0-9  0-10    0-10    5992.72607   

           Cover  Category  
ID                          
0-0  247923.5470       NaN  
0-1  135476.4380       NaN  
0-2  112447.1020       NaN  
0-3  101494.6640       NaN  
0-4   33981.7773       NaN  
0-0 0
0-2 1
0-6 2
0-

KeyError: '0-0'