In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import iss_preprocess as iss
import glob
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

In [None]:
# Load mask dfs from each tile
data_path = "becalia_rabies_barseq/BRAC8498.3e/chamber_07/"
processed_path = iss.io.get_processed_path(data_path)
prefix = "mCherry_1"
ops = iss.io.load_ops(data_path)
df_dir = processed_path / "cells"
df_files = glob.glob(str(df_dir / "*.pkl"))
dfs = [pd.read_pickle(f) for f in df_files]
df = pd.concat(dfs)

In [None]:
# Plot reference tile no filtering
ops = iss.io.load_ops(data_path)
print(ops["mcherry_ref_tile"])
from iss_preprocess.pipeline import load_and_register_tile
(roi, tilex, tiley) = ops["mcherry_ref_tile"]
stack, _ = load_and_register_tile(
    data_path, tile_coors=(roi, tilex, tiley), prefix=prefix, filter_r=False
)
plt.title("Reference tile raw mCherry channel")
plt.imshow(stack[:,:,2], vmax=300)

In [None]:
from sklearn.linear_model import LinearRegression
suffix="max"
background_ch=3
signal_ch=2
processed_path = iss.io.get_processed_path(data_path)
fname = (
    f"{prefix}_MMStack_{roi}-"
    + f"Pos{str(tilex).zfill(3)}_{str(tiley).zfill(3)}_{suffix}.tif"
)
image_path = processed_path / prefix / fname
background_image = stack[:, :, background_ch, 0]
mixed_signal_image = stack[:, :, signal_ch, 0]

# Flatten to 1D arrays for the regression model
background_flat = background_image.ravel()
mixed_signal_flat = mixed_signal_image.ravel()

ops = iss.io.load_ops(data_path)
background_coef = ops["background_coef"]
threshold_background = ops["threshold_background"]

# Remove pixels that are too dark or too bright
bright_pixels = (
    (background_flat > threshold_background) & (background_flat < 4090)
) & ((mixed_signal_flat > threshold_background) & (mixed_signal_flat < 4090))
background_flat = background_flat[bright_pixels].reshape(-1, 1)
mixed_signal_flat = mixed_signal_flat[bright_pixels]

# Initialize and fit Linear model
model = LinearRegression(positive=True)
try:
    model.fit(background_flat, mixed_signal_flat)
    # Predict the background component in the mixed signal image
    predicted_background_flat = model.predict(
        background_image.ravel().reshape(-1, 1)
    )

    predicted_background = predicted_background_flat.reshape(background_image.shape)

    # Subtract the predicted background from the mixed signal to get the signal image
    signal_image = mixed_signal_image - (
        predicted_background * background_coef
    )  # TODO: Remove fudge factor
    signal_image = np.clip(signal_image, 0, None)
    print(
        f"Image unmixed with coefficient: {model.coef_[0]}, intercept: {model.intercept_}"
    )
except ValueError:
    raise ValueError("Not enough data passing background threshold to fit model")

coef = model.coef_[0]
intercept = model.intercept_

In [None]:
#plot linear regression
plt.figure()
plt.scatter(background_flat, mixed_signal_flat, s=1)
plt.plot(background_flat, model.predict(background_flat), color="red")
plt.xlabel("Background")
plt.ylabel("Signal")
plt.title("Linear Regression")
plt.text(
    0.5,
    0.9,
    f"y = {coef:.2f}x + {intercept:.2f}",
    horizontalalignment="center",
    verticalalignment="center",
    transform=plt.gca().transAxes,
)
plt.show()

In [None]:
from skimage.filters import threshold_triangle
plt.figure(dpi=300)
fig, ax = plt.subplots(1,2, dpi=300)
ax[0].imshow(signal_image, vmax=300)
ax[0].set_title("Signal Image")
#make signal image have extra dimension
signal = signal_image[:, :, np.newaxis]
filt = iss.image.filter_stack(
    signal, r1=ops["mcherry_r1"], r2=ops["mcherry_r2"], dtype=float
)
binary = (filt > threshold_triangle(filt))[:, :, 0]
ax[1].set_title("Binary Image")
ax[1].imshow(binary)

In [None]:
df_thresh = df[df["area"] > 600]
df_thresh = df_thresh[df_thresh["area"] < 5000]
df_thresh = df_thresh[df_thresh["circularity"] <= 1]
df_thresh = df_thresh[df_thresh["circularity"] >= 0.7]
df_thresh = df_thresh[df_thresh["solidity"] >= 0.9]
df_thresh = df_thresh[df_thresh["solidity"] < 1]
df_thresh = df_thresh[df_thresh["eccentricity"] <= 0.99]
df_thresh = df_thresh[df_thresh["intensity_mean-3"] < 200]
features = ['area', 'circularity', 'solidity', 'intensity_mean-3', 'intensity_mean-2'] #,'solidity' 'major_axis_length', 'minor_axis_length' #  'solidity', 'intensity_mean-3', 'intensity_mean-2' ,'eccentricity',
df_thresh_norm = (df_thresh[features] - df_thresh[features].min()) / (df_thresh[features].max() - df_thresh[features].min())
scaled_features = scaler.fit_transform(df_thresh_norm[features])
df_scaled_features = pd.DataFrame(scaled_features, columns=features)
df_thresh

In [None]:
np.unique(df_thresh["cluster_label"], return_counts=True)

In [None]:
cluster_centers_scaled = np.array(
    [
        [-0.81560289, -1.16570977, -1.16885992, 0.68591332, -0.47768646],
        [-0.08201876, 0.48188625, 0.38447341, -0.41695244, -0.42873761],
        [0.97601349, 0.61105821, 0.74187513, -0.18499336, 1.06977134],
    ]
)
n_components = 3  # for good masks vs debris, adjust as necessary
from sklearn.mixture import GaussianMixture
gmm = GaussianMixture(n_components=n_components, 
                      #init_params='kmeans',  # default, consider 'random' if k-means doesn't work well
                      means_init=cluster_centers_scaled,  # optional, if you have strong priors about cluster centers
                      random_state=42,
                      verbose=2)

# Fit the model
gmm.fit(df_scaled_features[features])

# Predict the cluster labels
labels = gmm.predict(df_scaled_features[features])

cluster_centers_scaled = gmm.means_

# Inverse transform the scaled cluster centers
cluster_centers_norm = scaler.inverse_transform(cluster_centers_scaled)

# Rescale normalized cluster centers to original scale
min_values = df_thresh[features].min().values
max_values = df_thresh[features].max().values

# Rescale normalized cluster centers to original scale
cluster_centers = cluster_centers_norm * (max_values - min_values) + min_values

df_thresh['cluster_label'] = labels + 1
image_df = df_thresh[(df_thresh["roi"] == roi) & (df_thresh["tilex"] == tilex) & (df_thresh["tiley"] == tiley)]

In [None]:
import seaborn as sns
pairplot_fig = sns.pairplot(df_thresh[[        
        'area', 'circularity', 'solidity', 'intensity_mean-3', 'intensity_mean-2']], #'area', 'circularity', 'eccentricity','intensity_ratio', 'solidity', 'major_axis_length', 'minor_axis_length', 'intensity_mean-3'
        diag_kind=None,
        plot_kws={"s": 5, "alpha": 0.3, "c":labels, "cmap": "tab10"},
)

#overlay the cluster centers on the pairplot
axes = pairplot_fig.axes
feature_names = ['area', 'circularity', 'solidity', 'intensity_mean-3', 'intensity_mean-2']
for i, feature_i in enumerate(feature_names):
    for j, feature_j in enumerate(feature_names):
        #if i != j:
            # Only plot on the off-diagonal plots
        for center in cluster_centers:
                axes[i, j].scatter(center[j], center[i], c='red', s=50)  # color and size of the cluster center points


# Now supress FutureWarning: is_categorical_dtype is deprecated and will be removed in a future version.
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

plt.show()

In [None]:
all_masks = np.load(df_dir / f"mCherry_1_masks_{roi}_{tilex}_{tiley}.npy", allow_pickle=True)
cell_masks = np.load(df_dir / f"mCherry_1_cell_masks_{roi}_{tilex}_{tiley}.npy", allow_pickle=True)


plt.figure(dpi=300)
fig, ax = plt.subplots(1,2, dpi=300)
ax[0].set_title("All masks")
ax[0].imshow(binary)

ax[1].set_title("Cell masks")
ax[1].imshow(cell_masks, vmax=1)