# Process old probes and blast them

In [None]:
import os
import random
import re
from itertools import product
from pathlib import Path

import config
import lib.checkinput as checkinput
import lib.finalizeprobes as finalizeprobes
import lib.formatrefseq as formatrefseq
import lib.makesnails as makesnails
import lib.parblast as parblast
import lib.parmsa as parmsa
import lib.readblast as readblast
import lib.readfastafile as readfastafile
import lib.retrieveseq as retrieveseq
import lib.screenseq as screenseq
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from Bio import SeqIO
from tqdm import tqdm

## Read binding sites

In [None]:
probeset = pd.read_csv("pilot_snail_probes.csv")
probeset = probeset.fillna(0)

In [None]:
probeset

In [None]:
primer = []
melting_temp = []

for i, row in probeset.iterrows():
    number = int(row["PROBE"][-2:])
    if number < 10:
        continue
    primer.append(row["SEQUENCE"][-12:])
    melting_temp.append(readblast.calc_tm_NN(primer[-1]))

In [None]:
plt.hist(melting_temp)
plt.xlabel("Melting temperature of 3' primer")
plt.ylabel("Number of probes")
plt.show()

In [None]:
i = 0
probeset["SEQUENCE"][i][
    12 : -int(31 + probeset["Total pad"][i])
]  # assuming 6bp binding region at the begginning,
# 6bp at the end for binding, barcode of 7 and primer of 18.

In [None]:
probe = "/5Phos/GGGCACTCGAGTTGGCAGACCTCTGCGTCGGACTGTAGAACTCTATTAGTTTCGCACAA"
binding = probe[13:-33]
binding

In [None]:
genelist = list(np.ones(len(probeset)))
bindlist = list(np.ones(len(probeset)))
for row in probeset.iterrows():
    index = row[0]
    row = row[1]
    if row["PROBE"][-2] == "0":
        id = "padlock"
        binding = row["SEQUENCE"][12 : -int(31 + row["Total pad"])]
    elif row["PROBE"][-2] == "1":
        id = "primer"
        binding = row["SEQUENCE"][0:-13]
    else:
        print("catastrophic error")
    genelist[index] = row["PROBE"]
    bindlist[index] = binding

In [None]:
bindset = pd.DataFrame.from_dict({"probe": genelist, "binding": bindlist})

In [None]:
bindset

In [None]:
def merge_probe_primer(df: pd.DataFrame) -> pd.DataFrame:
    """
    Given a dataframe with columns ['probe', 'binding'] where probe names look like 'Gad1_01', 'Gad1_11', etc.,
    create pairs (01↔11, 02↔12, ...) per gene, concatenate binding as primer+probe, and
    return a dataframe with columns ['probe', 'binding'] named like 'Gad1_1'.
    """
    # Parse gene and numeric index from probe names
    m = df["probe"].str.extract(r"^(?P<gene>.+)_(?P<idx>\d+)$")
    if m.isna().any().any():
        raise ValueError(
            "Some probe names do not match the expected pattern '<gene>_<2-digit index>'"
        )

    work = df.copy()
    work["gene"] = m["gene"]
    work["idx"] = m["idx"].astype(int)
    print(work)

    # Split into "probe" (0k) and "primer" (1k). We treat idx 0-9 as probes, 10-19 as primers.
    # If your scheme extends beyond 19, this still generalizes because we match (k, k+10).
    probes = work[work["idx"] < 10].rename(columns={"binding": "probe_seq"})
    primers = work[work["idx"] >= 10].rename(columns={"binding": "primer_seq"})
    print(probes)
    print(primers)
    # Create k for matching (k pairs with k+10)
    probes["k"] = probes["idx"]  # 01 -> k=1, 02 -> k=2, ...
    primers["k"] = primers["idx"] - 10  # 11 -> k=1, 12 -> k=2, ...

    # Join on (gene, k)
    merged = pd.merge(
        probes[["gene", "k", "probe", "probe_seq"]],
        primers[["gene", "k", "probe", "primer_seq"]],
        on=["gene", "k"],
        how="inner",
        validate="one_to_one",
    )

    # Concatenate sequences: primer then probe
    merged["binding"] = merged["primer_seq"] + merged["probe_seq"]

    # Name as '<gene>_<k>' without leading zeros (keep original gene casing)
    merged["probe"] = merged.apply(lambda r: f"{r['gene']}_{int(r['k'])}", axis=1)

    # Final tidy dataframe
    out = merged[["probe", "binding"]].sort_values(["probe"]).reset_index(drop=True)
    return out

In [None]:
merged_sequences = merge_probe_primer(bindset)
merged_sequences

In [None]:
merged_sequences["length"] = [len(s) for s in merged_sequences["binding"]]

In [None]:
plt.hist(merged_sequences["length"], bins=20)
plt.title("DNA binding length")
plt.xlabel("Length (bp)")
plt.ylabel("Number of probes")

Some are exceptionally short!

In [None]:
probe = "AAATTTT"


def revcomp(seq):
    complement = str.maketrans("ATCG", "TAGC")
    return seq.translate(complement)[::-1]


revcomp(probe)

In [None]:
merged_sequences["target_rc"] = merged_sequences["binding"].apply(revcomp)
merged_sequences

## Blasting

In [None]:
# if you want to be safe:
BASE_DIR = os.path.abspath("/nemo/lab/znamenskiyp/home/shared/resources/")
reference_transcriptome = "ensembl"


blast_db_file = os.path.join(BASE_DIR, reference_transcriptome, "mouse.selected")


import os


def write_blast_script(
    df,
    db_path,
    outdir="blast_results",
    script_name="run_blast.sh",
    word_size=7,
    strand="plus",
):
    """
    Write a shell script with blastn commands for each probe sequence.

    Args:
        df (pd.DataFrame): must have columns ['probe','binding']
        db_path (str): path to blast database (without suffixes)
        outdir (str): directory for fasta and result files
        script_name (str): name of the shell script
        word_size (int): blastn word size
        strand (str): strand option for blastn
    """
    os.makedirs(outdir, exist_ok=True)
    script_path = os.path.join(outdir, script_name)

    with open(script_path, "w") as sh:
        sh.write("#!/bin/bash\n\n")
        sh.write(
            "#SBATCH --job-name=blast_job\n"
        )  # if you use SLURM; adapt for your scheduler
        sh.write("#SBATCH --output=blast_job.out\n\n")

        for _, row in df.iterrows():
            probe = row["probe"]
            seq = row["target_rc"]  # use reverse complement for off-target binding

            fasta_file = os.path.join(outdir, f"{probe}.fasta")
            out_file = os.path.join(outdir, f"{probe}_blast.txt")

            # write fasta
            with open(fasta_file, "w") as f:
                f.write(f">{probe}\n{seq}\n")

            # blast command
            cmd = (
                f"blastn -query {fasta_file} "
                f"-db {db_path} "
                f"-outfmt '10 std qseq sseq' "
                f"-out {out_file} "
                f"-word_size {word_size} -strand {strand}"
            )
            sh.write(cmd + "\n")

    # Make the script executable
    os.chmod(script_path, 0o755)
    return script_path

In [None]:
script_path = write_blast_script(
    merged_sequences,
    db_path=blast_db_file,
    outdir="blast_results",
    script_name="run_blast.sh",
    word_size=7,
    strand="plus",
)

## Show results

In [None]:
merged_sequences["pident"] = np.ones(len(merged_sequences))
merged_sequences["length"] = np.ones(len(merged_sequences))
merged_sequences["gap"] = np.ones(len(merged_sequences))

for idx, row in merged_sequences.iterrows():
    results = pd.read_csv("blast_results/" + row.probe + "_blast.txt")
    first_hit = results.iloc[0]
    merged_sequences.loc[idx, "pident"] = first_hit[2]
    merged_sequences.loc[idx, "length"] = first_hit[3]
    merged_sequences.loc[idx, "gap"] = first_hit[5]
merged_sequences

In [None]:
plt.hist(merged_sequences["gap"], bins=20)
plt.title("Gap length between primer and probe")
plt.xlabel("Gap length (bp)")
plt.ylabel("Number of probes")

In [None]:
plt.hist(merged_sequences["length"], bins=20)
plt.title("Length of first hit")
plt.xlabel("length (bp)")
plt.ylabel("Number of probes")

In [None]:
plt.scatter(merged_sequences["length"], merged_sequences["gap"])
plt.title("Length vs Gap of first hit")
plt.xlabel("Length (bp)")
plt.ylabel("Gap length (bp)")

## Running the first trial panel

Running the spapros_10K_box_sparseness panel, dropping Vip and Sst because they go to hybridization, and dropping Npsr1 and Rrad because they’re very, very bad. We include the four closest markers in the list, which are Cxcl14, Cort, Pld5, Htr2c. Run a new panel excluding these four directly with an entropy constraint, which is probably better, but this will help with figuring out the kinks. 

We'll need to build a folder with a csv for each of these genes and run the pipeline. 

In [None]:
panelpath = "/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/spapros_10K_taylored_2021/probeset.csv"
panel = pd.read_csv(panelpath)
panel = panel[panel["selection"] == True].copy()
genes = list(panel["Unnamed: 0"])
genes, len(genes)

In [None]:
# Make directory
outdir = "spapros_10K_taylored_2021"
os.makedirs(outdir, exist_ok=True)

# Write one CSV per gene (only the gene name, no header, no index)
for g in genes:
    filepath = os.path.join(outdir, f"{g}.csv")
    pd.Series([g]).to_csv(filepath, index=False, header=False)

## Debugging process

In [None]:
outdir_temp = "/nemo/lab/znamenskiyp/scratch/spapros_10K_taylored_2021_16_ensembl/Adamts2/TempFolder20251027140913"
outdir = "/nemo/lab/znamenskiyp/scratch/spapros_10K_taylored_2021_16_ensembl/Adamts2"

In [None]:
designpars = ("mouse", 20, 2, 60, 78, 20, 44)

outpars = (outdir, outdir_temp)
outpars[1]

In [None]:
genefile = "/nemo/lab/znamenskiyp/home/users/colasa/code/multi_padlock_design/spapros_10K_taylored_2021_16_ensembl/Adamts2.csv"
species = "mouse"

if len(genefile):
    genefile = checkinput.correctinput(genefile)
    success_g, allgenes, alllinkers = checkinput.readgenefile(genefile)
    # remove gene name duplicates
    genes = []
    linkers = []
    for c, name in enumerate(allgenes):
        if name not in genes:
            genes.append(name)
            linkers.append(alllinkers[c])
    # genes = list(set(genes))
else:
    success_g = True
    (
        success_f,
        seqfile,
        headers,
        sequences,
        headers_wpos,
        basepos,
    ) = checkinput.readseqfile()
    print(success_f)
    toavoid = [":", "/", "\\", "[", "]", "?", '"', " ", "<", ">"]
    genes = []
    linkers = []
    variants = []
    print(f"Headers: {headers}")
    for header in headers:
        for i in toavoid:
            header = header.replace(i, "")
        header = header.replace(",", "-")
        genes.append(header)
    # Convert to Ensembl gene names
    # Load the conversion table
    conversion_table = pd.read_csv(
        "/nemo/lab/znamenskiyp/home/users/becalia/code/multi_padlock_design/data/olfr_consensus_cds_3utr_annotations.csv"
    )

    # Build symbol(lowercased) -> first-seen alias mapping (preserves original behavior)
    symbol_to_alias = {}
    for _, row in conversion_table.iterrows():
        symbol = str(row["gene_name"])
        aliases = [a.strip() for a in str(row["gene_symbol"]).split(",") if a.strip()]
        key = symbol.lower()
        if aliases and key not in symbol_to_alias:
            symbol_to_alias[key] = aliases[0]

    def to_alias(name: str) -> str:
        """Return the alias for a given gene name, or the original if none found."""
        return symbol_to_alias.get(name.lower(), name)

    genes = [to_alias(g) for g in genes]
    print(f"Genes: {genes}")

    # --- Retrieve sequences / hits ---
    hits = retrieveseq.querygenes(genes, species)
    print(f"Hits: {hits}")

    retrieveseq.loaddb(species)
    Headers = retrieveseq.Headers

    # --- Collect variants and multi-hit headers for MSA ---
    variants = []
    headersMSA = []

    def variant_id_from_header(h: str) -> str:
        # Strip leading '>' (if present) and the suffix after the first dot
        return h[1:].split(".", 1)[0]

    for hit in hits:
        if len(hit) == 1:
            # FIX: append the whole variant id (not characters of the string)
            variants.append(variant_id_from_header(Headers[hit[0]]))
        elif len(hit) > 1:
            print("Checking hits")
            tempheaders = [Headers[i] for i in hit]
            headersMSA.append(tempheaders)
            variants.extend(variant_id_from_header(h) for h in tempheaders)
        # if len(hit) == 0: nothing to do

    # One entry per input header, as in original
    linkers = [[] for _ in headers]
    variants_matching_sequence = variants

    print(f"Variants: {variants}")

In [None]:
siteChopped = [
    0,
    1,
    2,
    3,
    4,
    5,
    6,
    7,
    8,
    9,
    10,
    11,
    12,
    13,
    14,
    15,
    16,
    17,
    18,
    19,
    20,
    21,
    22,
    23,
    24,
    25,
    26,
    27,
    28,
    29,
    30,
    31,
    32,
    33,
    34,
    35,
    36,
    37,
    38,
    39,
    40,
    41,
    42,
    43,
    44,
    45,
    46,
    47,
    48,
    49,
    50,
    51,
    52,
    53,
    94,
    95,
    96,
    97,
    98,
    99,
    100,
    101,
    102,
    103,
    104,
    105,
    106,
    107,
    108,
    109,
    110,
    111,
    112,
    113,
    114,
    115,
    116,
    117,
    118,
    119,
    120,
    121,
    122,
    123,
    124,
    125,
    126,
    127,
    128,
    129,
    130,
    131,
    132,
    133,
    134,
    135,
    136,
    137,
    138,
    139,
    140,
    141,
    142,
    143,
    144,
    145,
    146,
    147,
    148,
    149,
    150,
    151,
    152,
    153,
    154,
    155,
    156,
    157,
    158,
    159,
    160,
    161,
    162,
    163,
    164,
    165,
    166,
    167,
    168,
    169,
    170,
    171,
    172,
    173,
    174,
    175,
    176,
    177,
    178,
    179,
    180,
    181,
    182,
    183,
    195,
    196,
    197,
    204,
    290,
    291,
    292,
    293,
    294,
    295,
    296,
    297,
    298,
    299,
    300,
    301,
    302,
    303,
    304,
    305,
    306,
    307,
    308,
    309,
    310,
    311,
    312,
    313,
    314,
    315,
    316,
    317,
    318,
    319,
    320,
    321,
    322,
    323,
    324,
    325,
    326,
    327,
    328,
    329,
    330,
    331,
    332,
    333,
    338,
    339,
    340,
    341,
    342,
    343,
    344,
    345,
    346,
    347,
    348,
    349,
    350,
    351,
    352,
    353,
    354,
    355,
    356,
    357,
    358,
    359,
    360,
    361,
    362,
    363,
    364,
    365,
    366,
    367,
    368,
    369,
    370,
    371,
    372,
    373,
    374,
    375,
    376,
    377,
    378,
    379,
    380,
    381,
    382,
    383,
    384,
    385,
    386,
    387,
    388,
    389,
    390,
    391,
    392,
    393,
    394,
    395,
    396,
    397,
    398,
    399,
    400,
    401,
    402,
    403,
    404,
    405,
    406,
    407,
    408,
    409,
    410,
    411,
    412,
    413,
    414,
    415,
    416,
    417,
    418,
    419,
    420,
    421,
    422,
    423,
    424,
    425,
    426,
    427,
    428,
    429,
    430,
    431,
    432,
    433,
    434,
    435,
    436,
    437,
    438,
    439,
    440,
    441,
    442,
    443,
    444,
    445,
    446,
    447,
    448,
    449,
    450,
    451,
    452,
    453,
    454,
    455,
    456,
    457,
    458,
    459,
    460,
    461,
    462,
    463,
    464,
    465,
    466,
    467,
    468,
    469,
    470,
    471,
    472,
    473,
    474,
    475,
    476,
    477,
    478,
    479,
    480,
    481,
    482,
    483,
    484,
    485,
    486,
    487,
    488,
    489,
    490,
    491,
    492,
    493,
    494,
    495,
    496,
    497,
    498,
    499,
    500,
    501,
    502,
    503,
    504,
    505,
    506,
    507,
    508,
    509,
    510,
    511,
    512,
    513,
    514,
    515,
    516,
    517,
    518,
    519,
    520,
    521,
    522,
    523,
    524,
    525,
    526,
    527,
    528,
    529,
    530,
    531,
    532,
    533,
    534,
    535,
    536,
    537,
    538,
    539,
    540,
    541,
    542,
    543,
    544,
    545,
    546,
    547,
    548,
    549,
    550,
    551,
    552,
    553,
    554,
    555,
    556,
    557,
    558,
    559,
    560,
    561,
    562,
    563,
    564,
    565,
    566,
    567,
    568,
    569,
    570,
    571,
    572,
    573,
    574,
    575,
    576,
    577,
    578,
    579,
    580,
    581,
    582,
    583,
    584,
    585,
    586,
    587,
    588,
    589,
    590,
    591,
    592,
    593,
    594,
    595,
    596,
    597,
    598,
    599,
    600,
    601,
    602,
    603,
    604,
    605,
    606,
    607,
    608,
    609,
    610,
    611,
    612,
    613,
    614,
    615,
    616,
    617,
    618,
    619,
    620,
    621,
    622,
    623,
    624,
    625,
    626,
    627,
    628,
    629,
    630,
    631,
    632,
    633,
    634,
    635,
    636,
    637,
    638,
    639,
    640,
    641,
    642,
    643,
    644,
    645,
    646,
    647,
    648,
    649,
    650,
    651,
    652,
    653,
    654,
    655,
    656,
    657,
    658,
    659,
    660,
    661,
    662,
    663,
    664,
    665,
    666,
    667,
    668,
    669,
    670,
    671,
    672,
    673,
    674,
    675,
    676,
    677,
    678,
    679,
    680,
    681,
    682,
    683,
    684,
    685,
    686,
    687,
    688,
    689,
    690,
    691,
    692,
    693,
    694,
    695,
    696,
    697,
    698,
    699,
    700,
    701,
    702,
    703,
    704,
    705,
    706,
    707,
    708,
    709,
    710,
    711,
    712,
    713,
    714,
    715,
    716,
    717,
    718,
    719,
    720,
    721,
    722,
    723,
    724,
    725,
    726,
    727,
    728,
    729,
    730,
    731,
    732,
    733,
    734,
    735,
    736,
    737,
    738,
    739,
    740,
    741,
    742,
    743,
    744,
    745,
    746,
    747,
    748,
    749,
    750,
    751,
    752,
    753,
    754,
    755,
    756,
    757,
    758,
    759,
    760,
    761,
    762,
    763,
    764,
    765,
    766,
    767,
    768,
    769,
    770,
    771,
    772,
    773,
    774,
    775,
    776,
    777,
    778,
    779,
    780,
    781,
    782,
    783,
    784,
    785,
    786,
    787,
    788,
    789,
    790,
    791,
    792,
    793,
    794,
    795,
    796,
    797,
    798,
    799,
    800,
    801,
    802,
    803,
    804,
    805,
    806,
    807,
    808,
    809,
    810,
    811,
    812,
    813,
    814,
    815,
    816,
    817,
    818,
    819,
    820,
    821,
    822,
    823,
    824,
    825,
    826,
    827,
    828,
    829,
    830,
    831,
    832,
    833,
    834,
    835,
    836,
    837,
    838,
    839,
    840,
    841,
    842,
    843,
    844,
    845,
    846,
    847,
    848,
    849,
    850,
    851,
    852,
    853,
    854,
    855,
    856,
    857,
    858,
    859,
    860,
    861,
    862,
    863,
    864,
    865,
    866,
    867,
    868,
    869,
    870,
    871,
    872,
    873,
    874,
    875,
    876,
    877,
    878,
    879,
    880,
    881,
    882,
    883,
    884,
    885,
    886,
    887,
    888,
    889,
    890,
    891,
    892,
    893,
    894,
    895,
    896,
    897,
    898,
    899,
    900,
    901,
    902,
    903,
    904,
    905,
    906,
    907,
    908,
    909,
    910,
    911,
    912,
    913,
    914,
    915,
    916,
    917,
    918,
    919,
    920,
    921,
    922,
    923,
    924,
    925,
    926,
    927,
    928,
    929,
    930,
    931,
    932,
    933,
    934,
    935,
    936,
    937,
    938,
    939,
    940,
    941,
    942,
    943,
    944,
    945,
    946,
    947,
    948,
    949,
    950,
    951,
    952,
    953,
    954,
    955,
    956,
    957,
    958,
    959,
    960,
    961,
    962,
    963,
    964,
    965,
    966,
    967,
    968,
    969,
    970,
    971,
    972,
    973,
    974,
    975,
    976,
    977,
    978,
    979,
    980,
    981,
    982,
    983,
    984,
    985,
    986,
    987,
    988,
    989,
    990,
    991,
    992,
    993,
    994,
    995,
    996,
    997,
    998,
    999,
    1000,
    1001,
    1002,
    1003,
    1004,
    1005,
    1006,
    1007,
    1008,
    1009,
    1010,
    1011,
    1012,
    1013,
    1014,
    1015,
    1016,
    1017,
    1018,
    1019,
    1020,
    1021,
    1022,
    1023,
    1024,
    1025,
    1026,
    1027,
    1028,
    1029,
    1030,
    1031,
    1032,
    1033,
    1034,
    1035,
    1036,
    1037,
    1038,
    1039,
    1040,
    1041,
    1042,
    1043,
    1044,
    1045,
    1046,
    1047,
    1048,
    1049,
    1050,
    1051,
    1052,
    1053,
    1054,
    1055,
    1056,
    1057,
    1058,
    1059,
    1060,
    1061,
    1062,
    1063,
    1064,
    1065,
    1066,
    1067,
    1068,
    1069,
    1070,
    1071,
    1072,
    1073,
    1074,
    1075,
    1076,
    1077,
    1078,
    1079,
    1080,
    1081,
    1082,
    1083,
    1084,
    1085,
    1086,
    1087,
    1088,
    1089,
    1090,
    1091,
    1092,
    1093,
    1094,
    1095,
    1096,
    1097,
    1098,
    1099,
    1100,
    1101,
    1102,
    1103,
    1104,
    1105,
    1106,
    1107,
    1108,
    1109,
    1110,
    1111,
    1112,
    1113,
    1114,
    1115,
    1116,
    1117,
    1118,
    1119,
    1120,
    1121,
    1122,
    1123,
    1124,
    1125,
    1126,
    1127,
    1128,
    1129,
    1130,
    1131,
    1132,
    1133,
    1134,
    1135,
    1136,
    1137,
    1138,
    1139,
    1140,
    1141,
    1142,
    1143,
    1144,
    1145,
    1146,
    1147,
    1148,
    1149,
    1150,
    1151,
    1152,
    1153,
    1154,
    1155,
    1156,
    1157,
    1158,
    1159,
    1160,
    1161,
    1162,
    1163,
    1164,
    1165,
    1166,
    1167,
    1168,
    1169,
    1170,
    1171,
    1172,
    1173,
    1174,
    1175,
    1176,
    1177,
    1178,
    1179,
    1180,
    1181,
    1182,
    1183,
    1184,
    1185,
    1186,
    1187,
    1188,
    1189,
    1190,
    1191,
    1192,
    1193,
    1194,
    1195,
    1196,
    1197,
    1198,
    1199,
    1200,
    1201,
    1202,
    1203,
    1204,
    1205,
    1206,
    1207,
    1208,
    1209,
    1210,
    1211,
    1212,
    1213,
    1214,
    1215,
    1216,
    1217,
    1218,
    1219,
    1220,
    1296,
    1297,
    1298,
    1299,
    1300,
    1301,
    1302,
    1303,
    1304,
    1305,
    1306,
    1307,
    1308,
    1309,
    1310,
    1311,
    1312,
    1313,
    1314,
    1315,
    1316,
    1317,
    1318,
    1319,
    1320,
    1321,
    1322,
    1323,
    1324,
    1325,
    1326,
    1327,
    1328,
    1329,
    1330,
    1331,
    1332,
    1333,
    1334,
    1335,
    1336,
    1337,
    1338,
    1339,
    1340,
    1341,
    1342,
    1343,
    1344,
    1345,
    1346,
    1347,
    1348,
    1349,
    1350,
    1351,
    1352,
    1353,
    1354,
    1355,
    1356,
    1357,
    1358,
    1359,
    1360,
    1361,
    1362,
    1363,
    1364,
    1365,
    1366,
    1367,
    1368,
    1369,
    1370,
    1371,
    1372,
    1373,
    1374,
    1375,
    1417,
    1418,
    1419,
    1420,
    1421,
    1422,
    1423,
    1424,
    1425,
    1426,
    1427,
    1428,
    1429,
    1430,
    1431,
    1432,
    1433,
    1434,
    1435,
    1436,
    1437,
    1438,
    1439,
    1440,
    1441,
    1442,
    1443,
    1444,
    1445,
    1446,
    1447,
    1448,
    1449,
    1450,
    1451,
    1452,
    1453,
    1454,
    1455,
    1456,
    1457,
    1458,
    1459,
    1460,
    1461,
    1462,
    1463,
    1464,
    1465,
    1466,
    1467,
    1468,
    1469,
    1470,
    1471,
    1472,
    1473,
    1474,
    1475,
    1476,
    1477,
    1478,
    1479,
    1480,
    1481,
    1482,
    1483,
    1484,
    1485,
    1486,
    1487,
    1488,
    1489,
    1490,
    1491,
    1492,
    1493,
    1494,
    1495,
    1496,
    1497,
    1498,
    1499,
    1500,
    1501,
    1502,
    1503,
    1504,
    1505,
    1506,
    1507,
    1508,
    1509,
    1510,
    1511,
    1512,
    1513,
    1514,
    1515,
    1516,
    1517,
    1518,
    1519,
    1520,
    1521,
    1522,
    1523,
    1524,
    1525,
    1526,
    1527,
    1528,
    1529,
    1530,
    1531,
    1532,
    1533,
    1534,
    1535,
    1536,
    1537,
    1538,
    1539,
    1540,
    1541,
    1542,
    1543,
    1544,
    1545,
    1546,
    1547,
    1548,
    1549,
    1550,
    1551,
    1552,
    1553,
    1554,
    1555,
    1556,
    1557,
    1558,
    1559,
    1560,
    1561,
    1562,
    1563,
    1564,
    1565,
    1566,
    1567,
    1568,
    1569,
    1570,
    1571,
    1572,
    1573,
    1574,
    1575,
    1576,
    1577,
    1578,
    1579,
    1580,
    1581,
    1582,
    1583,
    1584,
    1585,
    1586,
    1587,
    1588,
    1589,
    1590,
    1591,
    1592,
    1593,
    1594,
    1595,
    1596,
    1597,
    1598,
    1599,
    1600,
    1601,
    1602,
    1603,
    1604,
    1605,
    1606,
    1607,
    1608,
    1609,
    1610,
    1611,
    1612,
    1613,
    1614,
    1615,
    1616,
    1617,
    1618,
    1619,
    1620,
    1621,
    1622,
    1623,
    1624,
    1665,
    1666,
    1667,
    1668,
    1669,
    1670,
    1671,
    1672,
    1673,
    1674,
    1675,
    1676,
    1677,
    1678,
    1679,
    1680,
    1681,
    1682,
    1683,
    1684,
    1685,
    1686,
    1687,
    1688,
    1689,
    1690,
    1691,
    1692,
    1693,
    1694,
    1695,
    1696,
    1697,
    1698,
    1699,
    1700,
    1701,
    1702,
    1703,
    1704,
    1705,
    1706,
    1707,
    1708,
    1709,
    1710,
    1711,
    1712,
    1713,
    1714,
    1715,
    1716,
    1717,
    1718,
    1719,
    1720,
    1721,
    1722,
    1723,
    1724,
    1725,
    1726,
    1727,
    1728,
    1729,
    1730,
    1731,
    1732,
    1733,
    1734,
    1735,
    1736,
    1737,
    1738,
    1739,
    1740,
    1741,
    1742,
    1743,
    1744,
    1745,
    1746,
    1747,
    1748,
    1749,
    1750,
    1751,
    1752,
    1753,
    1754,
    1755,
    1756,
    1757,
    1758,
    1759,
    1760,
    1761,
    1762,
    1763,
    1764,
    1765,
    1766,
    1767,
    1768,
    1769,
    1770,
    1771,
    1772,
    1773,
    1774,
    1775,
    1776,
    1777,
    1778,
    1779,
    1780,
    1781,
    1782,
    1783,
    1784,
    1785,
    1786,
    1787,
    1788,
    1789,
    1790,
    1791,
    1792,
    1793,
    1794,
    1795,
    1796,
    1797,
    1798,
    1799,
    1800,
    1801,
    1802,
    1803,
    1804,
    1805,
    1806,
    1807,
    1808,
    1809,
    1810,
    1811,
    1812,
    1813,
    1814,
    1815,
    1816,
    1817,
    1818,
    1819,
    1820,
    1821,
    1822,
    1823,
    1824,
    1825,
    1826,
    1827,
    1828,
    1829,
    1830,
    1831,
    1832,
    1833,
    1834,
    1835,
    1836,
    1837,
    1838,
    1839,
    1840,
    1841,
    1842,
    1843,
    1844,
    1845,
    1846,
    1847,
    1848,
    1849,
    1850,
    1851,
    1852,
    1853,
    1854,
    1855,
    1856,
    1857,
    1858,
    1859,
    1860,
    1861,
    1862,
    1863,
    1864,
    1865,
    1866,
    1867,
    1868,
    1869,
    1870,
    1871,
    1872,
    1873,
    1874,
    1875,
    1876,
    1877,
    1878,
    1879,
    1880,
    1881,
    1882,
    1883,
    1884,
    1885,
]
plt.plot(siteChopped)

In [None]:
blastpath = "/nemo/lab/znamenskiyp/scratch/spapros_10K_taylored_2021_20_ensembl/Lamp5/TempFolder20251031115743/ENSMUST00000057503.7cdnachromo_target_431_blast.txt"
with open(blastpath, "r") as f:
    for line in f:
        if not ("XR" in line or "XM" in line):
            columns = line.split(",")
            hit = columns[1].strip()
            scores = columns[2:]
            print(f"Hit: {hit}")
            print(f"Scores: {scores[0]}")
            print(scores[1])

In [None]:
if not success_f:
    # find genes in the database
    hits = retrieveseq.querygenes(genes, species)

    # if any gene is not found in the RefSeq acronym list, write to file
    nohit = [i for i in range(len(genes)) if len(hits[i]) == 0]
    if len(nohit):
        # try to convert gene acronyms using Ensembl REST API
        print(genes[1])
        translated_id = makesnails.get_gene_synonyms(species, genes[1])
        print(translated_id)
        if translated_id is not None:
            print(f"Converted {genes} to {translated_id}")
            hits = retrieveseq.querygenes(translated_id, species)
            with open(
                os.path.join(outdir, "0.AcronymNotFound_" + t + ".txt"), "w"
            ) as f:
                f.write(
                    "Some gene acronyms were converted using Ensembl REST API. Please check the following list:\n"
                )
                f.write("%s -> %s\n" % (translated_id, "Gabra1"))
                f.write("Please make sure these are the intended genes.\n\n")
        else:
            with open(
                os.path.join(outdir, "0.AcronymNotFound_" + t + ".txt"), "w"
            ) as f:
                for i in nohit[::-1]:
                    f.write("%s\n" % genes[i])
                    del genes[i]  # remove genes that are not found
                    del linkers[i]
                    del hits[i]

    # find sequences (MSA included if multiple variants)
    headers, basepos, sequences, msa, nocommon, variants = retrieveseq.findseq(
        genes, hits, outdir_temp
    )

    # genes that have no common sequence among all variants
    nocommon = [c for c, i in enumerate(genes) if i in nocommon]
    if len(nocommon):
        with open(
            os.path.join(outdir, "0.NoConsensusSequence_" + t + ".txt"), "w"
        ) as f:
            for i in nocommon[::-1]:
                f.write("%s\n" % genes[i])
                del genes[i]  # remove genes that are not found
                del linkers[i]
                del hits[i]
                del variants[i]

    idxmsa = [c for c, i in enumerate(genes) if i in msa]
    geneorder = [c for c, i in enumerate(genes) if i not in msa]
    [geneorder.append(i) for i in idxmsa]
    linkers = [linkers[i] for i in geneorder]
    genes = [genes[i] for i in geneorder]
    variants = [variants[i] for i in geneorder]

    # replicates variants so that they match headers_wpos
    variants_matching_sequence = []
    for c, i in enumerate(variants):
        if isinstance(basepos[c][0], int):
            variants_matching_sequence.append(i)
        else:
            for j in range(0, len(basepos[c])):
                variants_matching_sequence.append(i)

    # write found sequences to output file
    with open(os.path.join(outdir, "1.InputSeq_" + t + ".fasta"), "w") as f:
        for i, base in enumerate(basepos):
            if isinstance(base[0], int):  # no variants found
                f.write("%s, %d to %d\n" % (headers[i], base[0] + 1, base[1] + 1))
                f.write("%s\n\n" % sequences[i])
            else:  # more than one variants found
                for j, subbase in enumerate(base):
                    f.write(
                        "%s, %d to %d\n" % (headers[i], subbase[0] + 1, subbase[1] + 1)
                    )
                    f.write("%s\n\n" % sequences[i][j])
    temp = readfastafile.readfasta(os.path.join(outdir, "1.InputSeq_" + t + ".fasta"))
    headers_wpos = temp[1]
    sequences = temp[2]

In [None]:
hit = "ENSMUST00000001051"
variants = [hit]

# Strip version suffixes for comparison
hit_core = hit.split(".")[0]
if isinstance(variants[0], list):  # multiple variants
    variant_cores = [v.split(".")[0] for v in variants[0]]
elif isinstance(variants[0], str):  # one variant
    variant_cores = [variants[0].split(".")[0]]

hit_core, variant_cores

In [None]:
genepars = (genes, linkers, headers, variants)
genes, linkers, headers, variants

In [None]:
variants
variant_cores = [variants[0].split(".")[0]]
variant_cores

In [None]:
if config.annotation_file:
    print("bananas")

In [None]:
hit_core = hit.split(".")[0]
if len(variants) > 1:
    variant_cores = [v.split(".")[0] for v in variants[0]]
elif len(variants) == 1:
    variant_cores = [variants[0].split(".")[0]]

print(hit_core, variant_cores)
if hit_core not in variant_cores:
    print("Non-specific hit found:", hit)

In [None]:
designinput = (basepos, headers_wpos, sequences, variants_matching_sequence)
basepos, headers_wpos, sequences, variants_matching_sequence

In [None]:
Tm, siteChopped = screenseq.thresholdtm(
    designinput[1],
    designinput[2],
    outpars[1],
    designpars,  # armlen used, but not important, only total length
)

In [None]:
# Make blast database if it doesn't exist
formatrefseq.blastdb(designpars[0])
# Run blast in parallel
parblast.continueblast(siteChopped, designinput[1], outpars[1], designpars)

In [None]:
siteCandidates, notMapped = readblast.getcandidates(
    siteChopped,
    designinput[1],
    outpars,
    designpars[1],
    designinput[3],
    config.specificity_by_tm,
    designpars[6],  # total length
)  # used in readblastout

## Building SNAIL probes

I need to 

- Split the binding regions
- Source barcodes
- Source a sequencing primer
- Build random primer-padlock overlaps
- Assemble

In [None]:
# source all specific targets
outdir = Path("/nemo/lab/znamenskiyp/scratch")
panel_name = "spapros_10K_taylored_2021_21_ensembl"

In [None]:
def read_candidates(outdir, panel_name, probefile="3.AllSpecificTargets_*.csv"):
    """
    Read and concatenate all "3.AllSpecificTargets_*.csv" files from subdirectories of the given panel directory.

    Args:
        outdir (Path): Base output directory containing the panel subdirectory.
        panel_name (str): Name of the panel subdirectory.

    Returns:
        pd.DataFrame: Concatenated dataframe with columns ['target', 'Tm', 'startpos', 'gene'].
    """
    all_dfs = []
    panel_dir = outdir / panel_name

    # Loop over subdirectories
    for subdir in panel_dir.iterdir():
        if subdir.is_dir():
            print(f"Processing {subdir.name}")

            # Collect "3.AllSpecificTargets_*.csv"
            target_files = list(subdir.glob(probefile))
            if len(target_files) != 1:
                print(
                    f"Warning: Expected exactly one target file in {subdir}, found {len(target_files)}. Skipping."
                )
                continue
            tf = target_files[0]

            # Load and tag with gene name (= subdir name)
            with open(tf) as f:
                clean_lines = [line for line in f if not line.startswith(">")]
                clean_lines = [smart_split(line.strip()) for line in clean_lines]

            header, *data = clean_lines
            if len(data) == 0:
                print(f"Warning: No data found in {tf}, skipping.")
                continue
            if len(data[0]) < 2:
                print(f"Warning: No data found in {tf}, skipping.")
                continue
            df = pd.DataFrame(data, columns=header)
            df["gene"] = subdir.name
            if len(df) < 1:
                print(f"Warning: No targets found in {subdir.name}, skipping.")
                continue
            all_dfs.append(df)

    # Concatenate all into one big dataframe (OUTSIDE the loop)
    if all_dfs:
        big_df = pd.concat(all_dfs, ignore_index=True)
    else:
        big_df = pd.DataFrame()  # Return empty DataFrame if no data

    big_df = big_df[big_df["acronym"].notna() & (big_df["acronym"] != "")]

    return big_df


def smart_split(line):
    # split on commas not inside [ ... ]
    return re.split(r",(?![^[]*\])", line)


rows = []

big_df = read_candidates(outdir, panel_name, probefile="6.ProbesRandomSubset*.csv")
big_df

In [None]:
len("AATGACCAACTCCTCCTCCACTTCGACCTCCACCACTACCGGGG")

In [None]:
per_gene = big_df.groupby("acronym")["target"].count()

# Sort values for clarity
per_gene_sorted = per_gene.sort_values(ascending=False)

# Make figure tall enough for ~100 bars
plt.figure(figsize=(len(per_gene_sorted) * 0.3, 10))

per_gene_sorted.plot(kind="bar")

plt.xlabel("Gene Acronym")
plt.axhline(y=4, color="red")
plt.ylabel("Target Count")
plt.title("Target Counts per Gene Acronym")
plt.tight_layout()
plt.show()

- Design unique overlaps per probe-primer pair
- Split the target, build skeleton
- Add barcode
- Add UMI
- Add random specifier
- Add sequencing primer
- Verify

So let's start by defining the unique overlaps and building the skeleton

In [None]:
# generate 1, 2, 3, 4

probe_ids = [1, 2, 3, 4, 11, 12, 13, 14]
genes = big_df["acronym"].unique().tolist()

probes = np.array([f"{gene}_{pid}" for gene in genes for pid in probe_ids])
padlocks = np.array([f"{gene}_{pid}" for gene in genes for pid in [1, 2, 3, 4]])
gene_array = np.array([gene for gene in genes for pid in probe_ids])

In [None]:
len(probes)

In [None]:
overlaps = makesnails.generate_random_dna_sequences(
    12, 12, len(genes) * 4, True, genes=padlocks
)

In [None]:
# Ensure primer and padlock for each pair share the same overlap sequence
# Create a list to store all rows (original + duplicates)
new_rows = []

# Iterate through the dataframe in groups of 4
for i in range(0, len(overlaps), 4):
    # Get the current group of 4 rows
    group = overlaps.iloc[i : i + 4]

    # Add the original 4 rows
    new_rows.append(group)

    # Create duplicates with modified IDs
    duplicates = group.copy()
    duplicates["Overlap Primer Sequence"] = duplicates[
        "Overlap Primer Sequence"
    ].str.replace(r"_(\d+)$", lambda m: f"_{int(m.group(1)) + 10}", regex=True)

    # Add the duplicate 4 rows
    new_rows.append(duplicates)

# Concatenate all rows and reset index
overlaps = pd.concat(new_rows, ignore_index=True)

In [None]:
overlaps

In [None]:
probes_df = pd.DataFrame(
    {
        "probe_id": probes,
        "gene": gene_array,
        "overlap_sequence": overlaps["Sequence"],
        "overlaps_reverse_complement": overlaps["Reverse Complement"],
        "overlaps_Tms": overlaps["Tms"],
        "Tm": np.zeros(len(probes)),
        "startpos": np.zeros(len(probes), dtype=int),
        "endpos": np.zeros(len(probes), dtype=int),
        "transcript_region": np.zeros(len(probes), dtype=int),
        "target": np.zeros(len(probes)),
    }
)
probes_df.set_index("probe_id", inplace=True)

Now, for each gene we can choose four targets and assemble:

In [None]:
# ...existing code...
seed = 42  # set None for non-deterministic picks
selected = []

for gene in genes:
    probes_slice = probes_df["gene"] == gene
    candidates = big_df[big_df["acronym"] == gene]
    if candidates.empty:
        print(f"Warning: no targets for {gene}")
        continue
    if candidates.shape[0] < 4:
        print(f"Warning: only {candidates.shape[0]} targets for {gene}, fewer than 4")
        continue
    n = min(4, len(candidates))
    chosen = candidates.sample(n=n, replace=False, random_state=seed)
    for pair in [1, 2, 3, 4]:
        padlock = f"{gene}_{pair}"
        primer = f"{gene}_{pair+10}"
        target = chosen.iloc[pair - 1]

        # Assign to padlock
        probes_df.loc[padlock, "target"] = target["target"]
        probes_df.loc[padlock, "Tm"] = target["Tm"]
        probes_df.loc[padlock, "startpos"] = target["startpos"]
        probes_df.loc[padlock, "endpos"] = target["endpos"]
        probes_df.loc[padlock, "transcript_region"] = target["transcript_region"]
        probes_df.loc[padlock, "type"] = "padlock"

        # Assign to primer (same values)
        probes_df.loc[primer, "target"] = target["target"]
        probes_df.loc[primer, "Tm"] = target["Tm"]
        probes_df.loc[primer, "startpos"] = target["startpos"]
        probes_df.loc[primer, "endpos"] = target["endpos"]
        probes_df.loc[primer, "transcript_region"] = target["transcript_region"]
        probes_df.loc[primer, "type"] = "primer"

In [None]:
probes_df["sequencing_primer"] = "GTCGGACTGTAGAACTCT"

Sample from the chosen barcode set. In this case, we will use the constrained panel that we optimised by greedy search of the linked loss function. 

In [None]:
genes

In [None]:
subset_7_path = "/nemo/lab/znamenskiyp/home/users/colasa/code/multi_padlock_design/constrained_panel/224_subset_prefix7_d3_seed59.txt"
all_10_path = "/nemo/lab/znamenskiyp/home/users/colasa/code/multi_padlock_design/constrained_panel/1800_d4_10bp_barcodes_improved.txt"

with open(subset_7_path) as f:
    lines = f.readlines()
    short_barcodes = [line.split(",")[1].strip() for line in lines]

with open(all_10_path) as f:
    lines = f.readlines()
    all_barcodes = [line.split(",")[1].strip() for line in lines]

extra_barcodes = [barcode for barcode in all_barcodes if barcode not in short_barcodes]

gene_dict = pd.DataFrame(
    {
        "genes": genes + [np.nan] * (len(short_barcodes + extra_barcodes) - len(genes)),
        "barcodes": short_barcodes + extra_barcodes,
        "is_7_mer": [[True] * len(short_barcodes) + [False] * len(extra_barcodes)][0],
    }
)

for item in gene_dict["barcodes"][gene_dict["is_7_mer"] == True]:
    assert item in short_barcodes
for item in gene_dict["barcodes"][gene_dict["is_7_mer"] == False]:
    assert item in extra_barcodes

gene_dict

In [None]:
probes_df

In [None]:
import itertools

# Generate all 4-letter combinations of ATCG
all_combinations = ["".join(combo) for combo in itertools.product("ATCG", repeat=4)]

# Define length-3 homopolymers to remove
homopolymers = ["AAA", "TTT", "CCC", "GGG"]

# Remove sequences containing any homopolymer
filtered_combinations = [
    seq for seq in all_combinations if not any(hp in seq for hp in homopolymers)
]

# Choose random n sequences
n = len(genes)  # Change this to however many you want
random_selection = random.sample(filtered_combinations, n)

extra_redundancy = pd.DataFrame(
    {"genes": genes, "barcode": random_selection}
).set_index("genes")

In [None]:
new_probes = []
UMI = ["A", "T", "C", "G"]
for gene in genes:
    subset = probes_df[probes_df["gene"] == gene].copy
    subset["barcode"] = gene_dict["barcodes"][gene_dict["gene"] == gene]
    for pair in [1, 2, 3, 4]:
        padlock = f"{gene}_{pair}"
        primer = f"{gene}_{pair+10}"

        subset.loc[padlock, "extra_redundancy"] = extra_redundancy[gene]
        subset.loc[padlock, "target_seq"] = makesnails.revcomp(
            subset.loc[padlock, target][0:22]
        )
        subset.loc[padlock, "UMI"] = UMI[pair - 1]
        subset.loc[padlock, "overlap_5_seq"] = subset.loc[padlock,]

For each site, add the reverse compliment and break it at site 22. Add primer, spacer and padlock. Calculate the delta Tm between primer and padlock arms, add it. 

In [None]:
big_df["reverse_comp"] = big_df["target"].apply(makesnails.revcomp)

In [None]:
big_df["padlock"] = [big_df["reverse_comp"][i][0:22] for i in range(len(big_df))]
big_df["spacer"] = [big_df["reverse_comp"][i][22:24] for i in range(len(big_df))]
big_df["primer"] = [big_df["reverse_comp"][i][24:] for i in range(len(big_df))]

big_df.head()

In [None]:
big_df["padlock_tm"] = [
    readblast.calc_tm_NN(big_df["padlock"][i]) for i in range(len(big_df))
]
big_df["primer_tm"] = [
    readblast.calc_tm_NN(big_df["primer"][i]) for i in range(len(big_df))
]

In [None]:
import random

import pandas as pd


def generate_random_dna_sequences(
    min_barcode, length, num_sequences, for_overlap, genes=None
):
    bases = ["A", "T", "C", "G"]
    sequences = []
    reverse_complements = []
    Tms = []

    while len(sequences) < num_sequences:
        sequence = "".join(random.choice(bases) for _ in range(length))

        # Check condition for overlap oligos
        if (
            for_overlap
            and sequence[3:6]
            in ["TAA", "TCA", "TGA", "TTA", "ATT", "AGT", "ACT", "AAT"]
        ):  # ligation wont happen when reverse complement is TXA In 6:9 position for overlap
            continue

        # No homopolymers
        if (
            ("AAA" in sequence)
            or ("TTT" in sequence)
            or ("CCC" in sequence)
            or ("GGG" in sequence)
        ):
            continue

        # check Tm?
        Tm = readblast.calc_tm_NN(sequence)
        if Tm < 25 or Tm > 35:
            continue

        if all(
            count_diff(seq[:min_barcode], sequence[:min_barcode]) >= 2
            for seq in sequences
        ) and all(count_diff(seq, sequence) >= 4 for seq in sequences):
            sequences.append(sequence)
            complement = {"A": "T", "T": "A", "C": "G", "G": "C"}
            reverse_seq = sequence[::-1]
            reverse_complement_seq = "".join(complement[base] for base in reverse_seq)
            reverse_complements.append(reverse_complement_seq)
            Tms.append(Tm)

    # Build dataframe
    df = pd.DataFrame(
        {"Sequence": sequences, "Reverse Complement": reverse_complements, "Tms": Tms}
    ).drop_duplicates()

    # Handle gene names
    if genes is None:
        genes = ["no_gene"] * len(df)
    else:
        # pad or truncate to match df length
        if len(genes) < len(df):
            genes = genes + ["no_gene"] * (len(df) - len(genes))
        elif len(genes) > len(df):
            genes = genes[: len(df)]

    df["Overlap Primer Sequence"] = genes
    df["ID"] = range(len(df))  # unique IDs

    return df


def count_diff(seq1, seq2):
    return sum(base1 != base2 for base1, base2 in zip(seq1, seq2))


# Example usage
genes = big_df

overlap_prims = generate_random_dna_sequences(12, 12, 14, for_overlap=True, genes=genes)
print(overlap_prims)

In [None]:
genes = list(big_df.index)  # make a unique overlap for all sequences
overlap_regions = makesnails.generate_random_dna_sequences(
    12, 12, len(genes), True, genes=genes
)

## Barcodes


Check the Hamming distance matrix between all the candidates I have, have a look at how many fit in the barcode space at 7bp, decide how much room for a larger panel I want. 

In [None]:
barcodes = pd.read_csv(
    "/nemo/lab/znamenskiyp/home/users/colasa/code/multi_padlock_design/old_7bp_barcodes.csv",
    names=["ID", "barcode"],
)
barcodes.head()

In [None]:
with open(
    "/nemo/lab/znamenskiyp/home/users/colasa/code/multi_padlock_design/nested_barcodes_parallel/2097_d4_10bp_barcodes_seed105.txt"
) as f:
    lines = [line.strip().split(",") for line in f if line.strip()]

ids = [x[0] for x in lines]
all_code = [x[1] for x in lines]

barcodes = pd.DataFrame({"ID": ids, "barcode": all_code})

print(max([len(a) for a in barcodes["barcode"]]))

print(min([len(a) for a in barcodes["barcode"]]))

# eliminate the two final ones
# seven_mers = [barcode[:-3] for barcode in barcodes['barcode']]
# barcodes['barcode']=seven_mers

In [None]:
# Convert barcodes to a 2D numpy array of characters
barcodes_array = np.array([list(seq) for seq in barcodes["barcode"]])

# Broadcast and compare all pairs at once
# shape: (n, n, L) where n = number of barcodes, L = length of each
diffs = barcodes_array[:, None, :] != barcodes_array[None, :, :]

# Sum mismatches along sequence length (axis=2)
dist_matrix = diffs.sum(axis=2)

# Put result in a dataframe
dist_df = pd.DataFrame(dist_matrix, index=barcodes["ID"], columns=barcodes["ID"])
print(dist_df)

In [None]:
plt.imshow(dist_df.values, cmap="viridis", interpolation="nearest")

In [None]:
plt.hist(dist_df.values.flatten(), bins=range(0, 9), align="left")

In [None]:
np.power(4, 7)

In [None]:
np.power(4, 3)

In [None]:
alphabet = "ACGT"
n = 7

# as a list of strings
all_barcodes = ["".join(p) for p in product(alphabet, repeat=n)]
print(len(all_barcodes))  # 16384
print(all_barcodes[:5])  # first few

In [None]:
def greedy_barcode_set(candidates, d):
    # candidates: list/array of strings, equal length
    chosen = []
    arr = np.array([list(s) for s in candidates])  # shape (K, n)

    for idx, s in enumerate(arr):
        if not chosen:
            chosen.append(idx)  # seed the list
            continue
        # compare s to all chosen in a vectorized way
        prev = arr[chosen]  # (m, n)
        # number of mismatches vs each chosen codeword
        dist = (prev != s).sum(axis=1)
        if np.all(dist >= d):
            chosen.append(idx)
    return [candidates[i] for i in chosen]


candidates = greedy_barcode_set(all_barcodes, 3)

In [None]:
def hamming_distance(a, b):
    # assumes equal length strings
    return sum(ch1 != ch2 for ch1, ch2 in zip(a, b))


def greedy_once(candidates, d):
    chosen = []
    for cand in candidates:
        if all(hamming_distance(cand, c) >= d for c in chosen):
            chosen.append(cand)
    return chosen


def multi_restart_greedy(n=7, alphabet="ACGT", d=3, restarts=20, seed=None):
    # generate all candidates
    candidates = ["".join(p) for p in product(alphabet, repeat=n)]
    best = []
    rng = random.Random(seed)
    for _ in tqdm(range(restarts)):
        rng.shuffle(candidates)
        chosen = greedy_once(candidates, d)
        if len(chosen) > len(best):
            best = chosen
    return best


# Example: try to build a code with n=7, d=3
barcodes = multi_restart_greedy(n=10, d=4, restarts=50, seed=55)
print("Size:", len(barcodes))
print("First few:", barcodes[:10])

In [None]:
barcodes_len_3 = barcodes

In [None]:
def hamming_distance(a, b):
    return sum(ch1 != ch2 for ch1, ch2 in zip(a, b))


def local_improvement(chosen, candidates, d, max_passes=3):
    chosen = chosen[:]  # copy
    for _ in range(max_passes):
        improved = False
        for x in candidates:
            if x in chosen:
                continue
            conflicts = [c for c in chosen if hamming_distance(x, c) < d]
            # Add if it fits cleanly
            if not conflicts:
                chosen.append(x)
                improved = True
            # Simple swap if it blocks only one
            elif len(conflicts) == 1:
                c = conflicts[0]
                chosen.remove(c)
                if all(hamming_distance(x, c2) >= d for c2 in chosen):
                    chosen.append(x)
                    improved = True
                else:
                    chosen.append(c)  # revert
        if not improved:
            break
    return chosen

In [None]:
# Generate all barcodes of length 7 over ACGT
candidates = ["".join(p) for p in product("ACGT", repeat=7)]
barcodes = pd.read_csv(
    "/nemo/lab/znamenskiyp/home/users/colasa/code/multi_padlock_design/old_7bp_barcodes.csv",
    names=["ID", "barcode"],
)
greedy_code = barcodes["barcode"].astype(str).tolist()
old_size = len(greedy_code)
print("Old size:", old_size)

improved_code = local_improvement(greedy_code, candidates, d=3, max_passes=5)
print("Improved size:", len(improved_code))

In [None]:
with open("369_d3_7bp_barcodes.txt", "w") as f:
    for i, code in enumerate(improved_code):
        f.write(f"{i},{code}\n")

In [None]:
base2int = {"A": 0, "C": 1, "G": 2, "T": 3}


def encode_barcode(bc):
    return np.fromiter((base2int[b] for b in bc), dtype=np.uint8)


def encode_list(barcodes):
    return np.vstack([encode_barcode(bc) for bc in barcodes])


def hamming_vectorized(X, y):
    return (X != y).sum(axis=1)


def local_improvement_fast(chosen, candidates, d, max_passes=3):
    # Encode chosen and candidates
    chosen_enc = encode_list(chosen)
    cand_enc = encode_list(candidates)
    chosen_set = set(chosen)  # quick membership check

    for _ in range(max_passes):
        improved = False
        new_chosen = []
        for x_str, x_enc in zip(candidates, cand_enc):
            if x_str in chosen_set:
                continue
            # vectorized distance check
            dists = (chosen_enc != x_enc).sum(axis=1)
            conflicts = np.where(dists < d)[0]

            if len(conflicts) == 0:
                chosen.append(x_str)
                chosen_enc = np.vstack([chosen_enc, x_enc])
                chosen_set.add(x_str)
                improved = True
            elif len(conflicts) == 1:
                # attempt swap
                idx = conflicts[0]
                tmp_enc = np.delete(chosen_enc, idx, axis=0)
                if ((tmp_enc != x_enc).sum(axis=1) >= d).all():
                    # swap is valid
                    removed = chosen[idx]
                    chosen.pop(idx)
                    chosen_enc = tmp_enc
                    chosen.append(x_str)
                    chosen_enc = np.vstack([chosen_enc, x_enc])
                    chosen_set.discard(removed)
                    chosen_set.add(x_str)
                    improved = True
        if not improved:
            break
    return chosen


# generate all 7-mers
candidates = ["".join(p) for p in product("ACGT", repeat=10)]

# your initial set
greedy_code = greedy_code  # from previous greedy step

print("Old size:", len(greedy_code))
improved_code = local_improvement_fast(greedy_code, candidates, d=4, max_passes=35)
print("Improved size:", len(improved_code))

In [None]:
with open(
    "/nemo/lab/znamenskiyp/home/users/colasa/code/multi_padlock_design/barcodes_n10_d4.txt"
) as f:
    greedy_code = [line.strip() for line in f if line.strip()]

In [None]:
with open(
    "/nemo/lab/znamenskiyp/home/users/colasa/code/multi_padlock_design/2097_d4_10bp_barcodes.txt"
) as f:
    all_code = [line.strip().split(",")[1] for line in f if line.strip()]

In [None]:
os.mk

In [None]:
# filtering
# Ensure there are no runs of more than `max_homopolymer_length` consecutive bases
# and at least three different bases

# look at all 10_mers
candidates = np.array(["".join(p) for p in product("ACGT", repeat=10)])
print(f"Length before constraints {len(candidates)}")


# homopolymer constraint
def noHomo(barcode):
    if (
        ("AAA" in barcode)
        | ("TTT" in barcode)
        | ("CCC" in barcode)
        | ("GGG" in barcode)
    ):
        noHomo = False
    else:
        noHomo = True
    return noHomo


noHomo_mask = [noHomo(barcode) for barcode in candidates]
candidates = candidates[noHomo_mask]

print(f"length after homopolymers: {len(candidates)}")

diversity_mask = [(len(set(barcode)) >= 3) for barcode in candidates]
candidates = candidates[diversity_mask]

print(f"length after diversity: {len(candidates)}")

In [None]:
a = "AAATCT"
if ("AAA" in a) | ("AAA" in a):
    print("bananas")

## Debug finding regions

In [None]:
headers = [
    ">ENSMUST00000025404.10 cdna chromosome:GRCm39:18:67476674:67500855:1 gene:ENSMUSG00000024526.10 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:Cidea description:cell death-inducing DNA fragmentation factor, alpha subunit-like effector A [Source:MGI Symbol;Acc:MGI:1270845]"
]

In [None]:
for header in headers:
    first = header.split()[0]  # take the first token
    print(first)
    header = first.lstrip(">")  # only strip '>' if it exists
    print(header)
    transcript = header
    print(transcript)
    cds_file = config.cds_file
    cdna_file = config.cdna_file
    df = finalizeprobes.find_start_end_sites(cds_file, cdna_file, transcript)

In [None]:
headers = [
    "ENSMUST00000117695.8 cdna chromosome:GRCm39:19:40917371:40982591:-1 gene:ENSMUSG00000061132.14 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:Blnk description:B cell linker [Source:MGI Symbol;Acc:MGI:96878]"
]

In [None]:
# source all specific targets
outdir = Path("/nemo/lab/znamenskiyp/scratch")
panel_name = "spapros_10K_taylored_2021_16_ensembl"
probes_df = makesnails.read_candidates(
    outdir, panel_name, probefile="5.ProbesDBMappable_*.csv"
)
probes_df = probes_df[probes_df["acronym"] == "Cidea"]
probes = list(probes_df["target"])
probes

In [None]:
tempdir = "/nemo/lab/znamenskiyp/scratch/spapros_10K_taylored_2021_16_ensembl/Cidea/TempFolder20251027141110"
import glob
from difflib import SequenceMatcher
from pathlib import Path

import config
from lib import finalizeprobes, makesnails

headers = [
    ">ENSMUST00000025404.10 cdna chromosome:GRCm39:18:67476674:67500855:1 gene:ENSMUSG00000024526.10 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:Cidea description:cell death-inducing DNA fragmentation factor, alpha subunit-like effector A [Source:MGI Symbol;Acc:MGI:1270845]"
]
import os

regions = []


def align_to_reference_variant(target_sequence, reference_variant_seq):
    # target_sequence = short probe
    # reference_variant_seq = long cDNA
    matcher = SequenceMatcher(None, reference_variant_seq, target_sequence)
    match = matcher.find_longest_match(
        0, len(reference_variant_seq), 0, len(target_sequence)
    )
    return match.a, match.size, reference_variant_seq[match.a : match.a + match.size]


def get_gene_symbol(header):
    # Split by spaces
    parts = header.split()

    # Find the token that starts with "gene_symbol:"
    gene_symbol_field = next((p for p in parts if p.startswith("gene:")), None)

    # Extract the value after the colon
    if gene_symbol_field:
        gene_symbol = gene_symbol_field.split(":", 1)[1]
    else:
        gene_symbol = None

    return gene_symbol


def find_latest_variants_fasta(tempdir):
    # Look for files like variants_round3.fasta in the given directory
    pattern = os.path.join(tempdir, "*variants_round*.fasta")
    round_files = glob.glob(pattern)

    if round_files:
        print("found some variant files")

        # Extract round numbers using regex and pick the highest one
        def round_num(fpath):
            fname = os.path.basename(fpath)
            m = re.search(r"variants_round(\d+)\.fasta$", fname)
            return int(m.group(1)) if m else -1

        latest = max(round_files, key=round_num)
        print(f"picking latest: {latest}")

        with open(latest, "r") as f:
            lines = f.readlines()
            headers = [line.strip() for line in lines if line.startswith(">")]
            print(headers)
            variants = [h.lstrip(">").split()[0] for h in headers]

        return variants

    # Fallback: look for variants.fasta
    pattern = os.path.join(tempdir, "*variants.fasta")
    round_files = glob.glob(pattern)
    try:
        plain_file = round_files[0]
    except IndexError:
        print("No variant files found")
        return None
    if os.path.exists(plain_file):
        with open(plain_file, "r") as f:
            lines = f.readlines()
            headers = [line.strip() for line in lines if line.startswith(">")]
            print(headers)
            variants = [h.lstrip(">").split()[0] for h in headers]
        return variants

    return None


def find_cds_entries(cds_headers_file, gene_symbol):
    entries = []
    with open(cds_headers_file, "r") as f:
        lines = f.readlines()
        cds_headers = [line.strip() for line in lines]
    for header in cds_headers:
        if f"gene:{gene_symbol}" in header:
            entries.append(header)
    return entries


def get_cdna(transcript_id):
    """
    Return the cDNA sequence (as a string) for a given transcript ID
    from a multi-entry FASTA file.
    """
    for record in SeqIO.parse(config.cdna_file, "fasta"):
        if record.id == transcript_id:
            return str(record.seq)
    raise ValueError(f"Transcript ID {transcript_id} not found in {config.cdna_file}")


for header in headers:
    # get the ensembl gene symbol
    print(header)
    gene_symbol = get_gene_symbol(header)
    # find entries in the cds database for this gene symbol
    cds_headers = str(config.cds_file) + "_headers.txt"
    print(cds_headers)
    cds_entries = find_cds_entries(cds_headers, gene_symbol)
    print(cds_entries)
    # Iterate, check if it's in the variant list, if it's not, bad luck, complain.
    # define variant list
    # obtain tempdir from outpars TODO
    variants = find_latest_variants_fasta(tempdir)
    print(variants)
    for cds_variants in cds_entries:
        variant_id = cds_variants.split()[0].lstrip(">")
        if variant_id in variants:
            print(f"Found valid (some annotated cds) variant: {variant_id}")
            reference_variant = variant_id
            break
        else:
            print(f"No valid (some annotated cds) variant for: {variant_id}")
    # map the target to the reference variant sequence
    reference_variant_sequence = get_cdna(reference_variant)

In [None]:
makesnails.revcomp("ATGGACAAGCTGAATAAGATAACTGTCCCTGCCAG")

Trying target `GAGCCTTACCATGCTGCTCCCCAGGAAGTCCAGGAGCTGCTGAC`. It is not on the CDS, and it actually is at the very begginning of the transcript, at position 10 of `ENSMUST00000117695.8`. 

In [None]:
len(
    """CTTGAGGCCAGAGCCTTACCATGCTGCTCCCCAGGAAGTCCAGGAGCTGCTGACACCCCC
    CTGGACAGCGACACATCCTCTCTCATGGACAAGCTGAATAAGATAACTGTAAGAAACCCT
    GCCAGCCAGAAGCTGAGACAGCTTCAAAAGATGGTCCATGATATTAAGAACAATGAAGGT
    GGAATAATGGACAAGATAAAAAAGCTAAAAGTCAAAGGCCCTCCAAGTGTTCCTCGAAGG
    GACTACGCATTAGACAGCCCTGCAGATGAAGAGGAGCAGTGGTCAGATGACTTCGACAGT
    GACTATGAAAATCCAGATGAACATTCGGACTCCGAGATGTATGTGATGCCTGCCGAGGAG
    ACGGGCGACGATTCCTATGAACCGCCTCCCGCTGAGCAGCAGACACGGGTGGTCCATCCA
    GCCCTGCCCTTCACGAGGGGCGAGTATGTAGATAATCGATCCAGCCAGCGGCACTCTCCG
    CCCTTCAGCAAGACACTTCCCAGTAAGCCCAGCTGGCCTTCAGCGAAAGCGAGGCTGGCC
    TCCACTCTGCCAGCCCCCAACTCTCTACAGAAGCCTCAAGTCCCCCCCAAGCCCAAAGAC
    CTCCTTGAGGATGAGGCTGATTATGTGGTCCCTGTGGAAGATAACGATGAAAACTATATC
    CATCCCAGAGAAAGCAGCCCGCCGCCTGCTGAGAAGGCTCCCATGGTGAATAGATCAACC
    AAGCCAAACAGTTCCTCAAAGCACATGTCGCCTCCAGGGACTGTCGCAGGTCGAAACAGT
    GGGGTCTGGGACTCCAAGTCATCTTTGCCTGCCGCACCATCCCCACTACCACGGGCTGGG
    AAGAAGCCAGCTACACCACTTAAGACTACTCCCGTTCCTCCCCTACCGAATGCATCAAAT
    GTTTGTGAAGAAAAGCCTGTTCCTGCTGAGCGCCACCGAGGGTCTAGTCACAGACAAGAC
    ACTGTACAGTCACCAGTGTTTCCTCCCACCCAGAAACCTGTCCATCAAAAGCCTGTACCC
    TTGCCAAAAGCGGGGAGCCCAGCTGCAGATGGACCGTTCCACAGCTTCCCATTTAATTCG
    ACGTTTGCAGACCAGGAGGCTGAACTGCTCGGTAAGCCCTGGTATGCTGGCGCCTGTGAC
    CGCAAGTCTGCTGAAGAGGCCTTGCACAGATCCAACAAGGATGGATCGTTTCTTATTCGG
    AAGAGCTCTGGCCATGATTCCAAGCAGCCGTACACCCTAGTTGCGTTCTTTAACAAGCGA
    GTGTATAATATTCCTGTACGGTTTATTGAAGCAACCAAACAGTATGCTTTGGGAAAGAAG
    AAAAATGGTGAAGAGTACTTCGGAAGTGTTGTGGAAATCGTCAACAGTCACCAGCACAAC
    CCCCTGGTTCTTATTGACAGTCAGAATAACACGAAAGATTCCACGAGACTGAAATATGCT
    GTGAAGGTTTCATAACGATACCACGGTTCCAGACATGTCCTCTGTTTCTTCTTTTGAGAA
    AACATCATATTCTGGCTATGACTCCTCAGCAGTAAGAGAGAAAAGATGAATGAAGCCACT
    GAGGCTTCGTGAATGAATGAATCTACTCCTTCCTAGGGCGTTCACACGAGCTTTTCTATC
    ACCTGACCTGACGAAGTCATAGCTGGGGAGGTTCGGTTACTATGATACTAATATTGTCCA
    AATAAATGTATTTAAAAGCAA""".replace("\n", "")
)

## Filtering variants by `GENCODE Basic` subset

In [None]:
def extract_features(line, feature):
    match = re.search(rf'{feature} "([^"]+)"', line)
    if match:
        return match.group(1)
    return None


gene_ids = []
transcript_ids = []
transcript_support_levels = []
feature_type = []

with open(config.gencode_file, "r") as f:
    for line in f:
        if line.startswith("#") or line.split("\t")[2] != "transcript":
            continue  # skip header lines and gene lines
        line = line.strip()
        gene_id = extract_features(line, "gene_id")
        transcript_id = extract_features(line, "transcript_id")
        transcript_support_level = extract_features(line, "transcript_support_level")
        if gene_id or transcript_id:
            gene_ids.append(gene_id)
            transcript_ids.append(transcript_id)
            transcript_support_levels.append(transcript_support_level)

gencode = pd.DataFrame(
    {
        "gene_id": gene_ids,
        "transcript_id": transcript_ids,
        "transcript_support_level": transcript_support_levels,
    }
)
print(gencode.head())

In [None]:
# Count unique transcripts per gene
counts = (
    gencode.groupby("gene_id")["transcript_id"]
    .nunique()
    .reset_index(name="n_transcripts")
)

# Plot histogram
plt.figure(figsize=(8, 5))
plt.hist(counts["n_transcripts"], bins=50, log=True, color="skyblue", edgecolor="black")
plt.xlabel("Number of unique transcript IDs per gene")
plt.ylabel("Number of genes")
plt.title("Distribution of GENCODE Basic transcript counts per gene")
plt.show()

In [None]:
# Safely convert to numeric; non-numeric become NaN
tsl = (
    pd.to_numeric(gencode["transcript_support_level"], errors="coerce")
    .fillna(0)
    .astype(int)
)

plt.figure(figsize=(7, 5))
plt.hist(tsl, bins=range(0, 6), align="left", color="lightgreen", edgecolor="black")
plt.xticks(range(0, 6))
plt.xlabel("Transcript Support Level")
plt.ylabel("Number of Transcripts")
plt.title("Distribution of Transcript Support Levels (NA=0)")
plt.show()

	{1,2,3,4,5,NA} [0,1]
transcripts are scored according to how well mRNA and EST alignments match over its full length:  
1 (all splice junctions of the transcript are supported by at least one non-suspect mRNA),  
2 (the best supporting mRNA is flagged as suspect or the support is from multiple ESTs),  
3 (the only support is from a single EST),   
4 (the best supporting EST is flagged as suspect),   
5 (no single transcript supports the model structure),   
NA (the transcript was not analyzed)

In [None]:
import gffpandas.gffpandas as gffpd

# read a gencode database, which is a gff file (almost a gb, but only once)
gencode = gffpd.read_gff3(config.gencode_file)

In [None]:
gencode_attributes = gencode.attributes_to_columns()
gencode_attributes.head()

In [None]:
genes = ["Adamts2"]
species = config.species
hits, hitmask = retrieveseq.querygenes(genes, species)

In [None]:
genes = ["Adamts2"]
species = config.species

In [None]:
"""Load formatted RefSeq header and sequence files"""
global Acronyms
global Headers
global Seq
global Gencode

try:
    fastadir = os.path.join(config.BASE_DIR, config.reference_transcriptome)
    if getattr(config, "reference_transcriptome", "").lower() == "refseq":
        # Expect config.fasta_pre_suffix_refseq = (prefix, suffix), config.fasta_filenum_refseq = int
        pre_suf = getattr(config, "fasta_pre_suffix_refseq", None)
        filenum = getattr(config, "fasta_filenum_refseq", None)
        acr_path = os.path.join(fastadir, species + ".acronymheaders.txt")
        if not os.path.isfile(acr_path):
            print("Processing fasta database files..")
            if (
                isinstance(pre_suf, (list, tuple))
                and len(pre_suf) == 2
                and isinstance(filenum, int)
            ):
                fastadb(
                    fastadir,
                    (pre_suf[0], pre_suf[1], filenum),
                    species,
                    source="refseq",
                    keep_nm_nr_only=True,
                )
            else:
                raise ValueError(
                    "Invalid RefSeq config: expected fasta_pre_suffix_refseq=(prefix, suffix) and fasta_filenum_refseq=int"
                )

    elif getattr(config, "reference_transcriptome", "").lower() == "ensembl":
        files = [config.cdna_file]
        if hasattr(config, "extra_files") and config.extra_files:
            files += config.extra_files
        acr_path = os.path.join(fastadir, species + ".acronymheaders.txt")
        if not os.path.isfile(acr_path):
            print("Processing fasta database files..")
            fastadb(
                fastadir,
                files,
                species,
                source="ensembl",
                keep_nm_nr_only=False,
            )

    # load database files
    with open(os.path.join(fastadir, species + ".acronymheaders.txt"), "r") as f:
        Acronyms = [line.rstrip("\n") for line in f]
    with open(os.path.join(fastadir, species + ".selectedheaders.txt"), "r") as f:
        Headers = [line.rstrip("\n") for line in f]
    with open(os.path.join(fastadir, species + ".selectedseqs.txt"), "r") as f:
        Seq = [line.rstrip("\n") for line in f]
    Gencode = retrieveseq.load_gencode()

except FileNotFoundError:
    print("Cannot load fasta database due to mismatch in species.")

In [None]:
hits[0][0]

In [None]:
Headers[hits[0][0]]

In [None]:
def get_ensembl_id(header):
    match = re.search(r"gene:([^ ]+)", header)
    if match:
        return match.group(1)
    return None


get_ensembl_id(Headers[hits[0][0]])

In [None]:
hits = []
gene = genes[0]
ensembl_ID = [get_ensembl_id(header) for header in Headers if gene in header][0]
print(f"Querying for {ensembl_ID}")
# filter by Gencode basic annotation
allowed_variants = Gencode[Gencode["gene_id"] == ensembl_ID]["transcript_id"].tolist()
supporting_evidence = Gencode[Gencode["gene_id"] == ensembl_ID][
    "transcript_support_level"
].tolist()
print(
    f"allowed variants: {allowed_variants}, supporting evidence: {supporting_evidence} (1 is best)"
)
allowed_hit = [
    i for i, line in enumerate(Headers) if any(tid in line for tid in allowed_variants)
]
# now form the list of all variants and provide a mask
hit = [c for c, header in enumerate(Acronyms) if header == gene]
# filter hits by allowed variants
hit_masked = [(Headers[h].split()[0][1:] in allowed_variants) for h in hit]
# filter hits by allowed
hits.append(hit)

hits = (hits, hit_masked)

In [None]:
np.array(hits[0])[hit_masked]

In [None]:
for hit in hits[0]:
    print(Headers[hit].split()[0][1:])

In [None]:
import os

import config
from lib import retrieveseq
from lib.formatrefseq import fastadb

genes = ["Adamts2"]
species = config.species


def get_ensembl_id(header):
    match = re.search(r"gene:([^ ]+)", header)
    if match:
        return match.group(1)
    return None


dirname = "/nemo/lab/znamenskiyp/scratch/spapros_10K_taylored_2021_16_ensembl/Adamts2/TempFolder20251027140913"

hits = retrieveseq.querygenes(genes, species)

"""Load formatted RefSeq header and sequence files"""
global Acronyms
global Headers
global Seq
global Gencode

try:
    fastadir = os.path.join(config.BASE_DIR, config.reference_transcriptome)
    if getattr(config, "reference_transcriptome", "").lower() == "refseq":
        # Expect config.fasta_pre_suffix_refseq = (prefix, suffix), config.fasta_filenum_refseq = int
        pre_suf = getattr(config, "fasta_pre_suffix_refseq", None)
        filenum = getattr(config, "fasta_filenum_refseq", None)
        acr_path = os.path.join(fastadir, species + ".acronymheaders.txt")
        if not os.path.isfile(acr_path):
            print("Processing fasta database files..")
            if (
                isinstance(pre_suf, (list, tuple))
                and len(pre_suf) == 2
                and isinstance(filenum, int)
            ):
                fastadb(
                    fastadir,
                    (pre_suf[0], pre_suf[1], filenum),
                    species,
                    source="refseq",
                    keep_nm_nr_only=True,
                )
            else:
                raise ValueError(
                    "Invalid RefSeq config: expected fasta_pre_suffix_refseq=(prefix, suffix) and fasta_filenum_refseq=int"
                )

    elif getattr(config, "reference_transcriptome", "").lower() == "ensembl":
        files = [config.cdna_file]
        if hasattr(config, "extra_files") and config.extra_files:
            files += config.extra_files
        acr_path = os.path.join(fastadir, species + ".acronymheaders.txt")
        if not os.path.isfile(acr_path):
            print("Processing fasta database files..")
            fastadb(
                fastadir,
                files,
                species,
                source="ensembl",
                keep_nm_nr_only=False,
            )

    # load database files
    with open(os.path.join(fastadir, species + ".acronymheaders.txt"), "r") as f:
        Acronyms = [line.rstrip("\n") for line in f]
    with open(os.path.join(fastadir, species + ".selectedheaders.txt"), "r") as f:
        Headers = [line.rstrip("\n") for line in f]
    with open(os.path.join(fastadir, species + ".selectedseqs.txt"), "r") as f:
        Seq = [line.rstrip("\n") for line in f]
    Gencode = retrieveseq.load_gencode()

except FileNotFoundError:
    print("Cannot load fasta database due to mismatch in species.")


def findseq(genes, hits, dirname):
    """Find target sequences, if multiple variants are found, call MSA"""
    global Acronyms
    global Headers
    global Seq

    targets = []
    headers = []
    basepos = []
    msa = []
    nocommon = []
    variants = []

    headersMSA = []

    if config.reference_transcriptome == "ensembl":
        hits, hitmask = hits[0], hits[1]

    for c, hit in enumerate(hits):
        if len(hit) == 1:  # only one variant
            targets.append(Seq[hit[0]])
            headers.append(Headers[hit[0]])
            basepos.append([0, len(Seq[hit[0]]) - 1])
            variants.append(Headers[hit[0]][1:].split(".", 1)[0])
            file = (
                dirname + "/" + genes[c] + "_variants.fasta"
            )  # makes it easier to map properly the target if all sets of variants are written down
            with open(file, "w") as f:
                f.write("%s\n" % Headers[0])
                f.write("%s\n\n" % Seq[0])
        else:  # more than ona variant
            msa.append(genes[c])
            tempheader = []
            file = dirname + "/" + genes[c] + "_variants.fasta"
            with open(file, "w") as f:
                for multi in hit:
                    f.write("%s\n" % Headers[multi])
                    f.write("%s\n\n" % Seq[multi])
            if (
                config.reference_transcriptome == "ensembl"
            ):  # only keep variants that pass the GENCODE mask
                hit = [h for i, h in enumerate(hit) if hitmask[i]]
                if len(hit) == 1:
                    msa = []  # we CLUSTAL as a function of the filtered variants
                    targets.append(Seq[hit[0]])
                    headers.append(Headers[hit[0]])
                    basepos.append([0, len(Seq[hit[0]]) - 1])
                file = dirname + "/" + genes[c] + "_allowed_variants.fasta"
                print(f"Writing allowed variants to {file}")
                with open(file, "w") as f:
                    for multi in hit:
                        f.write("%s\n" % Headers[multi])
                        print(Headers[multi])
                        f.write("%s\n\n" % Seq[multi])
                        tempheader.append(Headers[multi])
                headersMSA.append(tempheader)
                variants.append([i[1:].split(".", 1)[0] for i in tempheader])
            else:  # Clustal on all variants, no filter
                for multi in hit:
                    tempheader.append(Headers[multi])
                headersMSA.append(tempheader)
                variants.append([i[1:].split(".", 1)[0] for i in tempheader])

    # run multiple sequence alignment if more than one variant is found for any gene
    if len(msa):
        # process each gene independently
        for gi, gene in enumerate(msa):
            # headers for all variants of this gene (written earlier)
            tempheader = headersMSA[gi]

            round_no = 1  # first run uses *_variants.fasta / .aln / .dnd

            def fasta_for_round(r: int) -> str:
                if config.reference_transcriptome == "ensembl":
                    return (
                        f"{dirname}/{gene}_allowed_variants.fasta"
                        if r == 1
                        else f"{dirname}/{gene}_variants_round{r}.fasta"
                    )
                return (
                    f"{dirname}/{gene}_variants.fasta"
                    if r == 1
                    else f"{dirname}/{gene}_variants_round{r}.fasta"
                )

            def dnd_for_round(r: int) -> str:
                return (
                    f"{dirname}/{gene}_variants.dnd"
                    if r == 1
                    else f"{dirname}/{gene}_variants_round{r}.dnd"
                )

            # how many variants do we currently have?
            n_left = sum(1 for _ in SeqIO.parse(fasta_for_round(round_no), "fasta"))

            while True:
                # run MSA for this gene only (pass [gene] so continuemsa handles one job)
                out = parmsa.continuemsa(
                    dirname,
                    [gene],
                    round=None if round_no == 1 else round_no,
                    reset=True,
                )
                # out = (Names, BasePos, Seqs); each is length 1 because we passed [gene]
                name_c, basepos_c, seqs_c = out[0][0], out[1][0], out[2][0]

                if len(basepos_c):  # consensus found
                    # pick the header that matches the first sequence used by ClustalW
                    # (the code relies on matching by substring of the chosen name)
                    chosen = [h for h in tempheader if name_c in h]
                    if not chosen:
                        # fallback: just take the first header if matching fails
                        chosen = [tempheader[0]]
                    headers.append(chosen[0])
                    basepos.append(basepos_c)
                    targets.append(seqs_c)
                    break

                # no consensus: stop if ≤ 2 variants remain
                if n_left <= 2:
                    nocommon.append(gene)
                    break

                # drop one outgroup and try again
                print(
                    f"[{gene}] No consensus, removing outgroup (round {round_no} → {round_no+1})"
                )
                treefile = dnd_for_round(round_no)
                outgroup = parmsa.find_outgroup(treefile)

                in_fa = fasta_for_round(round_no)
                round_no += 1
                parmsa.remove_outgroup(in_fa, outgroup, round=round_no)

                # update remaining variants
                n_left = sum(1 for _ in SeqIO.parse(fasta_for_round(round_no), "fasta"))

        print("MSA finished across genes.")

    return headers, basepos, targets, msa, nocommon, variants


headers, basepos, targets, msa, nocommon, variants = findseq(genes, hits, dirname)