In [None]:
import requests, time, subprocess
import numpy as np
import xml.etree.ElementTree as ET
from Bio import AlignIO

WT_SEQUENCE = "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEG"
BLAST_URL = "https://blast.ncbi.nlm.nih.gov/Blast.cgi"

# ======= Step 1: BLAST API for Homologs =======
def run_blast_search(wt_sequence, identity_threshold=90.0, max_retries=30, sleep_time=10, min_length=100, hitlist_size=300):
    """Run BLAST and extract homologous sequences below a given identity threshold."""
    params = {
        "CMD": "Put",
        "PROGRAM": "blastp",
        "DATABASE": "nr",
        "QUERY": wt_sequence,
        "FORMAT_TYPE": "XML",
        "EXPECT": "1e-2",
        "HITLIST_SIZE": str(hitlist_size)
    }
    response = requests.post(BLAST_URL, data=params)
    response.raise_for_status()
    response_text = response.text
    if "RID = " not in response_text:
        raise Exception("No RID found in BLAST response.")
    rid = response_text.split("RID = ")[-1].split("\n")[0].strip()
    print(f"BLAST RID: {rid}")

    # Wait for completion
    for attempt in range(max_retries):
        status = requests.get(BLAST_URL, params={"CMD":"Get", "FORMAT_OBJECT":"SearchInfo", "RID":rid})
        if "Status=READY" in status.text:
            print("BLAST complete.")
            break
        print(f"Waiting... {attempt+1}/{max_retries}")
        time.sleep(sleep_time)
    else:
        raise Exception("BLAST timed out")

    # Download results
    result = requests.get(BLAST_URL, params={"CMD":"Get", "FORMAT_TYPE":"XML", "RID":rid})
    result.raise_for_status()
    root = ET.fromstring(result.text)
    seqs = []
    for hit in root.findall(".//Hit"):
        for hsp in hit.findall(".//Hsp"):
            hseq_elem = hsp.find("Hsp_hseq")
            identity_elem = hsp.find("Hsp_identity")
            align_len_elem = hsp.find("Hsp_align-len")
            if hseq_elem is not None and identity_elem is not None and align_len_elem is not None:
                hseq = hseq_elem.text.strip()
                identity = int(identity_elem.text)
                align_len = int(align_len_elem.text)
                identity_pct = 100 * identity / align_len
                if identity_pct < identity_threshold and len(hseq) > min_length:
                    seqs.append(hseq)
    seqs = [wt_sequence] + list({s for s in seqs if s != wt_sequence})  # unique, include WT
    print(f"Total homologs: {len(seqs)}")
    # Save to FASTA
    with open("msa_input.fasta", "w") as f:
        for i, s in enumerate(seqs):
            f.write(f">seq{i}\n{s}\n")
    return "msa_input.fasta"

# ======= Step 2: Align with MAFFT =======
def run_mafft(input_fasta, output_fasta="msa_aligned.fasta"):
    print(f"Running MAFFT alignment...")
    cmd = f"mafft --auto {input_fasta} > {output_fasta}"
    subprocess.run(cmd, shell=True, check=True)
    print(f"Alignment written: {output_fasta}")
    return output_fasta

# ======= Step 3: Calculate Henikoff Weights =======
def henikoff_weights(msa_file, format="fasta"):
    alignment = AlignIO.read(msa_file, format)
    n_seq = len(alignment)
    aln_len = alignment.get_alignment_length()
    weights = np.zeros(n_seq)
    for pos in range(aln_len):
        aa_counts = {}
        for record in alignment:
            aa = record.seq[pos]
            if aa not in aa_counts:
                aa_counts[aa] = 0
            aa_counts[aa] += 1
        n_types = len(aa_counts)
        for i, record in enumerate(alignment):
            aa = record.seq[pos]
            weights[i] += 1.0 / (n_types * aa_counts[aa])
    weights /= weights.sum()
    return weights

# ======= (Optional) Jackhmmer/MMseqs2 integration (not changed here) =======

# ==== MAIN WORKFLOW ====
method = "blast"  # "jackhmmer" or "mmseqs2" possible if implemented

if method == "blast":
    msa_input = run_blast_search(WT_SEQUENCE, identity_threshold=90.0, hitlist_size=500)
elif method == "jackhmmer":
    msa_input = run_jackhmmer_search(WT_SEQUENCE)
elif method == "mmseqs2":
    msa_input = run_mmseqs2_search(WT_SEQUENCE)
else:
    raise ValueError("Invalid method chosen. Please select 'blast', 'jackhmmer', or 'mmseqs2'.")

msa_aligned = run_mafft(msa_input)

weights = henikoff_weights(msa_aligned, "fasta")
print("Sequence weights:", weights)
np.save("msa_weights.npy", weights)
