In [None]:
import os
import pandas as pd
import numpy as np
from pathlib import Path
import natsort
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys

sys.path.append(
    "/nemo/lab/znamenskiyp/home/users/becalia/code/multi_padlock_design/lib"
)
sys.path.append("/nemo/lab/znamenskiyp/home/users/becalia/code/multi_padlock_design")
sys.path.append(
    "/nemo/lab/znamenskiyp/home/users/becalia/code/multi_padlock_design/notebooks"
)
from lib import check_padlocks
import config

In [None]:
tqdm.pandas()
blast_file_dir = Path("/nemo/lab/znamenskiyp/scratch/olfr_queries")
ref_path = Path("/nemo/lab/znamenskiyp/home/shared/resources/refseq/")
check_padlocks.loaddb("mouse", config)

genes = [
    os.path.basename(file).split("_query")[0]
    for file in os.listdir(blast_file_dir)
    if file.endswith("query_blast.out")
]
print(genes)
print(len(genes))
for gene in genes:
    check_padlocks.find_off_targets(gene, blast_file_dir, ref_path)

# make a df from all the files that end with _off_targets.out
files = glob.glob(str(blast_file_dir / "*_off_targets.out"))
# ignore the first row of each file
df = pd.concat([pd.read_csv(file, header=0) for file in files])
# sort df with natsort on column
df = df.iloc[natsort.index_natsorted(df["0"])]
header = [
    "query",
    "subject",
    "percentage identity",
    "length",
    "mismatches",
    "gaps",
    "qstart",
    "qend",
    "sstart",
    "send",
    "evalue",
    "bitscore",
    "qseq",
    "sseq",
    "homology_candidate_hit",
    "gene",
]
# remove the first column
df = df.drop(columns=["Unnamed: 0"])
df.columns = header
df.sort_values(by=["evalue"], inplace=True)

precomputed_variants = check_padlocks.precompute_variants(df, ref_path)

# Running the optimized processing function
df = check_padlocks.process_dataframe(
    df, armlength=20, tm_threshold=37, precomputed_variants=precomputed_variants
)

# Extract the variants
df["variant"] = df["subject"].str.split(".", n=1, expand=True)[0]

# Get unique variants to reduce the number of lookups
unique_variants = df["variant"].unique()

# Get genes for unique variants
genes_dict = check_padlocks.find_genes_from_variants(unique_variants, "mouse", config)

# Map genes back to the DataFrame
df["blast_target"] = df["variant"].map(genes_dict)

# Simplify the 'offtarget' field
df["blast_target"] = df["blast_target"].apply(
    lambda x: x[0] if isinstance(x, list) and len(x) > 0 else None
)

# Drop the temporary 'variant' column
df = df.drop(columns=["variant"])

# Load the conversion table
conversion_table = pd.read_csv(Path().cwd().parent / "data/updated_idmap.csv")

# Remove the rows that have NaN in the alias column
conversion_table = conversion_table.dropna(subset=["alias"])

# Create a dictionary for fast lookups
alias_to_symbol = {}

for index, row in conversion_table.iterrows():
    aliases = row["symbol"].split(", ")
    for alias in aliases:
        alias_to_symbol[alias] = row["query"]


# Function to find the symbol using the dictionary
def find_symbol(gene_name, alias_to_symbol):
    return alias_to_symbol.get(gene_name, gene_name)


df["converted_blast_target"] = df["blast_target"].apply(
    find_symbol, args=(alias_to_symbol,)
)
df["converted_gene_name"] = df["gene"].apply(find_symbol, args=(alias_to_symbol,))

# Calculate Melting temp (requires Melting 5)

In [None]:
melting_5 = True
if melting_5:
    df = check_padlocks.process_dataframe_in_batches(df, batch_size=300)
else:
    df = pd.read_csv("blast_results_olfr_full_across_tms.csv")

# Plot Tm vs no. padlocks

In [None]:
# Sum the True values for each "valid_xx" column
valid_columns = [f"valid_{i}" for i in range(20, 51)]
true_sums = df[valid_columns].sum()

# Plotting the sums
plt.figure(figsize=(8, 5), dpi=200)
plt.plot(valid_columns, true_sums, marker="o")

plt.xlabel("Melting5 Tm")
plt.ylabel("Number of valid padlocks")
# plt.grid(True)
plt.xticks(range(len(valid_columns)), range(20, 51))
plt.tight_layout()

plt.show()

# Plot number of valid paldocks per gene

In [None]:
grouped_queries = df.groupby("query")

df_padlocks = grouped_queries.agg(
    valid_specific=(
        "query",
        lambda x: any(
            df.loc[x.index, "valid_probe_melting"] & df.loc[x.index, "specific"]
        ),
    ),
    valid_non_specific=(
        "query",
        lambda x: any(
            df.loc[x.index, "valid_probe_melting"] & ~df.loc[x.index, "specific"]
        ),
    ),
).reset_index()

# Merge df_padlocks with the original df to bring padlock_target_gene into df_padlocks
df_merged = pd.merge(
    df_padlocks, df[["query", "gene"]], on="query", how="left"
).drop_duplicates()

# Group by gene and aggregate valid_specific and valid_non_specific
df_grouped_by_gene = (
    df_merged.groupby("gene")
    .agg(
        valid_specific=("valid_specific", "any"),
        valid_non_specific=("valid_non_specific", "any"),
    )
    .reset_index()
)

# Count the number of padlocks
df_counts = (
    df_merged.groupby("gene")
    .agg(
        number_of_specific_padlocks=("valid_specific", lambda x: x.sum()),
        number_of_non_specific_padlocks=("valid_non_specific", lambda x: x.sum()),
    )
    .reset_index()
)

df_padlocks.to_csv("padlock_specificity_olfrs.csv", index=False)
df_grouped_by_gene.to_csv("gene_specificity_olfrs.csv", index=False)
df_counts.to_csv("number_of_specific_padlocks_olfrs.csv", index=False)

# Using df_counts to get the counts of specific and non-specific padlocks
padlocks_per_gene_specific = df_counts.set_index("gene")["number_of_specific_padlocks"]
padlocks_per_gene_non_specific = df_counts.set_index("gene")[
    "number_of_non_specific_padlocks"
]

# Combine the counts into one DataFrame for plotting
genes = padlocks_per_gene_specific.index.union(padlocks_per_gene_non_specific.index)
combined_counts = pd.DataFrame(
    {
        "Specific Padlocks": padlocks_per_gene_specific.reindex(genes, fill_value=0),
        "Non-Specific Padlocks": padlocks_per_gene_non_specific.reindex(
            genes, fill_value=0
        ),
    }
)

# Plotting the data
fig, ax = plt.subplots(figsize=(60, 6))
combined_counts["Specific Padlocks"].plot(
    kind="bar", color="black", ax=ax, position=0, width=0.4
)
combined_counts["Non-Specific Padlocks"].plot(
    kind="bar", color="red", ax=ax, position=1, width=0.4
)

# Customizing the plot
ax.set_title("Specific and Non-Specific Padlocks per Gene")
ax.set_xlabel("Gene")
ax.set_ylabel("Number of Padlocks")
ax.legend(["Specific Padlocks", "Non-Specific Padlocks"])
plt.xticks(rotation=90)
plt.tight_layout()

plt.show()

# Plot Tm from each method

In [None]:
plt.figure(figsize=(8, 8), dpi=200)
plt.title("Including ligation site missmatch seqs")
plt.scatter(
    df["tm_left_NN"], df["tm_left_melting"], c=(df["mismatches"]), s=0.1, vmin=0, vmax=6
)
plt.scatter(
    df["tm_right_NN"],
    df["tm_right_melting"],
    c=(df["mismatches"]),
    s=0.1,
    vmin=0,
    vmax=6,
)
cbar = plt.colorbar(label="No. total mismatches/gaps", fraction=0.046, pad=0.04)
# Plot the line y=x
plt.plot([-50, 100], [-50, 100], color="red", lw=1, alpha=0.3)
plt.xlabel("Tm_NN (C)")
plt.ylabel("Tm Melting5 (C)")
plt.gca().set_aspect("equal", adjustable="box")
# add grid lines every 10 degrees
plt.xticks(np.arange(-50, 110, 10))
plt.yticks(np.arange(-50, 110, 10))
# Add bold line at 0
plt.axvline(0, color="black", lw=0.5)
plt.axhline(0, color="black", lw=0.5)
plt.ylim(-50, 80)
plt.xlim(-50, 70)
plt.grid(True, which="both", linestyle="--", lw=0.5)
plt.show()

In [None]:
# Plot the data
plt.figure(figsize=(8, 8), dpi=200)
plt.title("NN probe cutoff")

# Plot where valid_probe_NN is False (black)
valid_false = df[df["valid_probe_NN"] == False]
plt.scatter(
    valid_false["tm_left_NN"],
    valid_false["tm_left_melting"],
    c="black",
    s=0.1,
    label="Valid Probe NN (False)",
)
plt.scatter(
    valid_false["tm_right_NN"], valid_false["tm_right_melting"], c="black", s=0.1
)

# Plot where valid_probe_NN is True (green)
valid_true = df[df["valid_probe_NN"] == True]
plt.scatter(
    valid_true["tm_left_NN"],
    valid_true["tm_left_melting"],
    c="lime",
    s=2,
    label="Valid Probe NN (True)",
)
plt.scatter(
    valid_true["tm_right_NN"],
    valid_true["tm_right_melting"],
    c="lime",
    s=2,
)

# Plot the line y=x
plt.plot([-50, 100], [-50, 100], color="red", lw=1, alpha=0.3)

# Add red vertical line at 37 degrees
plt.axvline(37, color="red", lw=0.5)


# Set labels and limits
plt.xlabel("Tm_NN (C)")
plt.ylabel("Tm Melting5 (C)")
plt.gca().set_aspect("equal", adjustable="box")

# Add grid lines every 10 degrees
plt.xticks(np.arange(-50, 110, 10))
plt.yticks(np.arange(-50, 110, 10))

# Add bold line at 0
plt.axvline(0, color="black", lw=0.5)
plt.axhline(0, color="black", lw=0.5)

# Set plot limits
plt.ylim(-50, 80)
plt.xlim(-50, 70)

# Add grid
plt.grid(True, which="both", linestyle="--", lw=0.5)

# Show legend
plt.legend()

# Display the plot
plt.show()

In [None]:
# Plot the data
plt.figure(figsize=(8, 8), dpi=200)
plt.title("Melting5 probe cutoff")

# Plot where valid_probe_NN is False (black)
valid_false = df[df["valid_probe_melting"] == False]
plt.scatter(
    valid_false["tm_left_NN"],
    valid_false["tm_left_melting"],
    c="black",
    s=0.1,
    label="Valid Probe Melting5 (False)",
)
plt.scatter(
    valid_false["tm_right_NN"], valid_false["tm_right_melting"], c="black", s=0.1
)

# Plot where valid_probe_NN is True (green)
valid_true = df[df["valid_probe_melting"] == True]
plt.scatter(
    valid_true["tm_left_NN"],
    valid_true["tm_left_melting"],
    c="lime",
    s=2,
    label="Valid Probe Melting5 (True)",
)
plt.scatter(valid_true["tm_right_NN"], valid_true["tm_right_melting"], c="lime", s=2)

# Add red line horizontal at 37
plt.axhline(37, color="red", lw=0.5)

# Plot the line y=x
plt.plot([-50, 100], [-50, 100], color="red", lw=1, alpha=0.3)

# Set labels and limits
plt.xlabel("Tm_NN (C)")
plt.ylabel("Tm Melting5 (C)")
plt.gca().set_aspect("equal", adjustable="box")

# Add grid lines every 10 degrees
plt.xticks(np.arange(-50, 110, 10))
plt.yticks(np.arange(-50, 110, 10))

# Add bold line at 0
plt.axvline(0, color="black", lw=0.5)
plt.axhline(0, color="black", lw=0.5)

# Set plot limits
plt.ylim(-50, 80)
plt.xlim(-50, 70)

# Add grid
plt.grid(True, which="both", linestyle="--", lw=0.5)

# Show legend
plt.legend()

# Display the plot
plt.show()

In [None]:
from matplotlib_venn import venn3
from matplotlib_venn.layout.venn3 import DefaultLayoutAlgorithm

# Create sets for each condition
set_NN = set(df.index[df["valid_probe_NN"]])
set_melting = set(df.index[df["valid_probe_melting"]])
set_old_filters = set(df.index[df["valid_probe_old_filters"]])

# Create Venn Diagram
plt.figure(figsize=(8, 8), dpi=200)
venn3(
    [set_NN, set_melting, set_old_filters],
    set_labels=("NN", "Melting", "Old Filters"),
    layout_algorithm=DefaultLayoutAlgorithm(normalize_to=1),
)

plt.title("Venn Diagram of Valid Probes")
plt.show()