In [None]:
import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
from matplotlib.ticker import FixedLocator
from tqdm import tqdm, trange

from iss_preprocess.io import (
    get_processed_path,
    get_roi_dimensions,
    load_micromanager_metadata,
    load_ops,
)

In [None]:
def plot_genes(gene_data, data_path, roi=1, plot_grid=True):
    plt.rcParams["figure.facecolor"] = "white"
    processed_path = get_processed_path(data_path)
    ops = load_ops(data_path)
    roi_dims = get_roi_dimensions(data_path)[roi - 1]
    nx = roi_dims[1] + 1
    ny = roi_dims[2] + 1
    metadata = load_micromanager_metadata(data_path, "genes_round_1_1")
    x_dim = int(metadata["FrameKey-0-0-0"]["ROI"].split("-")[2])
    y_dim = int(metadata["FrameKey-0-0-0"]["ROI"].split("-")[3])
    dim_x = ((nx - 1) * x_dim * 0.9) + x_dim
    dim_y = ((ny - 1) * y_dim * 0.9) + y_dim
    marker = "o"
    genes = np.sort(gene_data["Gene"].unique())
    num_genes = len(genes)
    num_cols = 3
    num_rows = math.ceil(num_genes / num_cols)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(25, 6 * num_rows))
    for i, gene in tqdm(enumerate(genes), desc="Genes computed", total=num_genes):
        color = "black"
        row_idx = i // num_cols
        col_idx = i % num_cols
        ax = axes[row_idx, col_idx]
        # Rotate the data 90 degrees counterclockwise
        rotated_x = gene_data[gene_data["Gene"] == gene]["y"].values
        rotated_y = dim_x - gene_data[gene_data["Gene"] == gene]["x"].values
        ax.plot(
            rotated_x,
            rotated_y,
            marker,
            c=color,
            markersize=0.1,
        )
        total_counts = len(rotated_x)
        ax.set_title(gene, fontsize=20)
        ax.set_xlabel(f"Total counts = {total_counts}")
        ax.set_aspect("equal", "box")
        ax.set_xlim(0, dim_y)
        ax.set_ylim(0, dim_x)

        if plot_grid:
            dim = np.array((dim_y, dim_x))
            tile_size = dim / np.array([ny, nx])
            for ix, x in enumerate("xy"):
                # Add gridlines at approximate tile boundaries
                getattr(ax, f"set_{x}lim")(0, dim[ix])
                tcks = np.arange(0, dim[ix], tile_size[ix]) + (tile_size[ix] / 2)
                getattr(ax, f"set_{x}ticks")(tcks)
                minor_locator = FixedLocator(np.arange(0, dim[ix], tile_size[ix]))
                getattr(ax, f"{x}axis").set_minor_locator(minor_locator)
            # Adjust tick labels to display between the ticks
            ops = load_ops(data_path)
            if ops["x_tile_direction"] == "left_to_right":
                ax.set_xticklabels(
                    np.arange(0, len(ax.get_xticks()))[::-1], rotation=90
                )
            else:
                ax.set_xticklabels(np.arange(0, len(ax.get_xticks())), rotation=90)
            if ops["y_tile_direction"] == "top_to_bottom":
                ax.set_yticklabels(np.arange(0, len(ax.get_yticks())))
            else:
                ax.set_yticklabels(np.arange(0, len(ax.get_yticks()))[::-1])
            ax.grid(which="minor", color="lightgrey")
            ax.tick_params(
                top=False,
                bottom=False,
                left=False,
                right=False,
                labelleft=True,
                labelbottom=True,
            )
            ax.set_ylabel("x_coords")
        else:
            ax.axis("off")
            fig.subplots_adjust(
                left=0, bottom=0, right=1, top=1, wspace=None, hspace=None
            )

    if num_genes < num_rows * num_cols:
        for i in range(num_genes, num_rows * num_cols):
            row_idx = i // num_cols
            col_idx = i % num_cols
            fig.delaxes(axes[row_idx, col_idx])
    for ax in axes.flat:
        ax.invert_yaxis()
    fig.tight_layout()
    fig.subplots_adjust(hspace=0.3)
    ref_tile = ops["ref_tile"]
    mouse_name = (processed_path).parent.name
    chamber_name = (processed_path).name
    fig.suptitle(
        f"{mouse_name} {chamber_name} ROI: {roi}, ref tile: {ref_tile}",
        fontsize=40,
        y=1.01,
    )
    save_path = processed_path / "figures" / "round_overviews" / "gene_spots"
    save_path.mkdir(exist_ok=True, parents=True)
    print(f"Saving to: {save_path}")
    plt.savefig(save_path / f"ara_gene_spots_{roi}.png", bbox_inches="tight")
    plt.show()

In [None]:
data_path = "becalia_rabies_barseq/BRAC8498.3e/chamber_08/"
roi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

for iroi in trange(1, 2, desc="ROI", total=len(roi)):
    processed_path = get_processed_path(data_path)
    gene_spots = pd.read_pickle(processed_path / f"ara_genes_round_spots_{iroi}.pkl")
    gene_spots.rename(columns={"gene": "Gene"}, inplace=True)
    gene_spots = gene_spots[gene_spots["spot_score"] > 0.2]
    plot_genes(gene_spots, data_path, roi=iroi, plot_grid=True)

In [None]:
from pathlib import Path


def plot_gene_across_chambers_rois(base_path, chambers, rois, gene_to_plot):
    """
    Plots the spatial distribution of a single gene across multiple chambers and ROIs
    in a single figure.

    Each column represents a chamber, and each row represents an ROI.

    Args:
        base_path (str or Path): The base directory path containing chamber folders
                                  (e.g., 'becalia_rabies_barseq/BRAC8498.3e').
        chambers (list): A list of chamber names as strings
        (e.g., ['07', '08', '09', '10']).
        rois (list): A list of ROI numbers (e.g., [1, 2, ..., 10]).
        gene_to_plot (str): The name of the gene to plot (e.g., "Rorb").
    """
    plt.rcParams["figure.facecolor"] = "white"
    num_rows = len(rois)
    num_cols = len(chambers)

    # Create a grid of subplots. figsize is adjusted for the new layout.
    # squeeze=False ensures that 'axes' is always a 2D array, even for a single row/col.
    fig, axes = plt.subplots(
        num_rows, num_cols, figsize=(4 * num_cols, 4 * num_rows), squeeze=False
    )

    print(f"Generating {num_rows}x{num_cols} grid for gene '{gene_to_plot}'...")

    # Iterate through chambers (for columns)
    for col_idx, chamber_name in enumerate(tqdm(chambers, desc="Processing Chambers")):
        # Iterate through ROIs (for rows)
        for row_idx, roi_num in enumerate(rois):
            ax = axes[row_idx, col_idx]
            data_path = Path(base_path) / f"chamber_{chamber_name}"

            try:
                # --- Data Loading and Filtering ---
                processed_path = get_processed_path(data_path)

                # Check if the data file exists
                gene_spots_file = (
                    processed_path / f"ara_genes_round_spots_{roi_num}.pkl"
                )
                if not gene_spots_file.exists():
                    ax.text(
                        0.5,
                        0.5,
                        "No Data",
                        ha="center",
                        va="center",
                        fontsize=12,
                        alpha=0.5,
                    )
                    ax.axis("off")
                    continue

                gene_spots = pd.read_pickle(gene_spots_file)
                gene_spots.rename(
                    columns={"gene": "Gene"}, inplace=True, errors="ignore"
                )

                # Filter data for the specific gene and a quality score
                gene_data = gene_spots[
                    (gene_spots["Gene"] == gene_to_plot)
                    & (gene_spots["spot_score"] > 0.2)
                ]
                total_counts = len(gene_data)

                if gene_data.empty:
                    ax.text(
                        0.5,
                        0.5,
                        "0 counts",
                        ha="center",
                        va="center",
                        fontsize=12,
                        alpha=0.5,
                    )
                    ax.axis("off")
                    continue

                # --- Plotting ---
                # Rotate the data 90 degrees counterclockwise for consistent orientation
                rotated_x = gene_data["ara_z"].values
                rotated_y = gene_data["ara_y"].values

                ax.plot(rotated_x, rotated_y, "o", c="black", markersize=0.2)
                # plt.scatter(
                # gene_spots["ara_z"],
                # gene_spots["ara_y"],
                # s=0.1,
                # c="black",
                # alpha=0.01)
                ax.set_xlim(6, 11)
                ax.set_ylim(0, 4)

                # --- Set Titles and Labels ---
                # Set a column title (Chamber) on the top-most subplot of each column
                if row_idx == 0:
                    ax.set_title(f"Chamber {chamber_name}", fontsize=18, pad=20)

                # Set a row label (ROI) on the left-most subplot of each row
                if col_idx == 0:
                    ax.set_ylabel(f"ROI {roi_num}", fontsize=18, labelpad=30)

                # Add total counts as text inside the plot for quick reference
                ax.text(
                    0.95,
                    0.05,
                    f"N={total_counts}",
                    transform=ax.transAxes,
                    fontsize=10,
                    verticalalignment="bottom",
                    horizontalalignment="right",
                )

                # --- Set Aspect, Limits, and Turn Off Axis Decorations ---
                ax.set_aspect("equal", "box")
                # ax.set_xlim(0, dim_y)
                # ax.set_ylim(0, dim_x)
                ax.invert_yaxis()
                ax.axis("off")

            except FileNotFoundError:
                ax.text(
                    0.5,
                    0.5,
                    "File Not Found",
                    ha="center",
                    va="center",
                    fontsize=12,
                    alpha=0.5,
                )
                ax.axis("off")
                print(f"Data file not found for Chamber {chamber_name}, ROI {roi_num}")
            except Exception as e:
                ax.text(
                    0.5, 0.5, "Error", ha="center", va="center", fontsize=12, c="red"
                )
                ax.axis("off")
                print(
                    f"An error occurred processing Ch{chamber_name}, ROI {roi_num}: {e}"
                )

    # --- Final Figure Adjustments ---
    fig.suptitle(
        f"Gene Expression of '{gene_to_plot}' Across Chambers and ROIs",
        fontsize=30,
        y=1.0,
    )
    fig.tight_layout(rect=[0, 0, 1, 0.97])  # Adjust rect to make space for the suptitle

    # --- Saving the Figure ---
    # Create a directory for the figures if it doesn't exist
    save_dir = Path(base_path) / "figures" / "gene_overviews"
    save_dir.mkdir(exist_ok=True, parents=True)
    save_file = save_dir / f"{gene_to_plot}_expression_grid.png"

    print(f"\nSaving plot to: {save_file}")
    plt.savefig(save_file, bbox_inches="tight", dpi=300)
    plt.show()

In [None]:
BASE_DATA_PATH = "becalia_rabies_barseq/BRAC8498.3e"
CHAMBER_LIST = ["07", "08", "09", "10"]
ROI_LIST = list(range(1, 11))  # Creates a list from 1 to 10
GENE_TO_PLOT = "Rorb"

# Call the plotting function
plot_gene_across_chambers_rois(
    base_path=BASE_DATA_PATH,
    chambers=CHAMBER_LIST,
    rois=ROI_LIST,
    gene_to_plot=GENE_TO_PLOT,
)

In [None]:
data_path = "karapir_barseq_5ht/FIAA53.3b/chamber_02/"
roi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

for iroi in trange(1, 2, desc="ROI", total=len(roi)):
    processed_path = get_processed_path(data_path)
    gene_spots = pd.read_pickle(processed_path / f"genes_round_spots_{iroi}.pkl")
    gene_spots.rename(columns={"gene": "Gene"}, inplace=True)
    gene_spots = gene_spots[gene_spots["spot_score"] > 0.2]
    plot_genes(gene_spots, data_path, roi=iroi, plot_grid=True)

In [None]:
# Create the scatter plot
fig = px.scatter(
    gene_spots,
    x="y",
    y="x",
    color="Gene",  # This will automatically color the spots based on the 'Gene' column
    hover_data=[
        "spot_score",
        "pos_pixels",
        "neg_pixels",
        "tile",
    ],  # Add additional hover information
)

fig.update_traces(marker=dict(size=3))

# Update the layout to make the axes aspect equal
fig.update_layout(
    yaxis_scaleanchor="x",  # This ensures that the x and y axes have the same scale
    yaxis_scaleratio=1,
    xaxis_constrain="domain",
    yaxis_constrain="domain",
    height=900,
    # Set background and grid colors
    plot_bgcolor="white",  # Set background color to white
    xaxis=dict(
        showgrid=True,  # Show grid lines
        gridcolor="black",  # Set grid color to black
        zerolinecolor="black",  # Set zero line color to black
    ),
    yaxis=dict(
        showgrid=True,  # Show grid lines
        gridcolor="black",  # Set grid color to black
        zerolinecolor="black",  # Set zero line color to black
    ),
)

# Update legend marker size
fig.update_layout(
    legend=dict(
        itemsizing="constant",
        itemclick="toggleothers",
        itemdoubleclick="toggle",
        title_text="Legend",
        font=dict(size=12),
        traceorder="normal",
    )
)

fig.show()