In [None]:
import bioformats
import javabridge
import numpy as np
import tifffile
from tqdm import tqdm, trange

from basicpy import BaSiC

javabridge.start_vm(class_path=bioformats.JARS)

In [None]:
stack_list = []
ind_include = [62, 61, 72, 71]
for ind in tqdm(ind_include):
    stack_test_path = f"./250317_ChickEmbryo18h_SOX2_EOMES_TBXT_{ind}.vsi"
    stack_test = None
    with bioformats.ImageReader(stack_test_path) as reader:
        meta = bioformats.get_omexml_metadata(stack_test_path)
        meta = bioformats.omexml.OMEXML(meta).image(0)
        for i in trange(meta.Pixels.SizeZ):
            img = (np.expand_dims(reader.read(z=i), 0) * 65535).astype(np.uint16)
            if stack_test is None:
                stack_test = img
            else:
                stack_test = np.concatenate([stack_test, img], axis=0)
    stack_list.append(stack_test)

In [None]:
full_stack = np.concatenate(stack_list, 0)

In [None]:
# fit one and transform others? I will do independently
stack_correct_list = []
basic_models = []
trained = False
for stack_test in tqdm(stack_list):
    stack_correct = np.zeros_like(stack_test)
    for i in trange(stack_correct.shape[3]):
        if not trained:
            basic = BaSiC(get_darkfield=True)
            basic.fit(full_stack[:, :, :, i])
            # basic.autotune(full_stack[:, :, :, i], early_stop=True, n_iter=100)
            basic_models.append(basic)
        stack_correct[:, :, :, i] = basic_models[i].transform(stack_test[:, :, :, i])
    trained = True
    stack_correct_list.append(stack_correct)

In [None]:
def save_imagej_compatible(stack, filename):
    """
    Save stack in ImageJ-compatible format
    """
    stack_uint16 = stack.astype(np.uint16)

    tifffile.imwrite(
        filename,
        stack_uint16,
        metadata={
            "axes": "ZYXC",
            "mode": "composite",  # For multi-channel
        },
    )


for i, stack_test in enumerate(stack_correct_list):
    save_imagej_compatible(np.floor(stack_test), f"./test_corrected_{i}.tif")

In [None]:
javabridge.kill_vm()