In [None]:
import fibsem_tools as fst
import matplotlib.pyplot as plt
import numpy as np
from skimage.morphology import binary_erosion, binary_dilation, ball
from scipy.spatial.distance import dice
import yaml
from pathlib import Path
import multiprocessing
import json
import logging
from mpl_interactions import hyperslicer

%matplotlib ipympl

In [None]:
def make_fake_mem(combo, radius, return_lumen = False):
    r = radius * 2 //2
    strict = (radius*2)%2 ==0
    mem = combo ^ binary_erosion(combo, ball(r, strict_radius=strict))
    if return_lumen:
        lumen = combo 
    else:
        return mem
def get_best_fake_mem(combo, mem, *, return_arr=True):
    mem = mem[:]
    combo = combo[:]
    best_match_score = 1.
    best_radius = 0
    for radius in np.arange(1,11, 0.5):
        fake_mem = make_fake_mem(combo, radius)
        score = dice(fake_mem.flatten(), mem[:].flatten())
        if score < best_match_score:
            best_radius = radius
            best_match_score = score
            #radius = r + 0.5 * (not strict)
    if return_arr:
        fake_mem = make_fake_mem(combo, best_radius)
        return fake_mem, best_radius, 1-best_match_score
    else:
        return best_radius, 1-best_match_score

In [None]:
with open("../selected_data_8nm_mem+org.yaml") as f:
    data = yaml.safe_load(f)

In [None]:
all_paths = []
for dataset, dataset_info in data["datasets"].items():
    for crops in dataset_info["crops"]:
        for crop in crops.split(","):
            crop_path = Path(data["gt_path"])/dataset/"groundtruth.zarr" /crop
            all_paths.append(crop_path)

In [None]:
def process_func(crop_path, combo_name="er", mem_name="er_mem_all"):
    crop_zarr = fst.read(crop_path)
    if all(cl in crop_zarr.attrs["cellmap"]["annotation"]["class_names"] for cl in [combo_name, mem_name]):
        combo = fst.read(crop_path / combo_name / "s0")
        complement = sum(v for v in combo.attrs["cellmap"]["annotation"]["complement_counts"].values())
        present = np.prod(combo.shape) - complement
        if present > 0:
            mem = fst.read(crop_path / mem_name / "s0")
            best_radius, best_match_score = get_best_fake_mem(combo, mem, return_arr=False)
            return best_radius, best_match_score
    return None, None

In [None]:
def get_corresponding_raw(crop_path, data):
    for dataset in data["datasets"].keys():
        if dataset in crop_path:
            raw = fst.read_xarray(data["datasets"][dataset]["raw"])
            crop = fst.read_xarray(Path(crop_path)/"all"/"s0")
            return raw["s0"].interp(crop.coords)
    msg = f"Did not find raw for {crop_path}"
    raise ValueError(msg)


In [None]:
pairs = [("er", "er_mem_all"), ("golgi", "golgi_mem"), ("endo", "endo_mem"), ("ves", "ves_mem"), 
         ("lyso", "lyso_mem"), ("ld", "ld_mem"), ("perox", "perox_mem")]
def all_process_func(path):
    result_dict = {}
    for combo_name, mem_name in pairs:
        result_dict[combo_name] = process_func(path, combo_name=combo_name, mem_name=mem_name)
    return result_dict


In [None]:
if Path("../erosions.json").exists():
    with open("../erosions.json") as f:
        all_results = json.load(f)
else:
    pool_obj = multiprocessing.Pool(44)
    ans = pool_obj.map(all_process_func,all_paths)
    all_results = {}
    for k, v in zip(all_paths, ans):
        all_results[str(k)] = v
    with open("../erosions.json", "w") as f:
        json.dump(all_results, f)

In [None]:
class ResultsIterator:
    def __init__(self, data, required_classes = ("mito", "mito_mem")):
        self.data = data
        self.iterator = iter(data.items())
        self.required_classes = required_classes

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            path, result = next(self.iterator)  # Get the next item
            try:
                if all(cl in fst.read(Path(path)).attrs["cellmap"]["annotation"]["class_names"] for cl in self.required_classes):
                    return path, result  # Return the item if condition is met
            except (OSError, KeyError, AttributeError) as e:
                # Handle cases where path doesn't exist or structure is not as expected
                logging.warning(f"Skipping {path} due to error: {e}")
                continue  # Skip to the next item
class PairsIterator:
    def __init__(self, pairs, result):
        self.iterator = iter(pairs)
        self.result = result

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            combo_name, mem_name = next(self.iterator)  # Get the next item
            try:
                if self.result[combo_name][0] is not None and self.result[combo_name][1] is not None:
                    return combo_name, mem_name  # Return the item if condition is met
            except KeyError as e:
                # Handle cases where path doesn't exist or structure is not as expected
                logging.warning(f"Skipping {combo_name} due to error: {e}")
                continue  # Skip to the next item


In [None]:
results_iter = ResultsIterator(all_results)

In [None]:
i = 0
while i< 10:
    path, result = next(results_iter)
    print(path)
    i += 1

In [None]:
path, result = next(results_iter)
mito = fst.read(Path(path) / "mito" / "s0")
mito_mem = fst.read(Path(path)/ "mito_mem" / "s0")
raw = np.array(get_corresponding_raw(path, data).data)
print(f"path: {path}; sum: {np.sum(mito_mem)}")

In [None]:
radii_to_lbl = {}
for combo_name, _ in PairsIterator(pairs, result):
    if result[combo_name][0] in radii_to_lbl.keys():
        radii_to_lbl[result[combo_name][0]].append(combo_name)
    else:
        radii_to_lbl[result[combo_name][0]] = [combo_name]

radii_to_mito_mem_mems = {}
for radius in radii_to_lbl.keys():
    radii_to_mito_mem_mems[radius] = make_fake_mem(mito_mem, radius)
fig, axs = plt.subplots(2,1+len(radii_to_lbl), figsize=(5*(1+len(radii_to_lbl)), 15))
fig.suptitle(f"{path}")
ssp = 2
radii_iterator = iter(radii_to_lbl.items())

control1 = hyperslicer(raw[::ssp,...], play_buttons=True, play_button_pos="left", ax = axs[0][0], title="raw", cmap="Greys_r",vmin=0,vmax=255)
_ = hyperslicer(mito_mem[::ssp,...], play_buttons=True, play_button_pos="left", ax = axs[1][0], controls=control1, title=f"mito_mem", cmap="inferno")
for k, (radius, lbl) in enumerate(radii_iterator):
    _ = hyperslicer(raw[::ssp,...], play_buttons=True, play_button_pos="left", ax = axs[0][k+1], cmap="Greys_r", controls=control1)
    _ = hyperslicer(radii_to_mito_mem_mems[radius][::ssp,...], play_buttons=True, play_button_pos="left", ax = axs[0][k+1], controls=control1, title=f"{radius=},{lbl}", cmap="inferno", alpha=0.5)
    _ = hyperslicer(radii_to_mito_mem_mems[radius][::ssp,...], play_buttons=True, play_button_pos="left", ax = axs[1][k+1], controls=control1, title=f"{radius=},{lbl}", cmap="inferno")


In [None]:
plt.close("all")