In [2]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import numpy as np
from pathlib import Path
import natsort
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys

from lib import check_padlocks
import config

In [None]:
tqdm.pandas()
blast_file_dir = Path("/nemo/lab/znamenskiyp/scratch/olfrs_monahan/blast_queries/")
check_padlocks.loaddb("mouse", config)

genes = [
    os.path.basename(file).split("_query")[0]
    for file in os.listdir(blast_file_dir)
    if file.endswith("query_blast.out")
]
print(genes)
print(len(genes))
failed_matches = []
for gene in genes:
    failed_match = check_padlocks.find_off_targets(gene, blast_file_dir)
    if failed_match:
        failed_matches.append(failed_match)

# make a df from all the files that end with _off_targets.out
files = glob.glob(str(blast_file_dir / "*_off_targets.out"))
# ignore the first row of each file
df = pd.concat([pd.read_csv(file, header=0) for file in files])
# sort df with natsort on column
df = df.iloc[natsort.index_natsorted(df["0"])]
header = [
    "query",
    "subject",
    "percentage identity",
    "length",
    "mismatches",
    "gaps",
    "qstart",
    "qend",
    "sstart",
    "send",
    "evalue",
    "bitscore",
    "qseq",
    "sseq",
    "homology_candidate_hit",
    "gene",
]
# remove the first column
df = df.drop(columns=["Unnamed: 0"])
df.columns = header
df.sort_values(by=["evalue"], inplace=True)

precomputed_variants = check_padlocks.precompute_variants(df)

# Get unique variants to reduce the number of lookups
unique_variants = df["subject"].unique()

# Get genes for unique variants
genes_dict = check_padlocks.find_genes_from_variants(unique_variants, "mouse", config)

# Map genes back to the DataFrame
df["blast_target"] = df["subject"].map(genes_dict)

# Simplify the 'offtarget' field
df["blast_target"] = df["blast_target"].apply(
    lambda x: x[0] if isinstance(x, list) and len(x) > 0 else None
)

# Running the optimized processing function
df = check_padlocks.process_dataframe(
    df, armlength=20, tm_threshold=37, precomputed_variants=precomputed_variants
)

In [None]:
from pathlib import Path


def create_missing_fasta_list(fasta_list_path):
    fasta_list_path = Path(fasta_list_path).expanduser().resolve()

    # Read existing fasta entries
    with open(fasta_list_path) as f:
        fasta_paths = [Path(line.strip()) for line in f if line.strip()]

    if not fasta_paths:
        raise ValueError("fasta_list.txt is empty.")

    folder = fasta_paths[0].parent

    # Genes listed in fasta_list.txt
    genes_in_fasta = set()
    for p in fasta_paths:
        if p.name.endswith("_query.fasta"):
            genes_in_fasta.add(p.name.rsplit("_query.fasta", 1)[0])

    # Genes that already have a _query_blast.out file
    blast_files = folder.glob("*_query_blast.out")
    genes_with_blast = {bf.name.rsplit("_query_blast.out", 1)[0] for bf in blast_files}

    # Genes missing BLAST output
    missing_genes = sorted(genes_in_fasta - genes_with_blast)

    out_path = folder / "fasta_list2.txt"
    with open(out_path, "w") as out:
        for gene in missing_genes:
            out.write(str(folder / f"{gene}_query.fasta") + "\n")

    print(f"Total genes in fasta_list: {len(genes_in_fasta)}")
    print(f"Genes with BLAST output: {len(genes_with_blast & genes_in_fasta)}")
    print(f"Missing BLAST outputs:   {len(missing_genes)}")
    print(f"Wrote: {out_path}")


# Run (update the path if needed)
create_missing_fasta_list(
    "/nemo/lab/znamenskiyp/scratch/olfrs_monahan/blast_queries/fasta_list.txt"
)

In [None]:
df.to_csv("Olfrs_monahan_BLAST_Tm_results_split_arms.csv", index=False)

# Calculate Melting temp (requires Melting 5)

In [None]:
melting_5 = True
if melting_5:
    df = check_padlocks.process_dataframe_in_batches(df, batch_size=300)
    df.to_csv("Olfrs_monahan_BLAST_Tm_results_split_arms.csv", index=False)
else:
    df = pd.read_csv("blast_results_olfr_full_across_tms.csv")

# Plot Tm vs no. padlocks

In [None]:
# Sum the True values for each "valid_xx" column
valid_columns = [f"valid_{i}" for i in range(20, 51)]
true_sums = df[valid_columns].sum()

# Plotting the sums
plt.figure(figsize=(8, 5), dpi=200)
plt.plot(valid_columns, true_sums, marker="o")

plt.xlabel("Melting5 Tm")
plt.ylabel("Number of valid padlocks")
# plt.grid(True)
plt.xticks(range(len(valid_columns)), range(20, 51))
plt.tight_layout()

plt.show()

# Plot number of valid padlocks per gene

In [None]:
grouped_queries = df.groupby("query")

df_padlocks = grouped_queries.agg(
    valid_specific=(
        "query",
        lambda x: any(df.loc[x.index, "valid_probe_NN"] & df.loc[x.index, "specific"]),
    ),
    valid_non_specific=(
        "query",
        lambda x: any(df.loc[x.index, "valid_probe_NN"] & ~df.loc[x.index, "specific"]),
    ),
).reset_index()

# Merge df_padlocks with the original df to bring padlock_target_gene into df_padlocks
df_merged = pd.merge(
    df_padlocks, df[["query", "gene"]], on="query", how="left"
).drop_duplicates()

# Group by gene and aggregate valid_specific and valid_non_specific
df_grouped_by_gene = (
    df_merged.groupby("gene")
    .agg(
        valid_specific=("valid_specific", "any"),
        valid_non_specific=("valid_non_specific", "any"),
    )
    .reset_index()
)

# Count the number of padlocks
df_counts = (
    df_merged.groupby("gene")
    .agg(
        number_of_specific_padlocks=("valid_specific", lambda x: x.sum()),
        number_of_non_specific_padlocks=("valid_non_specific", lambda x: x.sum()),
    )
    .reset_index()
)

df_padlocks.to_csv("padlock_specificity_olfrs.csv", index=False)
df_grouped_by_gene.to_csv("gene_specificity_olfrs.csv", index=False)
df_counts.to_csv("number_of_specific_padlocks_olfrs.csv", index=False)

# Using df_counts to get the counts of specific and non-specific padlocks
padlocks_per_gene_specific = df_counts.set_index("gene")["number_of_specific_padlocks"]
padlocks_per_gene_non_specific = df_counts.set_index("gene")[
    "number_of_non_specific_padlocks"
]

# Combine the counts into one DataFrame for plotting
genes = padlocks_per_gene_specific.index.union(padlocks_per_gene_non_specific.index)
combined_counts = pd.DataFrame(
    {
        "Specific Padlocks": padlocks_per_gene_specific.reindex(genes, fill_value=0),
        "Non-Specific Padlocks": padlocks_per_gene_non_specific.reindex(
            genes, fill_value=0
        ),
    }
)

# Plotting the data
fig, ax = plt.subplots(figsize=(60, 6))
combined_counts["Specific Padlocks"].plot(
    kind="bar", color="black", ax=ax, position=0, width=0.4
)
combined_counts["Non-Specific Padlocks"].plot(
    kind="bar", color="red", ax=ax, position=1, width=0.4
)

# Customizing the plot
ax.set_title("Specific and Non-Specific Padlocks per Gene")
ax.set_xlabel("Gene")
ax.set_ylabel("Number of Padlocks")
ax.legend(["Specific Padlocks", "Non-Specific Padlocks"])
plt.xticks(rotation=90)
plt.tight_layout()

plt.show()

# Plot Tm of the left and right arms

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

plt.figure(figsize=(8, 8), dpi=200)
plt.title("Including ligation site missmatch seqs")

# Scatter: unchanged
sc = plt.scatter(
    df["tm_right_NN"],
    df["tm_left_NN"],  # df["tm_right_melting"],
    c=df["mismatches"],
    s=0.1,
    vmin=0,
    vmax=6,
)

kde_n = 100000
df_kde = df[["tm_right_NN", "tm_left_NN"]].dropna()
if len(df_kde) > kde_n:
    df_kde = df_kde.sample(n=kde_n, random_state=42)

sns.kdeplot(
    x=df_kde["tm_right_NN"],
    y=df_kde["tm_left_NN"],
    levels=10,
    gridsize=120,  # lower = faster; default is 200
    cut=0,  # don't evaluate beyond data range
    bw_adjust=0.4,
    palette="Reds",
)

# Colorbar tied to the scatter (so it reflects mismatches)
cbar = plt.colorbar(sc, label="No. total mismatches/gaps", fraction=0.046, pad=0.04)

# Line y=x
plt.plot([-50, 100], [-50, 100], color="red", lw=1, alpha=0.3)

plt.xlabel("Tm_NN right (C)")
plt.ylabel("Tm_NN left (C)")
plt.gca().set_aspect("equal", adjustable="box")
plt.xticks(np.arange(-50, 110, 10))
plt.yticks(np.arange(-50, 110, 10))
plt.axvline(0, color="black", lw=0.5)
plt.axhline(0, color="black", lw=0.5)
plt.ylim(-50, 80)
plt.xlim(-50, 70)
plt.grid(True, which="both", linestyle="--", lw=0.5)
plt.show()

In [None]:
# Plot the data
plt.figure(figsize=(8, 8), dpi=200)
plt.title("NN probe cutoff")

# Plot where valid_probe_NN is False (black)
valid_false = df[df["valid_probe_NN"] == False]
plt.scatter(
    valid_false["tm_left_NN"],
    valid_false["tm_left_melting"],
    c="black",
    s=0.1,
    label="Valid Probe NN (False)",
)
plt.scatter(
    valid_false["tm_right_NN"], valid_false["tm_right_melting"], c="black", s=0.1
)

# Plot where valid_probe_NN is True (green)
valid_true = df[df["valid_probe_NN"] == True]
plt.scatter(
    valid_true["tm_left_NN"],
    valid_true["tm_left_melting"],
    c="lime",
    s=2,
    label="Valid Probe NN (True)",
)
plt.scatter(
    valid_true["tm_right_NN"],
    valid_true["tm_right_melting"],
    c="lime",
    s=2,
)

# Plot the line y=x
plt.plot([-50, 100], [-50, 100], color="red", lw=1, alpha=0.3)

# Add red vertical line at 37 degrees
plt.axvline(37, color="red", lw=0.5)


# Set labels and limits
plt.xlabel("Tm_NN (C)")
plt.ylabel("Tm Melting5 (C)")
plt.gca().set_aspect("equal", adjustable="box")

# Add grid lines every 10 degrees
plt.xticks(np.arange(-50, 110, 10))
plt.yticks(np.arange(-50, 110, 10))

# Add bold line at 0
plt.axvline(0, color="black", lw=0.5)
plt.axhline(0, color="black", lw=0.5)

# Set plot limits
plt.ylim(-50, 80)
plt.xlim(-50, 70)

# Add grid
plt.grid(True, which="both", linestyle="--", lw=0.5)

# Show legend
plt.legend()

# Display the plot
plt.show()

In [None]:
# Plot the data
plt.figure(figsize=(8, 8), dpi=200)
plt.title("Melting5 probe cutoff")

# Plot where valid_probe_NN is False (black)
valid_false = df[df["valid_probe_melting"] == False]
plt.scatter(
    valid_false["tm_left_NN"],
    valid_false["tm_left_melting"],
    c="black",
    s=0.1,
    label="Valid Probe Melting5 (False)",
)
plt.scatter(
    valid_false["tm_right_NN"], valid_false["tm_right_melting"], c="black", s=0.1
)

# Plot where valid_probe_NN is True (green)
valid_true = df[df["valid_probe_melting"] == True]
plt.scatter(
    valid_true["tm_left_NN"],
    valid_true["tm_left_melting"],
    c="lime",
    s=2,
    label="Valid Probe Melting5 (True)",
)
plt.scatter(valid_true["tm_right_NN"], valid_true["tm_right_melting"], c="lime", s=2)

# Add red line horizontal at 37
plt.axhline(37, color="red", lw=0.5)

# Plot the line y=x
plt.plot([-50, 100], [-50, 100], color="red", lw=1, alpha=0.3)

# Set labels and limits
plt.xlabel("Tm_NN (C)")
plt.ylabel("Tm Melting5 (C)")
plt.gca().set_aspect("equal", adjustable="box")

# Add grid lines every 10 degrees
plt.xticks(np.arange(-50, 110, 10))
plt.yticks(np.arange(-50, 110, 10))

# Add bold line at 0
plt.axvline(0, color="black", lw=0.5)
plt.axhline(0, color="black", lw=0.5)

# Set plot limits
plt.ylim(-50, 80)
plt.xlim(-50, 70)

# Add grid
plt.grid(True, which="both", linestyle="--", lw=0.5)

# Show legend
plt.legend()

# Display the plot
plt.show()

In [None]:
from matplotlib_venn import venn3
from matplotlib_venn.layout.venn3 import DefaultLayoutAlgorithm

# Create sets for each condition
set_NN = set(df.index[df["valid_probe_NN"]])
set_melting = set(df.index[df["valid_probe_melting"]])
set_old_filters = set(df.index[df["valid_probe_old_filters"]])

# Create Venn Diagram
plt.figure(figsize=(8, 8), dpi=200)
venn3(
    [set_NN, set_melting, set_old_filters],
    set_labels=("NN", "Melting", "Old Filters"),
    layout_algorithm=DefaultLayoutAlgorithm(normalize_to=1),
)

plt.title("Venn Diagram of Valid Probes")
plt.show()

# Calculate self dimers, hairpins and heterodimers

In [None]:
# Run it on your DataFrame `df` (must have a 'padlock' column, and optionally a 'name' column)
df = pd.read_csv("monahan_panel_barcoded.csv")
result_df = check_padlocks.annotate_with_thermo(df, n_jobs=20)
result_df.to_csv("monahan_panel_barcoded_with_thermo_trimmed.csv", index=False)


# Left-tail counts beyond 2σ and 3σ for the same three metrics
fig, axes = plt.subplots(1, 3, figsize=(18, 5), dpi=200)
hist_data = [
    (df["hairpin_dg_kcalmol"], "Hairpin ΔG (kcal/mol)"),
    (df["homodimer_dg_kcalmol"], "Homodimer ΔG (kcal/mol)"),
    (df["best_heterodimer_dg_kcalmol"], "Best Heterodimer ΔG (kcal/mol)"),
]

for ax, (data, title) in zip(axes, hist_data):
    data_clean = data.dropna().values

    # Plot histogram
    if data_clean.size == 0:
        ax.set_title(title + " (no data)")
        ax.set_xlabel("ΔG (kcal/mol)")
        ax.set_ylabel("Count")
        continue

    n, bins, patches = ax.hist(
        data_clean, bins=50, color="skyblue", edgecolor="black", alpha=0.7
    )
    ax.set_title(title + " — left tail 2σ/3σ counts")
    ax.set_xlabel("ΔG (kcal/mol)")
    ax.set_ylabel("Count")

    # Mean and sample standard deviation
    mu = float(np.mean(data_clean))
    sigma = float(np.std(data_clean, ddof=1)) if data_clean.size > 1 else 0.0

    y_max = ax.get_ylim()[1]
    if sigma > 0:
        t2 = mu - 2 * sigma
        t3 = mu - 3 * sigma

        # Left-tail counts
        n2 = int((data_clean <= t2).sum())
        n3 = int((data_clean <= t3).sum())

        # Annotate thresholds
        ax.axvline(t2, color="red", linestyle="--", lw=1, label="μ-2σ")
        ax.axvline(t3, color="purple", linestyle="--", lw=1, label="μ-3σ")
        ax.text(
            t2,
            0.9 * y_max,
            "μ-2σ",
            color="red",
            rotation=90,
            va="top",
            ha="right",
            fontsize=8,
        )
        ax.text(
            t3,
            0.9 * y_max,
            "μ-3σ",
            color="purple",
            rotation=90,
            va="top",
            ha="right",
            fontsize=8,
        )

        # Shade left-tail regions
        ax.fill_betweenx([0, y_max], bins[0], t2, color="red", alpha=0.06)
        ax.fill_betweenx([0, y_max], bins[0], t3, color="purple", alpha=0.06)

        # Display counts
        ax.text(
            bins[0], 0.80 * y_max, f"≤ μ-2σ: n={n2}", color="red", ha="left", fontsize=9
        )
        ax.text(
            bins[0],
            0.70 * y_max,
            f"≤ μ-3σ: n={n3}",
            color="purple",
            ha="left",
            fontsize=9,
        )

        ax.legend(loc="upper right", fontsize=8)
    else:
        ax.text(
            0.5, 0.5, "σ = 0 or insufficient data", transform=ax.transAxes, ha="center"
        )

plt.tight_layout()
plt.show()

# Compute and plot graph of worst heterodimer offenders

In [None]:
# Heterodimer interaction network: highlight oligos responsible for strongest interactions (with progress + faster layout options + outlier filtering + log-scale bar + log colourscale)
import time
from ast import literal_eval

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import cm, colors
from tqdm.auto import tqdm

try:
    import networkx as nx
except ImportError:
    raise RuntimeError(
        "networkx is required for this visualization. Please install it (e.g., pip install networkx)."
    )

start_all = time.perf_counter()
print("[0/8] Starting heterodimer network build...", flush=True)

# Config knobs
use_percentile_threshold = True  # if True, use percentile instead of fixed threshold
percentile_cut = 5  # keep edges at or below this ΔG percentile (most negative)
fixed_threshold_dg = -8.0  # kcal/mol, used if use_percentile_threshold is False
top_k = 100  # number of top offenders to list/bar-plot

# Visualization performance knobs
layout_mode = "auto"  # one of: "auto"|"sfdp"|"fa2"|"spring"|"spectral"
spring_iterations = 20  # fewer iterations to speed up
limit_nodes_for_plot = True  # draw only the top-N offending oligos to speed up layout
max_nodes_to_draw = 1000  # if graph is larger than this, plot a top-N subgraph

# Layout outlier filtering (post-layout)
hide_position_outliers = (
    True  # if True, drop nodes with extreme x/y positions from the drawn graph
)
pos_outlier_quantile = 0.005  # drop nodes outside these lower/upper quantiles on x or y

# Expected columns
col_name = "name"
col_partners = "heterodimer_candidate_partners"
col_dgs = "heterodimer_candidate_dgs_kcalmol"


# Safety: turn potential stringified lists back into Python lists
def ensure_list(x):
    if isinstance(x, (list, tuple, np.ndarray)):
        return list(x)
    if isinstance(x, str):
        try:
            val = literal_eval(x)
            return list(val) if isinstance(val, (list, tuple, np.ndarray)) else []
        except Exception:
            return []
    return []


# Sanity checks
if not set([col_name, col_partners, col_dgs]).issubset(df.columns):
    missing = set([col_name, col_partners, col_dgs]) - set(df.columns)
    raise ValueError(f"Missing required columns in df: {missing}")
print(f"DataFrame rows: {len(df):,}", flush=True)

# Stage 1: Build an undirected edge list with best (most negative) ΔG per pair
t0 = time.perf_counter()
print("[1/8] Parsing candidates and computing best ΔG per oligo pair...", flush=True)
edge_best_dg = {}  # key: (a,b) sorted tuple, value: most negative dg
names_seen = set()

for _, row in tqdm(
    df.iterrows(), total=len(df), desc="Build best ΔG per pair", leave=False
):
    a = row[col_name]
    names_seen.add(a)
    partners = ensure_list(row[col_partners])
    dgs = ensure_list(row[col_dgs])
    for b, dg in zip(partners, dgs):
        if b is None or pd.isna(b) or a == b:
            continue
        try:
            dg_val = float(dg)
        except Exception:
            continue
        if not np.isfinite(dg_val):
            continue
        pair = tuple(sorted((a, b)))
        if pair not in edge_best_dg or dg_val < edge_best_dg[pair]:
            edge_best_dg[pair] = dg_val
print(
    f"  Unique pairs seen: {len(edge_best_dg):,}; unique oligos: {len(names_seen):,}",
    flush=True,
)
print(f"  Stage 1 took {time.perf_counter()-t0:.2f}s", flush=True)

if not edge_best_dg:
    raise ValueError(
        "No heterodimer candidate edges found after parsing. Check the input columns."
    )

# Stage 2: Compute threshold
t0 = time.perf_counter()
print("[2/8] Computing strong-interaction threshold...", flush=True)
all_dgs = np.array(list(edge_best_dg.values()), dtype=float)
if use_percentile_threshold:
    threshold_dg = np.percentile(all_dgs, percentile_cut)
    print(
        f"  Using percentile cut {percentile_cut}% -> threshold ΔG ≤ {threshold_dg:.3f} kcal/mol",
        flush=True,
    )
else:
    threshold_dg = fixed_threshold_dg
    print(f"  Using fixed threshold ΔG ≤ {threshold_dg:.3f} kcal/mol", flush=True)

# Stage 3: Keep only strong interactions
print("[3/8] Selecting strong edges...", flush=True)
strong_edges = {pair: dg for pair, dg in edge_best_dg.items() if dg <= threshold_dg}
print(
    f"  Strong edges: {len(strong_edges):,} out of {len(edge_best_dg):,} total pairs",
    flush=True,
)
print(f"  Stage 2+3 took {time.perf_counter()-t0:.2f}s", flush=True)

# Stage 4: Build graph with strong edges only
t0 = time.perf_counter()
print("[4/8] Building graph from strong edges...", flush=True)
G = nx.Graph()
G.add_nodes_from(names_seen)
for (u, v), dg in tqdm(
    strong_edges.items(), total=len(strong_edges), desc="Add strong edges", leave=False
):
    strength = -dg  # more negative dg => larger strength
    G.add_edge(u, v, dg=dg, strength=strength)
print(
    f"  Graph now has {G.number_of_nodes():,} nodes and {G.number_of_edges():,} edges",
    flush=True,
)
print(f"  Stage 4 took {time.perf_counter()-t0:.2f}s", flush=True)

# Stage 5: Compute offender score per node
t0 = time.perf_counter()
print("[5/8] Computing offense scores...", flush=True)
offense_score = {n: 0.0 for n in G.nodes()}
for u, v, data in tqdm(
    G.edges(data=True),
    total=G.number_of_edges(),
    desc="Accumulate edge strengths",
    leave=False,
):
    s = data.get("strength", 0.0)
    offense_score[u] += s
    offense_score[v] += s
isolates = [n for n in G.nodes() if G.degree(n) == 0]
G.remove_nodes_from(isolates)
for n in isolates:
    offense_score.pop(n, None)
print(
    f"  Removed isolates: {len(isolates):,}; remaining nodes: {G.number_of_nodes():,}",
    flush=True,
)
print(f"  Stage 5 took {time.perf_counter()-t0:.2f}s", flush=True)

if G.number_of_edges() == 0:
    raise ValueError(
        "No edges passed the strong-interaction threshold. Try relaxing the threshold or percentile."
    )

# Stage 6: Prepare node sizes and colors (for the graph we'll draw)
t0 = time.perf_counter()
print("[6/8] Preparing node sizes/colors...", flush=True)
# Optionally focus visualization on top-N offenders to speed up layout
draw_nodes = list(G.nodes())
if limit_nodes_for_plot and len(G) > max_nodes_to_draw:
    draw_nodes = [
        n
        for n, _ in sorted(offense_score.items(), key=lambda x: x[1], reverse=True)[
            :max_nodes_to_draw
        ]
    ]
    print(
        f"  Limiting draw to top-{len(draw_nodes)} nodes by offense score (of {len(G):,})",
        flush=True,
    )
draw_G = G.subgraph(draw_nodes).copy()

scores = np.array([offense_score[n] for n in draw_G.nodes()], dtype=float)
if scores.size == 0:
    raise ValueError(
        "No nodes with offense scores in draw graph. Check the thresholds."
    )
score_max = float(scores.max())
score_range = score_max if score_max > 0 else 1.0
node_sizes = {n: 80 + 520 * (offense_score[n] / score_range) for n in draw_G.nodes()}
cmap = cm.get_cmap("Reds")

# Use a log-based colour scale if possible, otherwise fall back to linear
used_log_scale = False
if np.any(scores > 0) and score_max > 0:
    min_pos = float(scores[scores > 0].min())
    vmin_log = max(min_pos, score_max * 1e-6)  # ensure strictly > 0 and < vmax
    if vmin_log < score_max:
        norm = colors.LogNorm(vmin=vmin_log, vmax=score_max)
        used_log_scale = True
    else:
        norm = colors.Normalize(vmin=0, vmax=score_max)
else:
    norm = colors.Normalize(vmin=0, vmax=max(1.0, score_max))

node_colors = {n: cmap(norm(offense_score[n])) for n in draw_G.nodes()}
edge_strengths = np.array(
    [data["strength"] for _, _, data in draw_G.edges(data=True)], dtype=float
)
s_max = float(edge_strengths.max()) if edge_strengths.size else 1.0
edge_widths = [
    0.5 + 3.0 * (data["strength"] / s_max) for _, _, data in draw_G.edges(data=True)
]
print(
    f"  Draw graph: {draw_G.number_of_nodes():,} nodes, {draw_G.number_of_edges():,} edges",
    flush=True,
)
print(f"  Score max: {score_max:.3f}; edge strength max: {s_max:.3f}", flush=True)
print(f"  Stage 6 took {time.perf_counter()-t0:.2f}s", flush=True)

# Stage 7: Layout (choose a faster option if available)
t0 = time.perf_counter()
print("[7/8] Computing layout...", flush=True)
chosen_mode = layout_mode
pos = None
if chosen_mode == "auto":
    # Prefer Graphviz sfdp if available (good for large graphs), else try FA2, else faster spring, else spectral
    try:
        from networkx.drawing.nx_agraph import \
            graphviz_layout  # requires pygraphviz + graphviz
        pos = graphviz_layout(draw_G, prog="sfdp")
        chosen_mode = "sfdp"
    except Exception:
        try:
            from fa2 import ForceAtlas2  # pip install fa2

            fa2 = ForceAtlas2(
                outboundAttractionDistribution=False,  # FA2 defaults
                edgeWeightInfluence=1.0,
                jitterTolerance=1.0,
                barnesHutOptimize=True,
                barnesHutTheta=1.2,
                scalingRatio=2.0,
                gravity=1.0,
                verbose=False,
            )
            pos_array = fa2.forceatlas2_networkx_layout(
                draw_G, pos=None, iterations=300
            )
            pos = pos_array
            chosen_mode = "fa2"
        except Exception:
            try:
                # Use spectral to initialize, then a few spring iterations for refinement
                pos_init = nx.spectral_layout(draw_G)
                pos = nx.spring_layout(
                    draw_G,
                    seed=42,
                    weight="strength",
                    iterations=spring_iterations,
                    pos=pos_init,
                )
                chosen_mode = "spring+spectral-init"
            except Exception:
                pos = nx.spring_layout(
                    draw_G, seed=42, weight="strength", iterations=spring_iterations
                )
                chosen_mode = "spring"
elif chosen_mode == "sfdp":
    from networkx.drawing.nx_agraph import graphviz_layout

    pos = graphviz_layout(draw_G, prog="sfdp")
elif chosen_mode == "fa2":
    from fa2 import ForceAtlas2

    fa2 = ForceAtlas2(
        outboundAttractionDistribution=False,
        edgeWeightInfluence=1.0,
        jitterTolerance=1.0,
        barnesHutOptimize=True,
        barnesHutTheta=1.2,
        scalingRatio=2.0,
        gravity=1.0,
        verbose=False,
    )
    pos = fa2.forceatlas2_networkx_layout(draw_G, pos=None, iterations=300)
elif chosen_mode == "spring":
    pos = nx.spring_layout(
        draw_G, seed=42, weight="strength", iterations=spring_iterations
    )
elif chosen_mode == "spectral":
    pos = nx.spectral_layout(draw_G)
else:
    pos = nx.spring_layout(
        draw_G, seed=42, weight="strength", iterations=spring_iterations
    )
    chosen_mode = "spring"
print(f"  Layout mode: {chosen_mode}", flush=True)
print(f"  Stage 7 took {time.perf_counter()-t0:.2f}s", flush=True)

# Optional Stage 7b: Filter positional outliers to avoid squashed views
if hide_position_outliers and len(draw_G) > 0:
    print("[7b] Filtering positional outliers from layout...", flush=True)
    xs = np.array([pos[n][0] for n in draw_G.nodes()], dtype=float)
    ys = np.array([pos[n][1] for n in draw_G.nodes()], dtype=float)
    q = float(pos_outlier_quantile)
    x_lo, x_hi = np.quantile(xs, [q, 1 - q])
    y_lo, y_hi = np.quantile(ys, [q, 1 - q])
    keep_nodes = [
        n
        for n in draw_G.nodes()
        if (x_lo <= pos[n][0] <= x_hi) and (y_lo <= pos[n][1] <= y_hi)
    ]
    removed = draw_G.number_of_nodes() - len(keep_nodes)
    if removed > 0:
        draw_G = draw_G.subgraph(keep_nodes).copy()
        # Filter visuals and recompute edge widths using the new subgraph
        node_sizes = {n: node_sizes[n] for n in draw_G.nodes()}
        node_colors = {n: node_colors[n] for n in draw_G.nodes()}
        pos = {n: pos[n] for n in draw_G.nodes()}
        edge_strengths = np.array(
            [data["strength"] for _, _, data in draw_G.edges(data=True)], dtype=float
        )
        s_max = float(edge_strengths.max()) if edge_strengths.size else 1.0
        edge_widths = [
            0.5 + 3.0 * (data["strength"] / s_max)
            for _, _, data in draw_G.edges(data=True)
        ]
        print(
            f"  Removed {removed} positional outliers (kept {len(keep_nodes)} nodes)",
            flush=True,
        )
    else:
        print("  No positional outliers removed", flush=True)

# Stage 8: Render
t0 = time.perf_counter()
print("[8/8] Rendering network and top-offenders plot...", flush=True)
fig = plt.figure(figsize=(16, 12), dpi=200)
gs = fig.add_gridspec(1, 2, width_ratios=[2, 1])
ax_net = fig.add_subplot(gs[0, 0])
ax_bar = fig.add_subplot(gs[0, 1])

nx.draw_networkx_edges(
    draw_G, pos, ax=ax_net, width=edge_widths, edge_color="gray", alpha=0.6
)
nx.draw_networkx_nodes(
    draw_G,
    pos,
    ax=ax_net,
    node_size=[node_sizes[n] for n in draw_G.nodes()],
    node_color=[node_colors[n] for n in draw_G.nodes()],
    linewidths=0.3,
    edgecolors="black",
    alpha=0.95,
)

top_labels = sorted(
    ((n, offense_score[n]) for n in draw_G.nodes()), key=lambda x: x[1], reverse=True
)[:10]
label_nodes = {n for n, _ in top_labels}
labels = {n: (n if n in label_nodes else "") for n in draw_G.nodes()}
nx.draw_networkx_labels(draw_G, pos, labels=labels, font_size=8, ax=ax_net)

sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax_net, fraction=0.046, pad=0.04)
cbar.set_label(
    "Offense score (sum of strong interaction strengths"
    + ("; log scale)" if used_log_scale else ")")
)

ax_net.set_title(
    f"Heterodimer interaction network (drawn: {draw_G.number_of_nodes()} nodes)\nStrong edges: ΔG ≤ {threshold_dg:.2f} kcal/mol; nodes sized/colored by offense score"
)
ax_net.axis("off")

top_series = pd.Series(offense_score).sort_values(ascending=False).head(top_k)
top_series.plot(kind="barh", ax=ax_bar, color="crimson", alpha=0.8)
ax_bar.invert_yaxis()
ax_bar.set_xscale("log")  # log scale for offense scores
ax_bar.set_xlabel("Offense score (sum of -ΔG over strong edges, log scale)")
ax_bar.set_title(f"Top {min(top_k, len(top_series))} offending oligos")
ax_bar.grid(axis="x", which="both", linestyle="--", alpha=0.3)
plt.tight_layout()
plt.show()
print(f"  Stage 8 took {time.perf_counter()-t0:.2f}s", flush=True)

# Summary and optional export (still computed on the full strong-edge graph G)
edges_export = (
    pd.DataFrame(
        [
            {
                "oligo_a": a,
                "oligo_b": b,
                "dg_kcalmol": data["dg"],
                "strength": data["strength"],
            }
            for a, b, data in G.edges(data=True)
        ]
    )
    .sort_values("dg_kcalmol")
    .reset_index(drop=True)
)
print(
    f"Done in {time.perf_counter()-start_all:.2f}s. Strong interactions: {len(edges_export)} edges among {G.number_of_nodes()} nodes (threshold ΔG ≤ {threshold_dg:.2f}).",
    flush=True,
)
# edges_export.to_csv('strong_heterodimer_edges.csv', index=False)

# Attempt to filter probes based solely on hairpin, homodimer and heterodimer values

In [None]:
# Heterodimer-aware filtering: keep up to N padlocks per gene using offense score + homo/hairpin ΔG

# --- Parameters ---
keep_per_gene = 10  # target per gene

# Choose filtering method: 'two_stage' or 'combined'
method = "combined"  # or 'combined'

# Columns
name_col = "name"
gene_col = "gene_name"
homodimer_col = "homodimer_dg_kcalmol"
hairpin_col = "hairpin_dg_kcalmol"
heterodimer_candidates = ["best_heterodimer_dg_kcalmol", "heterodimer_dg_kcalmol"]
heterodimer_col = next((c for c in heterodimer_candidates if c in df.columns), None)

# Two-stage config (used when method == 'two_stage')
drop_top_offenders_n = 1000  # drop this many worst global offenders (by offense_score)
drop_top_offenders_frac = (
    0.0  # or drop this fraction (0..1) of unique names; ignored if n > 0
)

# Per-gene weights (ranks)
w_homo = 1.0
w_hair = 1.0
w_offense = 1.0  # only used for method == 'combined'

# Offense score requirements
require_offense = True  # if True, error if offense_score is not available
offense_default = 0.0  # default if not required

# --- Validation & setup ---
# Fallback for gene column
if gene_col not in df.columns:
    for alt in ["gene", "padlock_target_gene", "query"]:
        if alt in df.columns:
            gene_col = alt
            break
if gene_col not in df.columns:
    raise ValueError(
        f"gene column '{gene_col}' not found; tried fallbacks ['gene','padlock_target_gene','query']"
    )

needed = [name_col, gene_col, homodimer_col, hairpin_col]
if heterodimer_col is None:
    raise ValueError(
        "No heterodimer ΔG column found (expected one of: "
        + ", ".join(heterodimer_candidates)
        + ")"
    )
needed.append(heterodimer_col)
missing = [c for c in needed if c not in df.columns]
if missing:
    raise ValueError(f"Missing required columns: {missing}")

# Bring offense_score into the DataFrame
if "offense_score" in globals():
    df = df.copy()
    df["offense_score"] = df[name_col].map(offense_score).fillna(offense_default)
elif require_offense:
    raise ValueError(
        "offense_score not found in the notebook environment. Run the heterodimer network cell first."
    )
else:
    df = df.copy()
    df["offense_score"] = offense_default

# Precompute a global offense penalty in [0,1] (higher is worse)
# Rank worst offenders as highest penalty
if len(df) > 1:
    offense_rank = df["offense_score"].rank(
        ascending=False, method="dense"
    )  # 1 = worst offender
    max_rank = float(offense_rank.max()) if offense_rank.notna().any() else 1.0
    df["_offense_penalty"] = (offense_rank - 1.0) / max(1.0, max_rank - 1.0)
else:
    df["_offense_penalty"] = 0.0


# --- Helpers ---
def _rank_per_gene(g: pd.DataFrame) -> pd.DataFrame:
    g = g.copy()
    # Convert and fill NaNs; very negative means very strong (bad), so we want higher better
    g["_hom"] = pd.to_numeric(g[homodimer_col], errors="coerce")
    g["_hai"] = pd.to_numeric(g[hairpin_col], errors="coerce")
    g["_het"] = pd.to_numeric(g[heterodimer_col], errors="coerce")
    g[["_hom", "_hai", "_het"]] = g[["_hom", "_hai", "_het"]].fillna(-np.inf)
    # Rank within gene: higher ΔG (weaker) should get smaller (better) rank index for selection
    # We will normalize ranks later when needed.
    g["_rank_hom"] = g["_hom"].rank(ascending=False, method="dense")
    g["_rank_hai"] = g["_hai"].rank(ascending=False, method="dense")
    return g


def _select_by_two_stage(g: pd.DataFrame, k: int) -> pd.DataFrame:
    g = _rank_per_gene(g)
    # Combined score using homodimer/hairpin only (offense removed already at stage 1)
    g["_score"] = w_homo * g["_rank_hom"] + w_hair * g["_rank_hai"]
    g = g.sort_values(
        ["_score", homodimer_col, hairpin_col], ascending=[True, False, False]
    )
    return g.head(k)


def _select_by_combined(g: pd.DataFrame, k: int) -> pd.DataFrame:
    g = _rank_per_gene(g)
    # Normalize per-gene ranks to [0,1] to be comparable with offense penalty
    n = max(1.0, float(len(g) - 1))
    g["_rank_hom_n"] = (g["_rank_hom"] - 1.0) / n
    g["_rank_hai_n"] = (g["_rank_hai"] - 1.0) / n
    # Global offense penalty already in df; merge in
    # Compute composite: lower is better
    g["_combo"] = (
        w_homo * g["_rank_hom_n"]
        + w_hair * g["_rank_hai_n"]
        + w_offense * g["_offense_penalty"]
    )
    g = g.sort_values(
        ["_combo", homodimer_col, hairpin_col, heterodimer_col],
        ascending=[True, False, False, False],
    )
    return g.head(k)


# --- Two-stage prefilter (drop global worst offenders) ---
dff = df.copy()
dropped_names = []
if method == "two_stage":
    # Determine offenders to drop
    unique_names = (
        dff[[name_col, "offense_score"]]
        .drop_duplicates(subset=[name_col])
        .sort_values("offense_score", ascending=False)
    )
    if drop_top_offenders_n and drop_top_offenders_n > 0:
        to_drop = unique_names.head(int(drop_top_offenders_n))[name_col].tolist()
    elif drop_top_offenders_frac and 0 < drop_top_offenders_frac < 1:
        m = int(len(unique_names) * drop_top_offenders_frac)
        to_drop = unique_names.head(m)[name_col].tolist()
    else:
        to_drop = []
    if to_drop:
        dropped_names = to_drop
        dff = dff[~dff[name_col].isin(to_drop)]
        print(f"Dropped {len(to_drop)} global worst offenders (by offense_score)")

# --- Apply per-gene selection ---
sizes = dff.groupby(gene_col, dropna=False).size()
genes_over = sizes[sizes > keep_per_gene].index
genes_small = sizes[sizes <= keep_per_gene].index

kept_frames = []
if len(genes_small) > 0:
    kept_frames.append(dff[dff[gene_col].isin(genes_small)])
if len(genes_over) > 0:
    if method == "two_stage":
        selected = (
            dff[dff[gene_col].isin(genes_over)]
            .groupby(gene_col, group_keys=False)
            .apply(lambda g: _select_by_two_stage(g, keep_per_gene))
        )
    elif method == "combined":
        selected = (
            dff[dff[gene_col].isin(genes_over)]
            .groupby(gene_col, group_keys=False)
            .apply(lambda g: _select_by_combined(g, keep_per_gene))
        )
    else:
        raise ValueError("Unknown method; use 'two_stage' or 'combined'")
    kept_frames.append(selected)

df_kept = pd.concat(kept_frames, axis=0).sort_index()
df_removed2 = dff.drop(df_kept.index)

# --- Summaries ---
total_genes = int(df[gene_col].nunique(dropna=False))
print(f"Genes total: {total_genes}")
print(f"Method: {method}")
if method == "two_stage":
    print(f"Global offenders dropped: {len(dropped_names)}")
print(f"Padlocks kept: {len(df_kept)} / {len(df)} (removed {len(df) - len(df_kept)})")
print(
    "Kept per gene (top 10):\n",
    df_kept.groupby(gene_col, dropna=False)
    .size()
    .sort_values(ascending=False)
    .head(10),
)

# Optional outputs
# df_kept.to_csv('padlocks_filtered_kept_heteroaware.csv', index=False)
# df_removed2.to_csv('padlocks_filtered_removed_heteroaware.csv', index=False)

# To make the filtered set your working df, uncomment:
# df = df_kept.copy()