In [None]:
import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.ticker import FixedLocator
from tqdm import tqdm, trange

import iss_preprocess as iss

In [None]:
def plot_genes(gene_data, data_path, roi=1, plot_grid=True):
    plt.rcParams["figure.facecolor"] = "white"
    processed_path = iss.io.get_processed_path(data_path)
    ops = iss.io.load_ops(data_path)
    roi_dims = iss.io.get_roi_dimensions(data_path)[roi - 1]
    nx = roi_dims[1] + 1
    ny = roi_dims[2] + 1
    metadata = iss.io.load_micromanager_metadata(data_path, "genes_round_1_1")
    x_dim = int(metadata["FrameKey-0-0-0"]["ROI"].split("-")[2])
    y_dim = int(metadata["FrameKey-0-0-0"]["ROI"].split("-")[3])
    dim_x = ((nx - 1) * x_dim * 0.9) + x_dim
    dim_y = ((ny - 1) * y_dim * 0.9) + y_dim
    marker = "o"
    genes = np.sort(gene_data["Gene"].unique())
    num_genes = len(genes)
    num_cols = 3
    num_rows = math.ceil(num_genes / num_cols)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(25, 6 * num_rows))
    for i, gene in tqdm(enumerate(genes), desc="Genes computed", total=num_genes):
        color = "black"
        row_idx = i // num_cols
        col_idx = i % num_cols
        ax = axes[row_idx, col_idx]
        # Rotate the data 90 degrees counterclockwise
        rotated_x = gene_data[gene_data["Gene"] == gene]["y"].values
        rotated_y = dim_x - gene_data[gene_data["Gene"] == gene]["x"].values
        ax.plot(
            rotated_x,
            rotated_y,
            marker,
            c=color,
            markersize=0.1,
        )
        total_counts = len(rotated_x)
        ax.set_title(gene, fontsize=20)
        ax.set_xlabel(f"Total counts = {total_counts}")
        ax.set_aspect("equal", "box")
        ax.set_xlim(0, dim_y)
        ax.set_ylim(0, dim_x)

        if plot_grid:
            dim = np.array((dim_y, dim_x))
            tile_size = dim / np.array([ny, nx])
            for ix, x in enumerate("xy"):
                # Add gridlines at approximate tile boundaries
                getattr(ax, f"set_{x}lim")(0, dim[ix])
                tcks = np.arange(0, dim[ix], tile_size[ix]) + (tile_size[ix] / 2)
                getattr(ax, f"set_{x}ticks")(tcks)
                minor_locator = FixedLocator(np.arange(0, dim[ix], tile_size[ix]))
                getattr(ax, f"{x}axis").set_minor_locator(minor_locator)
            # Adjust tick labels to display between the ticks
            ops = iss.io.load_ops(data_path)
            if ops["x_tile_direction"] == "left_to_right":
                ax.set_xticklabels(
                    np.arange(0, len(ax.get_xticks()))[::-1], rotation=90
                )
            else:
                ax.set_xticklabels(np.arange(0, len(ax.get_xticks())), rotation=90)
            if ops["y_tile_direction"] == "top_to_bottom":
                ax.set_yticklabels(np.arange(0, len(ax.get_yticks())))
            else:
                ax.set_yticklabels(np.arange(0, len(ax.get_yticks()))[::-1])
            ax.grid(which="minor", color="lightgrey")
            ax.tick_params(
                top=False,
                bottom=False,
                left=False,
                right=False,
                labelleft=True,
                labelbottom=True,
            )
            ax.set_ylabel("x_coords")
        else:
            ax.axis("off")
            fig.subplots_adjust(
                left=0, bottom=0, right=1, top=1, wspace=None, hspace=None
            )

    if num_genes < num_rows * num_cols:
        for i in range(num_genes, num_rows * num_cols):
            row_idx = i // num_cols
            col_idx = i % num_cols
            fig.delaxes(axes[row_idx, col_idx])
    for ax in axes.flat:
        ax.invert_yaxis()
    fig.tight_layout()
    fig.subplots_adjust(hspace=0.3)
    ref_tile = ops["ref_tile"]
    mouse_name = (processed_path).parent.name
    chamber_name = (processed_path).name
    fig.suptitle(
        f"{mouse_name} {chamber_name} ROI: {roi}, ref tile: {ref_tile}",
        fontsize=40,
        y=1.01,
    )
    save_path = processed_path / "figures" / "round_overviews" / "gene_spots"
    save_path.mkdir(exist_ok=True, parents=True)
    print(f"Saving to: {save_path}")
    plt.savefig(save_path / f"roi_{roi}_gene_spots.png", bbox_inches="tight")
    plt.show()

In [None]:
data_path = "karapir_barseq_5ht/FIAA53.3b/chamber_02/"
roi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

for iroi in trange(1, 2, desc="ROI", total=len(roi)):
    processed_path = iss.io.get_processed_path(data_path)
    gene_spots = pd.read_pickle(processed_path / f"genes_round_spots_{iroi}.pkl")
    gene_spots.rename(columns={"gene": "Gene"}, inplace=True)
    gene_spots = gene_spots[gene_spots["spot_score"] > 0.2]
    plot_genes(gene_spots, data_path, roi=iroi, plot_grid=True)

In [None]:
import plotly.express as px

# Create the scatter plot
fig = px.scatter(
    gene_spots,
    x="y",
    y="x",
    color="Gene",  # This will automatically color the spots based on the 'Gene' column
    hover_data=[
        "spot_score",
        "pos_pixels",
        "neg_pixels",
        "tile",
    ],  # Add additional hover information
)

fig.update_traces(marker=dict(size=3))

# Update the layout to make the axes aspect equal
fig.update_layout(
    yaxis_scaleanchor="x",  # This ensures that the x and y axes have the same scale
    yaxis_scaleratio=1,
    xaxis_constrain="domain",
    yaxis_constrain="domain",
    height=900,
    # Set background and grid colors
    plot_bgcolor="white",  # Set background color to white
    xaxis=dict(
        showgrid=True,  # Show grid lines
        gridcolor="black",  # Set grid color to black
        zerolinecolor="black",  # Set zero line color to black
    ),
    yaxis=dict(
        showgrid=True,  # Show grid lines
        gridcolor="black",  # Set grid color to black
        zerolinecolor="black",  # Set zero line color to black
    ),
)

# Update legend marker size
fig.update_layout(
    legend=dict(
        itemsizing="constant",
        itemclick="toggleothers",
        itemdoubleclick="toggle",
        title_text="Legend",
        font=dict(size=12),
        traceorder="normal",
    )
)

fig.show()