In [None]:
import iss_preprocess as iss
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tifffile

# Registering DAPI_1_1 overviews to reference
### using modified stitch and register which saves downsampled stitched images before shifting and pads images

In [None]:
import warnings
from skimage import transform
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from image_tools.registration import phase_correlation as mpc
from image_tools.similarity_transforms import make_transform, transform_image

def stitch_and_register(
    data_path,
    target_prefix,
    reference_prefix=None,
    roi=1,
    downsample=3,
    ref_ch=0,
    target_ch=0,
    estimate_scale=False,
    estimate_rotation=True,
    target_projection=None,
    use_masked_correlation=False,
    debug=False,
):
    """Stitch target and reference stacks and align target to reference

    To speed up registration, images are downsampled before estimating registration
    parameters. These parameters are then applied to the full scale image.

    The reference stack always use the "projection" from ops as suffix. The target uses
    the same by default but that can be specified with `target_suffix`

    This does not use ops['max_shift_rounds'].

    Args:
        data_path (str): Relative path to data.
        reference_prefix (str): Acquisition prefix to register the stitched image to.
            Typically, "genes_round_1_1".
        target_prefix (str): Acquisition prefix to register.
        roi (int, optional): ROI ID to register (as specified in MicroManager).
            Defaults to 1.
        downsample (int, optional): Downsample factor for estimating registration
            parameter. Defaults to 5.
        ref_ch (int, optional): Channel of the reference image used for registration.
            Defaults to 0.
        target_ch (int, optional): Channel of the target image used for registration.
            Defaults to 0.
        estimate_scale (bool, optional): Whether to estimate scaling between target
            and reference images. Defaults to False.
        estimate_rotation (bool, optional): Whether to estimate rotation between target
            and reference images. Defaults to True.
        target_suffix (str, optional): Suffix to use for target stack. If None, will use
            the value from ops. Defaults to None.
        use_masked_correlation (bool, optional): Use masked correlation for registration.
            Defaults to False.
        debug (bool, optional): If True, return full xcorr. Defaults to False.

    Returns:
        numpy.ndarray: Stitched target image after registration.
        numpy.ndarray: Stitched reference image.
        float: Estimate rotation angle.
        tuple: Estimated X and Y shifts.
        float: Estimated scaling factor.
        dict: Debug information if `debug` is True.
    """
    warnings.warn(
        "stitching is now done on registered tiles", DeprecationWarning, stacklevel=2
    )
    ops = iss.io.load_ops(data_path)

    if target_projection is None:
        target_projection = ops[f"{target_prefix.split('_')[0].lower()}_projection"]
    if reference_prefix is None:
        reference_prefix = ops["reference_prefix"]

    ref_projection = ops[f"{reference_prefix.split('_')[0].lower()}_projection"]
    if isinstance(target_ch, int):
        target_ch = [target_ch]
    stitched_stack_target = None
    for ch in target_ch:
        stitched = iss.pipeline.stitch_tiles(
            data_path,
            target_prefix,
            suffix=target_projection,
            roi=roi,
            ich=ch,
            shifts_prefix=reference_prefix,
            correct_illumination=True,
            allow_quick_estimate=True,
        ).astype(
            np.single
        )  # to save memory
        if stitched_stack_target is None:
            stitched_stack_target = stitched
        else:
            stitched_stack_target += stitched
    stitched_stack_target /= len(target_ch)

    if isinstance(ref_ch, int):
        ref_ch = [ref_ch]
    stitched_stack_reference = None
    for ch in ref_ch:
        stitched = iss.pipeline.stitch_tiles(
            data_path,
            prefix=reference_prefix,
            suffix=ref_projection,
            roi=roi,
            ich=ch,
            shifts_prefix=reference_prefix,
            correct_illumination=True,
            register_channels=False,
            allow_quick_estimate=True,
        ).astype(np.single)
        if stitched_stack_reference is None:
            stitched_stack_reference = stitched
        else:
            stitched_stack_reference += stitched
    stitched_stack_reference /= len(ref_ch)

    # If they have different shapes, pad the smaller image to the size of the larger image
    if stitched_stack_target.shape != stitched_stack_reference.shape:
        warnings.warn("Stitched stacks have different shapes. Padding to match.")
        target_shape = stitched_stack_target.shape
        reference_shape = stitched_stack_reference.shape
        if target_shape < reference_shape:
            padding = [(0, ref - targ) for targ, ref in zip(target_shape, reference_shape)]
            stitched_stack_target = np.pad(stitched_stack_target, padding, mode='constant', constant_values=0)
            fshape = reference_shape
        else:
            padding = [(0, targ - ref) for targ, ref in zip(target_shape, reference_shape)]
            stitched_stack_reference = np.pad(stitched_stack_reference, padding, mode='constant', constant_values=0)
            fshape = target_shape
    else:
        fshape = stitched_stack_target.shape

    def prep_stack(stack, downsample):
        if stack.dtype != bool:
            ma = np.nanpercentile(stack, 99)
            stack = np.clip(stack, 0, ma)
            stack = stack / ma
        # downsample
        new_size = np.array(stack.shape) // downsample
        stack = transform.resize(stack, new_size)
        return stack

    # setup common args for registration
    kwargs = dict(
        angle_range=1.0,
        niter=3,
        nangles=11,
        upsample=False,
        debug=debug,
        max_shift=ops["max_shift2ref"] // downsample,
        min_shift=0,
        reference=prep_stack(stitched_stack_reference, downsample),
        target=prep_stack(stitched_stack_target, downsample),
    )
    if use_masked_correlation:
        kwargs["target_mask"] = prep_stack(stitched_stack_target != 0, downsample)
        kwargs["reference_mask"] = prep_stack(stitched_stack_reference != 0, downsample)

    if estimate_scale and estimate_rotation:
        out = iss.pipeline.estimate_scale_rotation_translation(
            scale_range=0.01,
            **kwargs,
        )
        if debug:
            angle, shift, scale, debug_dict = out
        else:
            angle, shift, scale = out
    elif estimate_rotation:
        out = iss.pipeline.estimate_rotation_translation(
            **kwargs,
        )
        if debug:
            angle, shift, debug_dict = out
        else:
            angle, shift = out
        scale = 1
    else:
        shift, _, _, _ = mpc.phase_correlation(kwargs["reference"], kwargs["target"])
        scale = 1
        angle = 0
    shift *= downsample

    stitched_stack_target_transformed = transform_image(
        stitched_stack_target, scale=scale, angle=angle, shift=shift
    )

    fname = f"{target_prefix}_roi{roi}_tform_to_ref.npz"
    print(f"Saving {fname} in the reg folder")
    np.savez(
        iss.io.get_processed_path(data_path) / "reg" / fname,
        angle=angle,
        shift=shift,
        scale=scale,
        stitched_stack_shape=fshape,
    )
    output = [stitched_stack_target, stitched_stack_target_transformed, stitched_stack_reference, angle, shift, scale]
    if debug:
        output.append(debug_dict)
    return tuple(output)


# To run the above in parallel 
### (takes long time, could be slurmed, don't re-run)

In [None]:
import iss_preprocess as iss
import tifffile
import numpy as np
from skimage import transform
from multiprocessing import Pool

def prep_stack(stack, downsample):
    if stack.dtype != bool:
        ma = np.nanpercentile(stack, 99)
        stack = np.clip(stack, 0, ma)
        stack = stack / ma
    # downsample
    new_size = np.array(stack.shape) // downsample
    stack = transform.resize(stack, new_size)
    return stack

def process_chamber_roi(args):
    chamber, roi = args
    data_path = f"becalia_rabies_barseq/BRAC8498.3e/chamber_{chamber}/"
    processed_path = iss.io.get_processed_path(data_path)
    print(f"Doing registration for chamber {chamber} ROI {roi}")
    (
        target_image,
        target_image_shifted,
        ref_image,
        angle,
        shifts,
        scale
    ) = stitch_and_register(
        data_path,
        target_prefix="DAPI_1_1",
        reference_prefix=None,
        roi=roi,
        downsample=50,
        ref_ch=3,
        target_ch=3,
        estimate_scale=False,
        estimate_rotation=False,
        target_projection=None,
        use_masked_correlation=False,
        debug=False,
    )
    save_path = processed_path / "figures" / "DAPI_1_1"
    save_path.mkdir(exist_ok=True)
    downsampled_target = prep_stack(target_image, 50)
    tifffile.imsave(save_path / f"DAPI_1_1_roi{roi}.tif", downsampled_target)
    downsampled_target_shifted = prep_stack(target_image_shifted, 50)
    tifffile.imsave(save_path / f"DAPI_1_1_roi{roi}_shifted.tif", downsampled_target_shifted)
    downsampled_ref = prep_stack(ref_image, 50)
    tifffile.imsave(save_path / f"hyb_roi{roi}_ref.tif", downsampled_ref)
    print(f"Finished registration for chamber {chamber} ROI {roi}")
    print(f"Angle: {angle}, Shifts: {shifts}, Scale: {scale}")

if False:
    # Parameters
    chambers = ["07", "08", "09", "10"]
    rois = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

    tasks = [(chamber, roi) for chamber in chambers for roi in rois]

    # Run the multiprocessing pool
    with Pool(processes=20) as pool:  # Adjust the number of processes as needed
        pool.map(process_chamber_roi, tasks)


## Plot results of registration

In [None]:
import matplotlib.pyplot as plt
import tifffile

chamber = "07"
data_path = f"becalia_rabies_barseq/BRAC8498.3e/chamber_{chamber}/"
processed_path = iss.io.get_processed_path(data_path)

# Define the number of rows and columns for the subplots
num_rois = 10
cols = 5
rows = (num_rois + cols - 1) // cols

fig, axs = plt.subplots(rows, cols, figsize=(20, 10), dpi=300)

for i, roi in enumerate(range(1, num_rois + 1)):
    row = i // cols
    col = i % cols
    ax = axs[row, col]

    target_image_shifted = tifffile.imread(processed_path / "figures" / "DAPI_1_1" / f"DAPI_1_1_roi{roi}_shifted.tif")
    ax.imshow(target_image_shifted, cmap="Greens_r")
    try:
        ref_image = tifffile.imread(processed_path / "figures" / "DAPI_1_1" / f"hyb_roi{roi}_ref.tif")
        ax.imshow(ref_image, cmap="Reds_r", alpha=0.5)
    except FileNotFoundError:
        print(f"ROI {roi} not found")
        pass
    ax.set_title(f"ROI {roi}")
    ax.axis('off')

# Hide any empty subplots
for j in range(i + 1, rows * cols):
    fig.delaxes(axs.flatten()[j])

plt.tight_layout()
plt.show()

In [None]:
# Manual fixes

#ch7 roi 6
[-24450, -10900]
#ch7 roi 8
[-24000, -10250]
#ch8 roi 9
(0.9170526315789473, 0.0, array([ -35.5, -240.5])) -1775, -12025
#ch9 roi 7
[1425, -17425]

## Make cropped and shifted ara coord images

In [None]:
from image_tools.similarity_transforms import transform_image

def crop_to_non_cval(im, cval: float = 0.0):
    """
    Crop the image to only the part that isn't filled in by cval.
    Args:
        im (npt.NDArray): Transformed image (can be 2D or 3D with channels first)
        cval (float): Value to fill in for pixels outside of the image
    Returns:
        npt.NDArray: Cropped image
    """
    if im.ndim == 3:
        # Assume channels are the first dimension
        mask = np.any(im != cval, axis=0)
    else:
        mask = im != cval
    
    coords = np.argwhere(mask)
    
    if coords.size == 0:
        return im
    
    y_min, x_min = coords.min(axis=0)
    y_max, x_max = coords.max(axis=0) + 1
    
    if im.ndim == 3:
        cropped_im = im[:, y_min:y_max, x_min:x_max]
    else:
        cropped_im = im[y_min:y_max, x_min:x_max]
    
    return cropped_im

if False:
    ara_downsample_rate = 8

    ch_offset = 0
    for chamber in ["07", "08", "09", "10"]:
        data_path = f"becalia_rabies_barseq/BRAC8498.3e/chamber_{chamber}/"
        processed_path = iss.io.get_processed_path(data_path)
        for roi in [1,2,3,4,5,6,7,8,9,10]:
            shifts = np.load(processed_path / "reg" / f"DAPI_1_1_roi{roi}_tform_to_ref.npz")
            #account for downsampling
            shifts["shift"] / ara_downsample_rate
            ara_im = tifffile.imread(processed_path / "register_to_ara" / "ara_coordinates" / f"chamber_{chamber}_r{roi}_sl{str(roi+ch_offset).zfill(3)}.ome.tif_Coords.tif")
            target_shifted = transform_image(
                ara_im, scale=1, angle=0, shift=(shifts["shift"] / 8)
            )
            cropped_image = crop_to_non_cval(target_shifted, cval=0.0)
            #shift order from [z,x,y] to [x,y,z]
            cropped_image = np.moveaxis(cropped_image, 0, -1)
            iss.io.write_stack(cropped_image, processed_path / "register_to_ara" / "ara_coordinates" / f"chamber_{chamber}_r{roi}_sl{str(roi+ch_offset).zfill(3)}_registered.tif", dtype="float32",  clip=False)
            print(f"Saved {processed_path / 'register_to_ara' / 'ara_coordinates' / f'chamber_{chamber}_r{roi}_sl{str(roi+ch_offset).zfill(3)}_registered.tif'}")
        ch_offset += 10

## Assign ARA coords to spots

In [None]:
import pandas as pd
from tqdm import tqdm
for chamber in ["chamber_07", "chamber_08", "chamber_09", "chamber_10"]:
    data_path = "becalia_rabies_barseq/BRAC8498.3e/" + chamber + "/"
    processed_path = iss.io.get_processed_path(data_path)
    print(f"Doing registration for {data_path}")
    roi_dims = iss.io.get_roi_dimensions(data_path)
    rois = roi_dims[:,0]
    print(rois)
    for roi in tqdm(rois, total=len(rois)):   
        gene_spots = pd.read_pickle(processed_path / f"genes_round_spots_{roi}.pkl")
        ara_gene_spots = iss.pipeline.spots_ara_infos(data_path, gene_spots, roi, atlas_size=10, acronyms=True, inplace=True)
        pd.to_pickle(ara_gene_spots, processed_path / f"ara_genes_round_spots_{roi}.pkl")

        barcode_spots = pd.read_pickle(processed_path / f"barcode_round_spots_{roi}.pkl")
        ara_barcode_spots = iss.pipeline.spots_ara_infos(data_path, barcode_spots, roi, atlas_size=10, acronyms=True, inplace=True)
        pd.to_pickle(ara_barcode_spots, processed_path / f"ara_barcode_round_spots_{roi}.pkl")

### Barcodes

In [None]:
import pandas as pd
from tqdm import tqdm
for chamber in ["chamber_07", "chamber_08", "chamber_09", "chamber_10"]:
    data_path = "becalia_rabies_barseq/BRAC8498.3e/" + chamber + "/"
    processed_path = iss.io.get_processed_path(data_path)
    print(f"Doing registration for {data_path}")
    roi_dims = iss.io.get_roi_dimensions(data_path)
    rois = roi_dims[:,0]
    print(rois)
    for roi in tqdm(rois, total=len(rois)):   
        barcode_spots = pd.DataFrame(np.load(processed_path / "manual_starter_click" / f"BRAC8498.3e_{chamber}_{roi}_rabies_spots.npy", allow_pickle=True), columns=["x", "y","barcode_id", "mask_id", "barcode"])
        barcode_spots["y"] = barcode_spots["y"].astype(np.float64)
        barcode_spots["x"] = barcode_spots["x"].astype(np.float64)
        ara_barcode_filtered_spots = iss.pipeline.spots_ara_infos(data_path, barcode_spots, roi, atlas_size=10, acronyms=True, inplace=True)
        pd.to_pickle(ara_barcode_filtered_spots, processed_path / f"ara_barcode_filtered_spots_{roi}.pkl") 

### Starter

In [None]:
import pandas as pd
from tqdm import tqdm
for chamber in ["chamber_07", "chamber_08", "chamber_09", "chamber_10"]:
    data_path = "becalia_rabies_barseq/BRAC8498.3e/" + chamber + "/"
    processed_path = iss.io.get_processed_path(data_path)
    print(f"Doing registration for {data_path}")
    roi_dims = iss.io.get_roi_dimensions(data_path)
    rois = roi_dims[:,0]
    print(rois)
    for roi in tqdm(rois, total=len(rois)):   
        starter_cells = pd.read_csv(processed_path.parent / "analysis" / "starter_cells" / f"starter_cells_BRAC8498.3e_{chamber}_roi_{roi}.csv", index_col=0)
        starter_cells.columns = ["y", "x"]
        ara_starter_cells = iss.pipeline.spots_ara_infos(data_path, starter_cells, roi, atlas_size=10, acronyms=True, inplace=True)
        pd.to_pickle(ara_starter_cells, processed_path / f"ara_starter_cells_{roi}.pkl")

## Concatenate all spots into one df

In [None]:
import os
from tqdm import tqdm

# Define the chambers and the prefix
chambers = ["chamber_07", "chamber_08", "chamber_09", "chamber_10"]
prefix = "DAPI_1_1"
base_path = "becalia_rabies_barseq/BRAC8498.3e/"
# Initialize an empty list to collect dataframes
dataframes = []

for chamber in chambers:
    data_path = "becalia_rabies_barseq/BRAC8498.3e/" + chamber + "/"
    print(f"Loading data for {data_path}")
    processed_path = iss.io.get_processed_path(data_path)
    # Get the ROI dimensions
    roi_dims = iss.io.get_roi_dimensions(data_path, prefix)
    rois = roi_dims[:, 0]
    
    for roi in tqdm(rois, total=len(rois)):
        # Load the ara_genes_round_spots PKL
        pkl_path = os.path.join(processed_path, f"ara_genes_round_spots_{roi}.pkl")
        if os.path.exists(pkl_path):
            ara_gene_spots = pd.read_pickle(pkl_path)
            # Add columns for chamber and roi
            ara_gene_spots['chamber'] = chamber
            ara_gene_spots['roi'] = roi
            # Append the dataframe to the list
            dataframes.append(ara_gene_spots)
        else:
            print(f"File not found: {pkl_path}")

# Concatenate all dataframes into a single dataframe
all_ara_gene_spots = pd.concat(dataframes, ignore_index=True)
all_ara_gene_spots

In [None]:
import os
from tqdm import tqdm
# Define the chambers and the prefix
chambers = ["chamber_07", "chamber_08", "chamber_09", "chamber_10"]
prefix = "DAPI_1_1"
base_path = "becalia_rabies_barseq/BRAC8498.3e/"
# Initialize an empty list to collect dataframes
dataframes = []
for chamber in chambers:
    data_path = "becalia_rabies_barseq/BRAC8498.3e/" + chamber + "/"
    print(f"Loading data for {data_path}")
    processed_path = iss.io.get_processed_path(data_path)
    # Get the ROI dimensions
    roi_dims = iss.io.get_roi_dimensions(data_path, prefix)
    rois = roi_dims[:, 0]
    
    for roi in tqdm(rois, total=len(rois)):
        # Load the ara_genes_round_spots PKL
        pkl_path = os.path.join(processed_path, f"ara_barcode_filtered_spots_{roi}.pkl")
        if os.path.exists(pkl_path):
            ara_barcode_spots = pd.read_pickle(pkl_path)
            # Add columns for chamber and roi
            ara_barcode_spots['chamber'] = chamber
            ara_barcode_spots['roi'] = roi
            # Append the dataframe to the list
            dataframes.append(ara_barcode_spots)
        else:
            print(f"File not found: {pkl_path}")
# Concatenate all dataframes into a single dataframe
all_ara_barcode_spots = pd.concat(dataframes, ignore_index=True)
all_ara_barcode_spots

In [None]:
import os
from tqdm import tqdm
# Define the chambers and the prefix
chambers = ["chamber_07", "chamber_08", "chamber_09", "chamber_10"]
prefix = "DAPI_1_1"
base_path = "becalia_rabies_barseq/BRAC8498.3e/"
# Initialize an empty list to collect dataframes
dataframes = []
for chamber in chambers:
    data_path = "becalia_rabies_barseq/BRAC8498.3e/" + chamber + "/"
    print(f"Loading data for {data_path}")
    processed_path = iss.io.get_processed_path(data_path)
    # Get the ROI dimensions
    roi_dims = iss.io.get_roi_dimensions(data_path, prefix)
    rois = roi_dims[:, 0]
    
    for roi in tqdm(rois, total=len(rois)):
        # Load the ara_genes_round_spots PKL
        pkl_path = os.path.join(processed_path, f"ara_starter_cells_{roi}.pkl")
        if os.path.exists(pkl_path):
            ara_starter_cells = pd.read_pickle(pkl_path)
            # Add columns for chamber and roi
            ara_starter_cells['chamber'] = chamber
            ara_starter_cells['roi'] = roi
            # Append the dataframe to the list
            dataframes.append(ara_starter_cells)
        else:
            print(f"File not found: {pkl_path}")
# Concatenate all dataframes into a single dataframe
all_ara_starter_cells = pd.concat(dataframes, ignore_index=True)
all_ara_starter_cells

# Transform coords for 2D plotting
### Using chamber_07 ROI 1 as reference, transform all other ROI's raw x, y coords using least squares into a common 2D framework using the relation between each ROI's raw and ara coords
This allows checking of registration fit between planes

In [None]:
import numpy as np
from scipy.linalg import lstsq

roi_reference = 1 

df = all_ara_gene_spots.copy()
df['z'] = 0

# Filter rows where roi is equal to 1 and chamber is equal to 07
df_roi_1 = df[(df['roi'] == roi_reference) & (df['chamber'] == 'chamber_07')]

# Extract coordinate values for roi = 1
x = df_roi_1['x'].values
y = df_roi_1['y'].values
z = df_roi_1['z'].values
ara_x = df_roi_1['ara_x'].values
ara_y = df_roi_1['ara_y'].values
ara_z = df_roi_1['ara_z'].values

A = np.column_stack((ara_x, ara_y, ara_z))
b = np.column_stack((x, y, z))

T, residuals, _, _ = lstsq(A, b)

# Apply the transformation matrix to rows where roi is not equal to 1
df_roi_not_1 = df[df['roi'] != roi_reference]

# Extract coordinate values for roi != 1
ara_x_new = df_roi_not_1['ara_x'].values
ara_y_new = df_roi_not_1['ara_y'].values
ara_z_new = df_roi_not_1['ara_z'].values

# Apply the transformation matrix T to the new data
coordinates_new = np.column_stack((ara_x_new, ara_y_new, ara_z_new))
transformed_coordinates = np.dot(coordinates_new, T)

# Extract the transformed coordinates into separate arrays
x_new = transformed_coordinates[:, 0]
y_new = transformed_coordinates[:, 1]
z_new = transformed_coordinates[:, 2]

# Create new columns 'x', 'y', 'z' in df_roi_not_1 with the transformed coordinates
df_roi_not_1['x'] = x_new
df_roi_not_1['y'] = y_new
df_roi_not_1['z'] = z_new

# Combine the modified rows back with the original DataFrame
df_transformed_genes = pd.concat([df_roi_1, df_roi_not_1])

# Print the transformed DataFrame
df_transformed_genes


In [None]:
# Remove spots outside brain
df_transformed_genes = df_transformed_genes[df_transformed_genes['area_acronym'] != 'outside']

## Plot all ROIs in 2D if desired
### (slow)

In [None]:
# Define chambers and rois
chambers = ["chamber_07", "chamber_08", "chamber_09", "chamber_10"]
rois = list(range(1, 11))

# Create a grid of subplots
fig, axs = plt.subplots(len(chambers), len(rois), figsize=(20, 15))

for i, chamber in tqdm(enumerate(chambers), total=len(chambers)):
    for j, roi in tqdm(enumerate(rois), total=len(rois)):
        # Filter the DataFrame for the current chamber and roi
        df_filtered = df_transformed_genes[(df_transformed_genes['chamber'] == chamber) & (df_transformed_genes['roi'] == roi)]
        
        # Plot the data
        axs[i, j].scatter(df_filtered['x'], df_filtered['y'], s=0.01, alpha=0.05, c="black")
        axs[i, j].set_title(f'{chamber}, ROI: {roi}')
        axs[i, j].axis('off')
        axs[i, j].set_aspect('equal')
        # Optionally set labels if desired
        if i == len(chambers) - 1:
            axs[i, j].set_xlabel('X')
        if j == 0:
            axs[i, j].set_ylabel('Y')

# Adjust layout
plt.tight_layout()
plt.show()

### Or plot any two planes

In [None]:
plt.scatter(
    df_transformed_genes[(df_transformed_genes["roi"]==1)&(df_transformed_genes["chamber"]=="chamber_07")]["x"],
    df_transformed_genes[(df_transformed_genes["roi"]==1)&(df_transformed_genes["chamber"]=="chamber_07")]["y"],
    s=0.1, alpha=0.1)

plt.scatter(
    df_transformed_genes[(df_transformed_genes["roi"]==2)&(df_transformed_genes["chamber"]=="chamber_07")]["x"],
    df_transformed_genes[(df_transformed_genes["roi"]==2)&(df_transformed_genes["chamber"]=="chamber_07")]["y"],
    s=0.1, alpha=0.03, c="red")

# Interactive 3D plotting

In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Assuming df_transformed_genes is your DataFrame and already defined

# Factorize the area_id column
df_transformed_genes['area_ids_continuous'] = pd.factorize(df_transformed_genes['area_id'])[0].astype(str)

# Define the number of points to plot
subset_size = 100000

# Randomly select a subset of points
df_subset = df_transformed_genes.sample(n=subset_size)
df_subset = df_subset.sort_values(by='area_acronym')
# Create a list of colors using the tab20 colormap from matplotlib
tab20 = plt.get_cmap('tab20').colors
tab20_hex = [mcolors.rgb2hex(color) for color in tab20]

# Generate a list of colors mapped to the unique area_ids_continuous
unique_area_ids = df_subset['area_ids_continuous'].unique()
color_map = {area_id: tab20_hex[i % len(tab20_hex)] for i, area_id in enumerate(unique_area_ids)}

# Apply the color mapping to the DataFrame
df_subset['color'] = df_subset['area_ids_continuous'].map(color_map)

# Create the scatter plot using Plotly Express
fig = px.scatter_3d(
    df_subset,
    x='ara_x',
    y='ara_y',
    z='ara_z',
    color='area_ids_continuous',
    color_discrete_sequence=tab20_hex,
    title='Interactive 3D Scatter Plot Colored by area_ids_continuous',
    width=2000,
    height=1200
)

fig.update_traces(marker=dict(size=2))
area_acronym_map = df_transformed_genes.set_index('area_ids_continuous')['area_acronym'].to_dict()
fig.for_each_trace(lambda t: t.update(name = area_acronym_map[t.name]))


# Update legend marker size
fig.update_layout(
    legend=dict(
        itemsizing='constant',
        itemclick='toggleothers',
        itemdoubleclick='toggle',
        title_text='Legend',
        font=dict(size=12),
        traceorder='normal'
    )
)

# Remove the background grid and axes
fig.update_layout(
    scene=dict(
        xaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        yaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        zaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        bgcolor='white'  # Set the background color to white
    )
)
# Show the plot
fig.show()

### Starters

In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Assuming df_transformed_genes is your DataFrame and already defined

# Factorize the area_id column
all_ara_starter_cells['area_ids_continuous'] = pd.factorize(all_ara_starter_cells['area_id'])[0].astype(str)



df_subset = all_ara_starter_cells
df_subset = df_subset.sort_values(by='area_acronym')
# Create a list of colors using the tab20 colormap from matplotlib
tab20 = plt.get_cmap('tab20').colors
tab20_hex = [mcolors.rgb2hex(color) for color in tab20]

# Generate a list of colors mapped to the unique area_ids_continuous
unique_area_ids = df_subset['area_ids_continuous'].unique()
color_map = {area_id: tab20_hex[i % len(tab20_hex)] for i, area_id in enumerate(unique_area_ids)}

# Apply the color mapping to the DataFrame
df_subset['color'] = df_subset['area_ids_continuous'].map(color_map)

# Create the scatter plot using Plotly Express
fig = px.scatter_3d(
    df_subset,
    x='ara_x',
    y='ara_y',
    z='ara_z',
    color='area_ids_continuous',
    color_discrete_sequence=tab20_hex,
    title='Interactive 3D Scatter Plot Colored by area_ids_continuous',
    width=2000,
    height=1200
)

fig.update_traces(marker=dict(size=2))
area_acronym_map = all_ara_starter_cells.set_index('area_ids_continuous')['area_acronym'].to_dict()
fig.for_each_trace(lambda t: t.update(name = area_acronym_map[t.name]))


# Update legend marker size
fig.update_layout(
    legend=dict(
        itemsizing='constant',
        itemclick='toggleothers',
        itemdoubleclick='toggle',
        title_text='Legend',
        font=dict(size=12),
        traceorder='normal'
    )
)

# Remove the background grid and axes
fig.update_layout(
    scene=dict(
        xaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        yaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        zaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        bgcolor='white'  # Set the background color to white
    )
)
# Show the plot
fig.show()

### Barcodes

In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Assuming all_ara_barcode_spots is your DataFrame and already defined
# Factorize the barcode column to create universal_barcode_id
all_ara_barcode_spots['universal_barcode_id'] = pd.factorize(all_ara_barcode_spots['barcode'])[0].astype(str)

# Filter out rows where area_acronym is 'outside'
all_ara_barcode_spots = all_ara_barcode_spots[all_ara_barcode_spots['area_acronym'] != 'outside']

df_subset = all_ara_barcode_spots
df_subset = df_subset.sort_values(by='area_acronym')

# Create a list of colors using the tab20 colormap from matplotlib
tab20 = plt.get_cmap('tab20').colors
tab20_hex = [mcolors.rgb2hex(color) for color in tab20]

# Generate a list of colors mapped to the unique universal_barcode_id
unique_barcode_ids = df_subset['universal_barcode_id'].unique()
color_map = {barcode_id: tab20_hex[i % len(tab20_hex)] for i, barcode_id in enumerate(unique_barcode_ids)}

# Apply the color mapping to the DataFrame
df_subset['color'] = df_subset['universal_barcode_id'].map(color_map)

# Count the occurrences of each barcode and sort them
barcode_counts = df_subset['universal_barcode_id'].value_counts()
sorted_barcodes = barcode_counts.index.tolist()

# Create the scatter plot using Plotly Express
fig = px.scatter_3d(
    df_subset,
    x='ara_x',
    y='ara_y',
    z='ara_z',
    color='universal_barcode_id',
    category_orders={'universal_barcode_id': sorted_barcodes},
    color_discrete_sequence=tab20_hex,
    title='Interactive 3D Scatter Plot Colored by universal_barcode_id',
    width=2000,
    height=1200
)
fig.update_traces(marker=dict(size=2))

# Map barcode IDs to the actual barcode sequences
barcode_id_map = all_ara_barcode_spots.set_index('universal_barcode_id')['barcode'].to_dict()
fig.for_each_trace(lambda t: t.update(name=barcode_id_map[t.name]))

# Update legend marker size
fig.update_layout(
    legend=dict(
        itemsizing='constant',
        itemclick='toggleothers',
        itemdoubleclick='toggle',
        title_text='Legend',
        font=dict(size=12),
    )
)

# Remove the background grid and axes
fig.update_layout(
    scene=dict(
        xaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        yaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        zaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        bgcolor='white'  # Set the background color to white
    )
)

# Show the plot
fig.show()


In [None]:
import iss_analysis as issa
project = "becalia_rabies_barseq"
mouse = "BRAC8498.3e"

error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_16"
data_path = f"{project}/{mouse}"
analysis_folder = iss.io.get_processed_path(data_path) / "analysis"

(
    rab_spot_df,
    rab_cells_barcodes,
    rab_cells_properties,
) = issa.segment.get_barcode_in_cells(
    project,
    mouse,
    error_correction_ds_name,
    valid_chambers=None,
    save_folder=None,
    verbose=True,
)
# find starter
starters_positions = issa.io.get_starter_cells(project, mouse)
rabies_cell_properties = issa.segment.match_starter_to_barcodes(
    project,
    mouse,
    rab_cells_properties,
    rab_spot_df,
    starters=starters_positions,
    redo=False,
)
rabies_cell_properties.head()


In [None]:
import pandas as pd
from tqdm import tqdm
for chamber in ["chamber_07", "chamber_08", "chamber_09", "chamber_10"]:
    data_path = "becalia_rabies_barseq/BRAC8498.3e/" + chamber + "/"
    processed_path = iss.io.get_processed_path(data_path)
    print(f"Doing registration for {data_path}")
    roi_dims = iss.io.get_roi_dimensions(data_path)
    rois = roi_dims[:,0]
    print(rois)
    for roi in tqdm(rois, total=len(rois)):   
        starter_cells = rabies_cell_properties[(rabies_cell_properties["chamber"]==chamber) & (rabies_cell_properties["roi"]==roi)]
        ara_starter_cells = iss.pipeline.spots_ara_infos(data_path, starter_cells, roi, atlas_size=10, acronyms=True, inplace=True)
        pd.to_pickle(ara_starter_cells, processed_path / f"ara_starter_cells_{roi}.pkl")

In [None]:
import os
from tqdm import tqdm
# Define the chambers and the prefix
chambers = ["chamber_07", "chamber_08", "chamber_09", "chamber_10"]
prefix = "DAPI_1_1"
base_path = "becalia_rabies_barseq/BRAC8498.3e/"
# Initialize an empty list to collect dataframes
dataframes = []
for chamber in chambers:
    data_path = "becalia_rabies_barseq/BRAC8498.3e/" + chamber + "/"
    print(f"Loading data for {data_path}")
    processed_path = iss.io.get_processed_path(data_path)
    # Get the ROI dimensions
    roi_dims = iss.io.get_roi_dimensions(data_path, prefix)
    rois = roi_dims[:, 0]
    
    for roi in tqdm(rois, total=len(rois)):
        # Load the ara_genes_round_spots PKL
        pkl_path = os.path.join(processed_path, f"ara_starter_cells_{roi}.pkl")
        if os.path.exists(pkl_path):
            ara_starter_cells = pd.read_pickle(pkl_path)
            # Add columns for chamber and roi
            ara_starter_cells['chamber'] = chamber
            ara_starter_cells['roi'] = roi
            # Append the dataframe to the list
            dataframes.append(ara_starter_cells)
        else:
            print(f"File not found: {pkl_path}")
# Concatenate all dataframes into a single dataframe
all_ara_starter_cells = pd.concat(dataframes, ignore_index=True)
all_ara_starter_cells

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Assuming starters_positions is a pandas DataFrame
cell_counts = starters_positions.groupby(['chamber', 'roi']).size().reset_index(name='counts')

# Plotting
plt.figure(figsize=(10,6))
plt.bar(cell_counts['chamber'] + '-' + cell_counts['roi'].astype(str), cell_counts['counts'])
plt.xlabel('Chamber-ROI')
plt.ylabel('Number of Cells')
plt.title('Number of Starter cells per Chamber/ROI')
plt.xticks(rotation=90)
plt.show()

In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Assuming all_ara_barcode_spots and starter_cells are your DataFrames and already defined
# Factorize the barcode column to create universal_barcode_id
all_ara_barcode_spots['universal_barcode_id'] = pd.factorize(all_ara_barcode_spots['barcode'])[0].astype(str)

# Filter out rows where area_acronym is 'outside'
all_ara_barcode_spots = all_ara_barcode_spots[all_ara_barcode_spots['area_acronym'] != 'outside']

df_subset = all_ara_barcode_spots
df_subset = df_subset.sort_values(by='area_acronym')

# Create a list of colors using the tab20 colormap from matplotlib
tab20 = plt.get_cmap('tab20').colors
tab20_hex = [mcolors.rgb2hex(color) for color in tab20]

# Generate a list of colors mapped to the unique universal_barcode_id
unique_barcode_ids = df_subset['universal_barcode_id'].unique()
color_map = {barcode_id: tab20_hex[i % len(tab20_hex)] for i, barcode_id in enumerate(unique_barcode_ids)}

# Apply the color mapping to the DataFrame
df_subset['color'] = df_subset['universal_barcode_id'].map(color_map)

# Count the occurrences of each barcode and sort them
barcode_counts = df_subset['universal_barcode_id'].value_counts()
sorted_barcodes = barcode_counts.index.tolist()

# Filter starter cells to include only those with starter == True
starter_cells = all_ara_starter_cells[all_ara_starter_cells['starter'] == True]
starter_cells["size"] = 20
starter_cells["color"] = "black"
starter_cells["corrected_ara_z"] = starter_cells["ara_z"] / 8

# Create the scatter plot using Plotly Express
fig = px.scatter_3d(
    df_subset,
    x='ara_x',
    y='ara_y',
    z='ara_z',
    color='universal_barcode_id',
    category_orders={'universal_barcode_id': sorted_barcodes},
    color_discrete_sequence=tab20_hex,
    title='Interactive 3D Scatter Plot Colored by universal_barcode_id',
    width=2000,
    height=1200
)
fig.update_traces(marker=dict(size=2))

# Add starter cells to the plot
fig.add_trace(
    px.scatter_3d(
        starter_cells,
        x='ara_x',
        y='ara_y',
        z='ara_z',
        color="color",
        title='Starter Cells',
        size="size",
    ).data[0]
)

# Map barcode IDs to the actual barcode sequences
barcode_id_map = all_ara_barcode_spots.set_index('universal_barcode_id')['barcode'].to_dict()
fig.for_each_trace(lambda t: t.update(name=barcode_id_map[t.name] if t.name in barcode_id_map else t.name))

# Update legend marker size
fig.update_layout(
    legend=dict(
        itemsizing='constant',
        itemclick='toggleothers',
        itemdoubleclick='toggle',
        title_text='Legend',
        font=dict(size=12),
    )
)

# Calculate the ranges for each axis
x_range = df_subset['ara_x'].max() - df_subset['ara_x'].min()
y_range = df_subset['ara_y'].max() - df_subset['ara_y'].min()
z_range = df_subset['ara_z'].max() - df_subset['ara_z'].min()

# Determine the aspect ratio
aspect_ratio = dict(x=x_range, y=y_range, z=z_range)

# Update the layout with equal aspect ratio
fig.update_layout(
    scene=dict(
        aspectmode='manual',
        aspectratio=aspect_ratio,
        xaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        yaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        zaxis=dict(
            showbackground=False,
            showgrid=False,
            zeroline=False,
            visible=False
        ),
        bgcolor='white'  # Set the background color to white
    )
)

#Show the plot
fig.show()


In [None]:
all_ara_barcode_spots["corrected_ara_z"] = all_ara_barcode_spots["ara_z"] / 8

In [None]:
starter_cells

In [None]:
#save all the concatenated dataframes
all_ara_gene_spots.to_pickle(processed_path.parent / "analysis" / "ara_gene_spots.pkl")
all_ara_barcode_spots.to_pickle(processed_path.parent / "analysis" / "ara_barcode_spots.pkl")
all_ara_starter_cells.to_pickle(processed_path.parent / "analysis" / "ara_starter_cells.pkl")

In [None]:
all_ara_starter_cells