# Imports

In [None]:
import glob
import itertools as it
from functools import partial
from pathlib import Path

import basicpy
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import scipy
import skimage
import zarr
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from paulssonlab.image_analysis import workflow
from paulssonlab.image_analysis.ui import display_image

# Config

In [None]:
# nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/221227daniel/Experiment.nd2"
# nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
nd2_filename = workflow.SplitFilename(
    sorted(
        glob.glob(
            "/home/jqs1/scratch/jqs1/microscopy/230619/230619_NAO745_repressilators_split.nd2*"
        )
    )
)

In [None]:
nd2 = workflow.get_nd2_reader(nd2_filename)

In [None]:
nd2.metadata["channels"]

# Dark/Flat

In [None]:
nd2.sizes

In [None]:
%%time
images = np.stack(
    [
        nd2.get_frame_2D(v=v, t=t, c=0)
        for v in trange(0, 260, 4)
        for t in range(0, 199, 40)
    ]
)

In [None]:
threshold = skimage.filters.threshold_otsu(images[0])

In [None]:
%%time
image_weights = np.stack(
    [(image < threshold).astype(np.float32) for image in tqdm(images)]
)

In [None]:
%%time
image_weights = np.stack(
    [
        (~scipy.ndimage.binary_dilation(image >= threshold, iterations=8)).astype(
            np.float32
        )
        for image in tqdm(images)
    ]
)

In [None]:
?basicpy.BaSiC

In [None]:
%%time
basic = basicpy.BaSiC(
    get_darkfield=True, smoothness_flatfield=1, smoothness_darkfield=1
)
basic.fit(images, image_weights)

In [None]:
display_image(images[0], downsample=4, scale=0.9)

In [None]:
skimage.filters.threshold_otsu(images[0])

In [None]:
images[0] > skimage.filters.threshold_otsu(images[0])

In [None]:
display_image(images[20], downsample=4, scale=0.9)

In [None]:
%%time
weighted_sum = (images * image_weights).sum(axis=0)

In [None]:
%%time
sum_of_weights = image_weights.sum(axis=0)

In [None]:
avg = weighted_sum / sum_of_weights

In [None]:
sum_of_weights

In [None]:
display_image(sum_of_weights < 320, downsample=4)

In [None]:
display_image(sum_of_weights, downsample=4, scale=0.9)

In [None]:
display_image(avg, downsample=4, scale=0.9)

In [None]:
display_image(image_weights.max(axis=0), downsample=4)

In [None]:
%%time
avg = np.average(images, axis=0, weights=image_weights)

In [None]:
display_image(image_weights[0], downsample=4, scale=0.9)

In [None]:
display_image(
    (images[20] > skimage.filters.threshold_otsu(images[0])).astype(np.uint8),
    downsample=4,
    scale=0.9,
)

In [None]:
display_image(basic.flatfield, downsample=4, scale=0.9)

In [None]:
display_image(basic.flatfield, downsample=4, scale=0.9)

In [None]:
display_image(basic.darkfield, downsample=4, scale=0.9)