<a href="https://colab.research.google.com/github/yongchanzzz/bioinformatics/blob/main/ortholog_fetcher.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧬 Ortholog Fetcher (Colab-ready, MMseqs2 → UniProt → BLAST)

Pipeline order:
1. MMseqs2 local search (proteome DB per species)
2. UniProt API gene search
3. BLAST (EBI REST API, stype=protein)

Output: CSV with columns species_name,taxid,chain_A,chain_B,...


In [None]:
#@title 🔧 Install MMseqs2
!wget -q https://mmseqs.com/latest/mmseqs-linux-avx2.tar.gz -O mmseqs2.tar.gz
!tar xfz mmseqs2.tar.gz
!sudo mv mmseqs/bin/mmseqs /usr/local/bin/

import subprocess, sys
try:
    out = subprocess.run(["mmseqs","-h"], capture_output=True, text=True, check=False)
    banner = out.stdout.splitlines()
    print("\n".join(banner[:5]))  # print just the top lines with version
except Exception as e:
    print("MMseqs2 installed, but couldn't read banner:", e)

In [None]:
#@title 🔧 Install Foldseek
!wget -q https://mmseqs.com/foldseek/foldseek-linux-avx2.tar.gz -O foldseek.tar.gz
!tar xfz foldseek.tar.gz
!sudo mv foldseek/bin/foldseek /usr/local/bin/

import subprocess
out = subprocess.run(["foldseek","-h"], capture_output=True, text=True)
print("\n".join(out.stdout.splitlines()[:8]))


In [None]:
#@title ⚙️ Parameters
# Seeds: UniProt accessions (comma-separated)
SEED_IDS = "P25705,P06576,P36542,P30049,P56381"  #@param {type:"string"}

# Column aliases (same length as seeds; blanks auto-filled: chain_A, chain_B, …)
CHAIN_NAMES = "ATP5F1A,ATP5F1B,ATP5F1G,ATP5F1D,ATP5F1E"  #@param {type:"string"}

# Species: scientific names (optionally 'Name (taxid)')
SPECIES_INPUT_NAMES = "Mus musculus, Danio rerio, Ciona intestinalis, Drosophila melanogaster, Caenorhabditis elegans"  #@param {type:"string"}
SPECIES_INPUT_TAXIDS = ""  #@param {type:"string"}  # e.g. "10090,7955" to force

# Optional interactive check before building DBs
INTERACTIVE_REVIEW = True  #@param {type:"boolean"}

# Proteome preference
USE_REVIEWED_ONLY = False   #@param {type:"boolean"}

# MMseqs2 (sequence homology)
USE_MMSEQS2 = True          #@param {type:"boolean"}
THREADS = 2                 #@param {type:"integer"}
MMSEQS2_SENS = 7.5          #@param {type:"number"}
MMSEQS2_MAX_SEQS = 200      #@param {type:"integer"}
MMSEQS2_MAX_EVALUE = 1e-3    #@param {type:"number"}
MMSEQS2_MIN_BITS   = 50.0    #@param {type:"number"}
MMSEQS2_MIN_PIDENT = 5.0   #@param {type:"number"}    # %
MMSEQS2_MIN_QCOV   = 0.30   #@param {type:"number"}    # 0..1
MMSEQS2_MIN_TCOV   = 0.50   #@param {type:"number"}    # 0..1

# Foldseek (structural; last resort)
USE_FOLDSEEK = False        #@param {type:"boolean"}
USE_BLAST = False        #@param {type:"boolean"}
FOLDSEEK_DB = ""            #@param {type:"string"}    # local DB path if available
FOLDSEEK_SENS = 9.5         #@param {type:"number"}
FOLDSEEK_MAX_SEQS = 1000    #@param {type:"integer"}
FOLDSEEK_ALIGNMENT_TYPE = 2 #@param {type:"integer"}   # 2 local, 1 global
FOLDSEEK_MAX_EVALUE = 1e-2  #@param {type:"number"}
FOLDSEEK_MIN_ALNTMS = 0.45  #@param {type:"number"}
FOLDSEEK_MIN_LDDT   = 0.00  #@param {type:"number"}
FOLDSEEK_MIN_QCOV   = 0.30  #@param {type:"number"}
FOLDSEEK_MIN_TCOV   = 0.30  #@param {type:"number"}
FOLDSEEK_COV_MODE   = 0     #@param {type:"integer"}

OUTPUT_BASENAME = "ortholog_fetcher"
VERBOSE = True              #@param {type:"boolean"}

import re, os, csv, math, time, json, tempfile, subprocess, unicodedata, requests
import pandas as pd

def _excel_col(n: int) -> str:
    s = ""
    while True:
        n, r = divmod(n, 26)
        s = chr(65 + r) + s
        if n == 0: break
        n -= 1
    return s

SEED_LIST = [s.strip() for s in SEED_IDS.split(",") if s.strip()]
_raw_names = [c.strip() for c in CHAIN_NAMES.split(",")] if CHAIN_NAMES else []
CHAIN_ALIASES = []
for i in range(len(SEED_LIST)):
    label = _raw_names[i] if i < len(_raw_names) and _raw_names[i] else f"chain_{_excel_col(i)}"
    CHAIN_ALIASES.append(label)

print("Seeds:", SEED_LIST)
print("Aliases:", CHAIN_ALIASES)


In [None]:
#@title 🧭 Species resolver (NCBI → UniProt; adds seed-origin species)
import re, time, json, requests
from IPython.display import display
try:
    import ipywidgets as widgets
except Exception:
    widgets = None

NCBI_TOOL = "OrthologFetcherColab"
NCBI_EMAIL = "your.email@example.com"
UNI_BASE = "https://rest.uniprot.org"
UNI_UNIPROTKB = f"{UNI_BASE}/uniprotkb"
UNI_TAXONOMY = f"{UNI_BASE}/taxonomy"
NCBI_EUTILS  = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"

def _parse_name_taxid_token(tok: str):
    tok = tok.strip()
    m = re.match(r"^(.*)\((\d+)\)\s*$", tok)
    if m:
        return m.group(1).strip(), int(m.group(2))
    return tok, None

def _ncbi_taxid_from_name(name: str) -> int | None:
    params = {
        "db":"taxonomy",
        "term": f'{name}[SCIN] OR {name}[CN]',
        "retmode":"json",
        "tool": NCBI_TOOL,
        "email": NCBI_EMAIL,
    }
    r = requests.get(f"{NCBI_EUTILS}/esearch.fcgi", params=params, timeout=30)
    if r.status_code != 200:
        return None
    ids = r.json().get("esearchresult", {}).get("idlist", [])
    if not ids:
        return None
    taxid = ids[0]
    time.sleep(0.34)  # be polite to NCBI
    return int(taxid)

def _uniprot_taxid_from_name(name: str) -> int | None:
    r = requests.get(f"{UNI_TAXONOMY}/search",
                     params={"query":name,"format":"json",
                             "fields":"scientificName,commonName,taxonId,rank","size":25},
                     timeout=30)
    if r.status_code != 200:
        return None
    res = r.json().get("results", [])
    if not res:
        return None
    for it in res:
        if it.get("scientificName","").lower()==name.lower() and (it.get("rank","") or "").lower()=="species":
            return int(it["taxonId"])
    for it in res:
        if it.get("scientificName","").lower()==name.lower():
            return int(it["taxonId"])
    for it in res:
        if (it.get("rank","") or "").lower()=="species":
            return int(it["taxonId"])
    return int(res[0]["taxonId"])

def resolve_species_inputs(names_csv: str, taxids_csv: str):
    out = []
    if taxids_csv.strip():
        for tok in [t for t in taxids_csv.split(",") if t.strip()]:
            tid = int(tok.strip())
            sci = f"taxid:{tid}"
            try:
                rr = requests.get(f"{UNI_TAXONOMY}/{tid}", params={"format":"json"}, timeout=30)
                if rr.status_code==200:
                    sci = rr.json().get("scientificName", sci)
            except Exception:
                pass
            out.append((sci, tid))
        return out

    for name_tok in [t for t in names_csv.split(",") if t.strip()]:
        nm, tid = _parse_name_taxid_token(name_tok)
        if tid:
            out.append((nm, tid)); continue
        tid = _ncbi_taxid_from_name(nm) or _uniprot_taxid_from_name(nm)
        if tid is None:
            print(f"[WARN] Could not resolve taxid for '{nm}'. "
                  f"Consider entering as 'Name (taxid)' or via SPECIES_INPUT_TAXIDS.")
        out.append((nm, tid))
    return out

def up_get_entry(acc):
    r = requests.get(f"{UNI_UNIPROTKB}/{acc}", params={"format":"json"}, timeout=30)
    if r.status_code==200: return r.json()
    return None

def up_taxid_of_acc(acc):
    js = up_get_entry(acc)
    return js.get("organism",{}).get("taxonId") if js else None

def species_name_from_taxid(taxid: int) -> str:
    try:
        rr = requests.get(f"{UNI_TAXONOMY}/{taxid}", params={"format":"json"}, timeout=30)
        if rr.status_code==200:
            return rr.json().get("scientificName", f"taxid:{taxid}")
    except Exception:
        pass
    return f"taxid:{taxid}"

# 1) Resolve the user-provided list
SPECIES_RESOLVED = resolve_species_inputs(SPECIES_INPUT_NAMES, SPECIES_INPUT_TAXIDS)
print("Resolved species (pre-review):", SPECIES_RESOLVED)

# 2) Add seed-origin species (seeds can be from several organisms)
SEED_LIST = [s.strip() for s in SEED_IDS.split(",") if s.strip()]
SEED_ORIGINS = {acc: up_taxid_of_acc(acc) for acc in SEED_LIST}
SEED_ORIGIN_SPECIES = {
    acc: (species_name_from_taxid(tid) if tid else None)
    for acc, tid in SEED_ORIGINS.items()
}

# 3) Build union set: user species + seed-origin species
_species_by_taxid = {}
for nm, tid in SPECIES_RESOLVED:
    if tid: _species_by_taxid[tid] = nm

for acc, tid in SEED_ORIGINS.items():
    if tid and tid not in _species_by_taxid:
        _species_by_taxid[tid] = SEED_ORIGIN_SPECIES.get(acc) or species_name_from_taxid(tid)

ALL_SPECIES = [(name, tid) for tid, name in sorted(_species_by_taxid.items())]

print("Seed origin taxids:", SEED_ORIGINS)
print("Augmented species set:", ALL_SPECIES)

# Optional interactive review
if INTERACTIVE_REVIEW and widgets is not None:
    name_boxes = []
    taxid_boxes = []
    rows = []
    for nm, tid in ALL_SPECIES:
        nb = widgets.Text(value=str(nm), layout=widgets.Layout(width="45%"))
        tb = widgets.Text(value="" if tid is None else str(tid), layout=widgets.Layout(width="20%"))
        name_boxes.append(nb); taxid_boxes.append(tb)
        rows.append(widgets.HBox([nb, tb]))
    apply_btn = widgets.Button(description="Use these", button_style="success")
    out_lbl = widgets.HTML()

    def _apply(_):
        global ALL_SPECIES
        new_list=[]
        for nb, tb in zip(name_boxes, taxid_boxes):
            nm = nb.value.strip()
            t  = tb.value.strip()
            if not nm: continue
            tax = int(t) if t.isdigit() else None
            new_list.append((nm, tax))
        # Deduplicate by taxid (keep last name entered)
        dedup = {}
        for nm, tid in new_list:
            if tid: dedup[tid] = nm
        ALL_SPECIES = [(nm, tid) for tid, nm in dedup.items()]
        out_lbl.value = f"<b>Updated species:</b> {ALL_SPECIES}"

    apply_btn.on_click(_apply)
    display(widgets.HTML("<b>Review/override species and taxids:</b>"))
    for r in rows: display(r)
    display(apply_btn, out_lbl)
else:
    print("Interactive review disabled or ipywidgets unavailable.")


In [None]:
#@title 🧱 Download proteomes & build MMseqs2 DBs
import os, io, gzip, time, subprocess, requests

os.makedirs("dbs", exist_ok=True)
os.makedirs("mmseqs_tmp", exist_ok=True)

UNI_STREAM = "https://rest.uniprot.org/uniprotkb/stream"
_HEADERS = {"User-Agent": "OrthologFetcher/1.0", "Accept": "*/*"}

def _sh(cmd, check=True):
    p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    if check and p.returncode != 0:
        print(p.stdout)
        raise subprocess.CalledProcessError(p.returncode, cmd)
    return p.stdout

def _count_headers(fa):
    if not os.path.exists(fa):
        return 0
    n = 0
    with open(fa, 'r', errors='ignore') as fh:
        for ln in fh:
            if ln.startswith(">"):
                n += 1
    return n

def _download_uniprot_fasta(taxid: int, out_fa: str, reviewed_only: bool, retries: int = 4):
    """
    Download proteome FASTA via UniProt stream API.
    Handles URL-encoding and gzip transparently.
    """
    query = f"(organism_id:{taxid})"
    if reviewed_only:
        query += " AND (reviewed:true)"  # contains a space → MUST be URL-encoded

    params = {
        "format": "fasta",
        "compressed": "true",   # server returns gzipped stream
        "query": query,
    }

    # Retry politely with backoff
    last_err = None
    for attempt in range(retries):
        try:
            print(f"  [DL] {taxid} {'Swiss-Prot' if reviewed_only else 'Swiss+TrEMBL'}")
            r = requests.get(UNI_STREAM, params=params, headers=_HEADERS, timeout=120, stream=True)
            r.raise_for_status()
            # Write the gzipped response to memory/file, then decompress
            gz_path = out_fa + ".gz"
            with open(gz_path, "wb") as fh:
                for chunk in r.iter_content(chunk_size=1 << 20):
                    if chunk:
                        fh.write(chunk)
            # Decompress
            with gzip.open(gz_path, "rb") as gzh, open(out_fa, "wb") as out:
                out.write(gzh.read())
            os.remove(gz_path)

            n = _count_headers(out_fa)
            print(f"  [OK] FASTA headers: {n}")
            if n == 0:
                # Empty file: treat as failure so caller can fallback
                try: os.remove(out_fa)
                except: pass
                raise RuntimeError("Downloaded FASTA is empty")
            return
        except Exception as e:
            last_err = e
            if attempt < retries - 1:
                sleep_s = min(2 ** attempt, 10)
                print(f"  [WARN] download failed (attempt {attempt+1}/{retries}): {e}. Retrying in {sleep_s}s…")
                time.sleep(sleep_s)
            else:
                print(f"  [ERR] download failed: {e}")
                raise

def _ensure_mmseqs_db(taxid: int):
    fa = f"dbs/{taxid}.fasta"
    if not os.path.exists(fa) or os.path.getsize(fa) == 0:
        raise RuntimeError(f"FASTA missing/empty: {fa}")
    dbdir = f"dbs/{taxid}"
    # detect if DB looks already built+indexed
    has_db = os.path.exists(dbdir) or os.path.exists(dbdir + ".dbtype")
    has_idx = os.path.exists(dbdir + ".index") or os.path.exists(os.path.join(dbdir, "index"))
    if has_db and has_idx:
        print(f"  [=] DB OK for {taxid}")
        return
    if not has_db:
        print(f"  [DB] Creating DB for {taxid}")
        _sh(["mmseqs", "createdb", fa, dbdir])
    print(f"  [IX] Indexing")
    # Some environments return non-zero if index already exists → don't hard-fail
    _sh(["mmseqs", "createindex", dbdir, "mmseqs_tmp", "--threads", str(THREADS)], check=False)

for sp, tid in ALL_SPECIES:
    if not tid:
        print(f"[SKIP] {sp}: no taxid")
        continue
    print(f"\n[{sp}] taxid={tid}")
    fa = f"dbs/{tid}.fasta"
    need_dl = not os.path.exists(fa) or _count_headers(fa) == 0
    if need_dl:
        try:
            _download_uniprot_fasta(tid, fa, reviewed_only=USE_REVIEWED_ONLY)
        except Exception:
            if USE_REVIEWED_ONLY:
                print("  [WARN] Swiss-Prot empty → fallback to full UniProt")
                _download_uniprot_fasta(tid, fa, reviewed_only=False)
            else:
                raise
    _ensure_mmseqs_db(tid)

print("\n[Build] done.")


In [None]:
#@title 🔎 UniProt helpers & methods
UNI_BASE = "https://rest.uniprot.org"
UNI_KB   = f"{UNI_BASE}/uniprotkb"

def _get_json(url, params=None, timeout=45):
    for attempt in range(4):
        try:
            r = requests.get(url, params=params, timeout=timeout)
            r.raise_for_status()
            return r.json()
        except requests.RequestException:
            if attempt == 3: raise
            time.sleep(2**attempt)

def get_uniprot_entry_json(acc: str):
    return _get_json(f"{UNI_KB}/{acc}", params={"format":"json"})

def get_uniprot_sequence(acc: str):
    r = requests.get(f"{UNI_KB}/{acc}.fasta", timeout=45)
    if r.status_code != 200: return None
    seq = []
    for ln in r.text.splitlines():
        if ln.startswith(">"): continue
        seq.append(ln.strip())
    return "".join(seq) if seq else None

def _seed_gene_and_names(seed_js: dict):
    gene = None
    try:
        g0 = seed_js.get("genes", [])
        if g0 and "geneName" in g0[0]:
            gene = g0[0]["geneName"].get("value")
    except Exception:
        pass
    pref = None
    try:
        pref = seed_js["proteinDescription"]["recommendedName"]["fullName"]["value"]
    except Exception:
        pass
    return gene, pref

def method_uniprot_quick_name(seed_acc: str, taxid: int, prefer_reviewed=True):
    try:
        seed_js = get_uniprot_entry_json(seed_acc)
    except Exception as e:
        return None, {"method":"uniprot-name","status":"seed-fetch-failed","detail":str(e)}
    gene, prefname = _seed_gene_and_names(seed_js)
    if not gene:
        return None, {"method":"uniprot-name","status":"no-gene"}

    q = f"(gene:{gene}) AND (organism_id:{taxid})"
    if prefer_reviewed: q += " AND (reviewed:true)"
    params = {"query": q, "format": "json", "fields": "accession,reviewed,organism_id,protein_name"}
    try:
        js = _get_json(f"{UNI_KB}/search", params=params)
    except Exception as e:
        return None, {"method":"uniprot-name","status":"search-failed","detail":str(e)}

    hits = js.get("results", [])
    if not hits:
        return None, {"method":"uniprot-name","status":"no-hit"}

    best = next((rec for rec in hits if rec.get("reviewed") is True), None) or hits[0]
    acc = best.get("primaryAccession") or best.get("accession")
    return (acc if acc != seed_acc else None), {"method":"uniprot-name","status":"hit","gene":gene}

# Verbose, defensive MMseqs2 wrapper used by both firstpass and deep-pass
import os, subprocess, requests, gzip, io, time, re

UNI_STREAM = "https://rest.uniprot.org/uniprotkb/stream"
_HEADERS = {"User-Agent": "OrthologFetcher/1.0", "Accept": "*/*"}
_ACC_RE = re.compile(r'^(?:[sptr]\|)?([A-Z0-9]{6,10})')

def _sh(cmd, check=True):
    p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    if check and p.returncode != 0:
        print(p.stdout)
        raise subprocess.CalledProcessError(p.returncode, cmd)
    return p.stdout

def _ensure_query_fasta(acc: str) -> tuple[str, dict]:
    os.makedirs("mmseqs_tmp", exist_ok=True)
    qf = f"mmseqs_tmp/{acc}.fa"
    if os.path.exists(qf) and os.path.getsize(qf) > 0:
        return qf, {"method":"mmseqs2","status":"query-cached"}
    params = {"format": "fasta", "compressed": "true", "query": f"(accession:{acc})"}
    for attempt in range(4):
        try:
            r = requests.get(UNI_STREAM, params=params, headers=_HEADERS, timeout=60)
            r.raise_for_status()
            with gzip.GzipFile(fileobj=io.BytesIO(r.content)) as gz:
                data = gz.read()
            with open(qf, "wb") as out:
                out.write(data)
            # sanity: ensure header
            with open(qf, "rt", errors="ignore") as fh:
                if not any(ln.startswith(">") for ln in fh):
                    os.remove(qf)
                    raise RuntimeError("FASTA has no header")
            return qf, {"method":"mmseqs2","status":"query-fetched"}
        except Exception as e:
            if attempt == 3:
                return "", {"method":"mmseqs2","status":"query-fetch-failed","error":str(e)}
            time.sleep(2**attempt)
    return "", {"method":"mmseqs2","status":"query-fetch-failed","error":"unknown"}

def _extract_acc(target_field: str) -> str | None:
    """
    Accepts 'sp|Q9W0V0|PROT_NAME ...' or 'Q9W0V0' and returns 'Q9W0V0'.
    """
    if not target_field:
        return None
    t = target_field.split()[0]  # mmseqs uses ID up to first whitespace
    m = _ACC_RE.match(t)
    return m.group(1) if m else None

def method_mmseqs2(seed_acc: str, taxid: int):
    """Return (best_target_acc | None, info_dict) with real metrics in info."""
    t0 = time.time()
    dbdir = f"dbs/{int(taxid)}"
    if not (os.path.exists(dbdir) or os.path.exists(dbdir + ".dbtype")):
        return None, {"method":"mmseqs2","status":"no-target-db","db":dbdir}

    # Query FASTA
    qf, qinfo = _ensure_query_fasta(seed_acc)
    if not qf:
        return None, qinfo

    # Run mmseqs easy-search
    outf = f"mmseqs_tmp/{seed_acc}_{int(taxid)}.tsv"
    fmt = "query,target,evalue,bits,qcov,tcov,alnlen,qlen,tlen,pident"
    cmd = [
        "mmseqs","easy-search", qf, dbdir, outf, "mmseqs_tmp",
        "-s", str(MMSEQS2_SENS),
        "--max-seqs", str(MMSEQS2_MAX_SEQS),
        "--format-output", fmt,
        "--threads", str(THREADS)
    ]
    try:
        _sh(cmd, check=True)
    except subprocess.CalledProcessError:
        return None, {"method":"mmseqs2","status":"search-failed","cmd":" ".join(cmd)}

    # Parse, filter, keep best (min evalue, tie-break by max bits)
    best = None
    try:
        with open(outf, "rt", errors="ignore") as fh:
            for ln in fh:
                if not ln or ln.startswith("#"):
                    continue
                parts = ln.rstrip("\n").split("\t")
                if len(parts) < 10:
                    continue
                _, target_raw, ev_s, bits_s, qcov_s, tcov_s, alnlen_s, qlen_s, tlen_s, pident_s = parts[:10]

                acc = _extract_acc(target_raw)
                if not acc or acc == seed_acc:
                    continue

                # robust float parsing
                try:
                    ev     = float(ev_s)
                except Exception:
                    try:    ev = float(ev_s.replace("E", "e"))
                    except: ev = float("inf")
                try:    bits   = float(bits_s)
                except: bits   = 0.0
                try:    qcov   = float(qcov_s)
                except: qcov   = 0.0
                try:    tcov   = float(tcov_s)
                except: tcov   = 0.0
                try:    pident = float(pident_s)
                except: pident = 0.0

                # thresholds
                if MMSEQS2_MAX_EVALUE is not None and ev > MMSEQS2_MAX_EVALUE: continue
                if MMSEQS2_MIN_BITS   is not None and bits   < MMSEQS2_MIN_BITS:   continue
                if MMSEQS2_MIN_PIDENT is not None and pident < MMSEQS2_MIN_PIDENT: continue
                if MMSEQS2_MIN_QCOV   is not None and qcov   < MMSEQS2_MIN_QCOV:   continue
                if MMSEQS2_MIN_TCOV   is not None and tcov   < MMSEQS2_MIN_TCOV:   continue

                cand = {
                    "acc": acc,
                    "evalue": ev, "bits": bits,
                    "pident": pident, "qcov": qcov, "tcov": tcov,
                    "alnlen": alnlen_s, "qlen": qlen_s, "tlen": tlen_s
                }

                if best is None:
                    best = cand
                else:
                    if (cand["evalue"] < best["evalue"]) or (
                        cand["evalue"] == best["evalue"] and cand["bits"] > best["bits"]
                    ):
                        best = cand
    except FileNotFoundError:
        return None, {"method":"mmseqs2","status":"no-tsv"}

    dur = round(time.time()-t0, 3)
    if best:
        info = {"method":"mmseqs2","status":"hit","time_s":dur}
        info.update(best)  # <- attach metrics so your stdout can show them
        return best["acc"], info

    return None, {"method":"mmseqs2","status":"no-hit-after-filter","time_s":dur}

def _afdb_model_to_file(acc: str, dest_pdb: str) -> bool:
    for ver in ("v4","v3"):
        url = f"https://alphafold.ebi.ac.uk/files/AF-{acc}-F1-model_{ver}.pdb"
        r = requests.get(url, timeout=45)
        if r.status_code == 200 and "ATOM" in r.text:
            with open(dest_pdb, "w") as f: f.write(r.text)
            return True
    for ver in ("v4","v3"):
        url = f"https://alphafold.ebi.ac.uk/files/AF-{acc}-F1-model_{ver}.cif"
        r = requests.get(url, timeout=45)
        if r.status_code == 200 and "_atom_site" in r.text:
            with open(dest_pdb, "w") as f: f.write(r.text)
            return True
    return False

def method_foldseek(seed_acc, taxid):
    if not USE_FOLDSEEK or not FOLDSEEK_DB or not os.path.exists(FOLDSEEK_DB):
        return None, {"method":"foldseek","status":"disabled-or-missing-db"}

    qtmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb"); qtmp.close()
    if not _afdb_model_to_file(seed_acc, qtmp.name):
        try: os.remove(qtmp.name)
        except: pass
        return None, {"method":"foldseek","status":"no-structure"}

    out_tsv = f"mmseqs_tmp/{seed_acc}_{taxid}.fold.tsv"
    fmt = "query,target,evalue,bits,qcov,tcov,alntmscore,qtmscore,ttmscore,lddt"
    cmd = [
        "foldseek","easy-search",
        qtmp.name, FOLDSEEK_DB, out_tsv, "mmseqs_tmp",
        "-s", str(FOLDSEEK_SENS),
        "--max-seqs", str(FOLDSEEK_MAX_SEQS),
        "--alignment-type", str(FOLDSEEK_ALIGNMENT_TYPE),
        "--format-output", fmt,
        "-e", str(FOLDSEEK_MAX_EVALUE),
        "-c", str(FOLDSEEK_MIN_QCOV),
        "--cov-mode", str(FOLDSEEK_COV_MODE)
    ]
    proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    if proc.returncode != 0:
        try: os.remove(qtmp.name)
        except: pass
        return None, {"method":"foldseek","status":"error","detail":proc.stdout.splitlines()[-1:]}

    best=None
    try:
        with open(out_tsv, newline="") as fh:
            rdr = csv.reader(fh, delimiter="\t")
            for row in rdr:
                if len(row) < 10: continue
                _, target, evalue, bits, qcov, tcov, alntms, qtms, ttms, lddt = row
                m = re.search(r"(?:\b[sptr]\|)?([A-Z0-9]{6,10})", target)
                acc = m.group(1) if m else None
                if not acc or acc == seed_acc: continue

                ev   = float(evalue) if evalue not in ("","NA") else math.inf
                qc   = float(qcov)   if qcov   not in ("","NA") else 0.0
                tc   = float(tcov)   if tcov   not in ("","NA") else 0.0
                atms = float(alntms) if alntms not in ("","NA") else 0.0
                ldt  = float(lddt)   if lddt   not in ("","NA") else 0.0

                if FOLDSEEK_MAX_EVALUE   is not None and ev > FOLDSEEK_MAX_EVALUE:   continue
                if FOLDSEEK_MIN_ALNTMS   is not None and atms < FOLDSEEK_MIN_ALNTMS: continue
                if FOLDSEEK_MIN_LDDT     is not None and ldt  < FOLDSEEK_MIN_LDDT:   continue
                if FOLDSEEK_MIN_QCOV     is not None and qc   < FOLDSEEK_MIN_QCOV:   continue
                if FOLDSEEK_MIN_TCOV     is not None and tc   < FOLDSEEK_MIN_TCOV:   continue

                cand = {"acc":acc,"evalue":evalue,"bits":bits,
                        "alntmscore":alntms,"qtmscore":qtms,"ttmscore":ttms,
                        "lddt":lddt,"qcov":qcov,"tcov":tcov}
                if best is None or float(cand["alntmscore"]) > float(best["alntmscore"]):
                    best = cand
        if best:
            return best["acc"], {"method":"foldseek","status":"hit", **best}
        return None, {"method":"foldseek","status":"no-hit-after-filter"}
    finally:
        try: os.remove(qtmp.name)
        except: pass


In [None]:
#@title 🚦 Orchestrators
def find_ortholog_firstpass(seed_acc, taxid, species_name):
    tried = []

    acc, info = method_uniprot_quick_name(seed_acc, taxid, prefer_reviewed=True)
    tried.append(info.copy())
    if acc:
        if VERBOSE: print(f"[OK] {species_name}({taxid}) {seed_acc} ← UniProt {acc}")
        return acc, info, tried

    acc, info = method_mmseqs2(seed_acc, taxid)
    tried.append(info.copy())
    if acc:
        if VERBOSE:
            print(f"[OK] {species_name}({taxid}) {seed_acc} ← MMseqs2 {acc} "
                  f"e={info.get('evalue')} bits={info.get('bits')} "
                  f"pident={info.get('pident')} qcov={info.get('qcov')} tcov={info.get('tcov')}")
        return acc, info, tried

    acc, info = method_foldseek(seed_acc, taxid)
    tried.append(info.copy())
    if acc:
        if VERBOSE:
            print(f"[OK] {species_name}({taxid}) {seed_acc} ← Foldseek {acc} "
                  f"alnTM={info.get('alntmscore')} lddt={info.get('lddt')}")
        return acc, info, tried

    if VERBOSE: print(f"[MISS] {species_name}({taxid}) {seed_acc}")
    return None, {"method":"none","status":"miss"}, tried


# Deep pass: omit UniProt text; try MMseqs2 → Foldseek (if enabled)
def find_ortholog_deeppass(seed_acc: str, taxid: int, species_name: str):
    tried = []
    if USE_MMSEQS2:
        acc, info = method_mmseqs2(seed_acc, taxid)  # 2-arg
        tried.append(info)
        if acc:
            return acc, info, tried
    if USE_FOLDSEEK:
        try:
            acc, info = method_foldseek(seed_acc, taxid)
        except TypeError:
            acc, info = method_foldseek(seed_acc, taxid, species_name)
        tried.append(info)
        if acc:
            return acc, info, tried
    return None, {"method":"deeppass","status":"miss"}, tried

In [None]:
#@title ▶️ Run (first pass) — with seed-origin shortcut + result cache
from collections import defaultdict

seed_accs = SEED_LIST
chain_names = CHAIN_ALIASES
assert len(seed_accs) == len(chain_names)

# Ensure we have seed origins (seed accession -> taxid)
try:
    SEED_ORIGINS
except NameError:
    # Fallback: compute if not present in this runtime (uses your helper)
    SEED_ORIGINS = {acc: up_taxid_of_acc(acc) for acc in seed_accs}

methods_log = []
rows = []

# Cache results of firstpass to avoid re-running identical (seed, taxid)
_firstpass_cache = {}  # (seed, taxid) -> (acc, info, tried)

for sp, taxid in ALL_SPECIES:
    if not taxid:
        print(f"[SKIP] No taxid for {sp}")
        continue

    row = {"species_name": sp, "taxid": taxid}

    for seed, chain in zip(seed_accs, chain_names):
        # 1) Seed-origin shortcut
        origin_tid = SEED_ORIGINS.get(seed)
        if origin_tid and int(origin_tid) == int(taxid):
            row[chain] = seed
            methods_log.append({
                "species_name": sp, "taxid": taxid, "seed": seed, "chain": chain,
                "found_acc": seed, "method": "seed-origin", "status": "origin"
            })
            continue

        # 2) Cached firstpass result?
        key = (seed, int(taxid))
        if key in _firstpass_cache:
            acc, info, tried = _firstpass_cache[key]
        else:
            acc, info, tried = find_ortholog_firstpass(seed, taxid, sp)
            _firstpass_cache[key] = (acc, info, tried)

        # 3) Write cell + logging
        row[chain] = acc or ""
        methods_log.append({
            "species_name": sp, "taxid": taxid, "seed": seed, "chain": chain,
            "found_acc": acc or "", "method": info.get("method","none"),
            "status": info.get("status","miss"),
            **{k:v for k,v in info.items() if k not in ("method","status")}
        })
        for step in tried:
            methods_log.append({
                "species_name": sp, "taxid": taxid, "seed": seed, "chain": chain,
                "found_acc": acc or "", "method": step.get("method",""),
                "status": step.get("status",""),
                **{k:v for k,v in step.items() if k not in ("method","status")}
            })

    rows.append(row)

# Build and save outputs
df = pd.DataFrame(rows, columns=["species_name","taxid"] + chain_names)
ts = pd.Timestamp.now().strftime('%Y%m%d-%H%M%S')
out_csv = f"{OUTPUT_BASENAME}_{ts}.csv"
out_methods = f"{OUTPUT_BASENAME}_{ts}_methods.csv"
df.to_csv(out_csv, index=False)
pd.DataFrame(methods_log).to_csv(out_methods, index=False)

print("\nSaved:")
print(" - Results:", out_csv)
print(" - Methods:", out_methods)
df


In [None]:
#@title 🔁 Deep Search (Repeatable)
DEEP_SEARCH_ROUNDS = 1     #@param {type:"integer"}
DEEP_MAX_COL_CAND = 32     #@param {type:"integer"}
DEEP_VERBOSE = True        #@param {type:"boolean"}
DEEP_RESET_CACHE = True   #@param {type:"boolean"}

# Persist cache across runs unless reset
if DEEP_RESET_CACHE or ("_deep_cache" not in globals()):
    _deep_cache = {}  # (query_acc, taxid) -> (hit_acc | None, info_dict)

def _nonempty_vals(series):
    out, seen = [], set()
    for v in series.tolist():
        if isinstance(v, str):
            vv = v.strip()
            if vv and vv not in seen:
                seen.add(vv); out.append(vv)
    return out

def _gather_same_chain_candidates(df, chain_col, seed_acc_for_chain):
    col_cands = _nonempty_vals(df[chain_col])
    if DEEP_MAX_COL_CAND > 0:
        col_cands = col_cands[:DEEP_MAX_COL_CAND]
    if seed_acc_for_chain and seed_acc_for_chain not in set(col_cands):
        col_cands.append(seed_acc_for_chain)
    return col_cands

# -------- pretty-print helpers --------
def _fmt_metrics(info: dict) -> str:
    def f(k, fmt):
        v = info.get(k)
        if v is None or v == "":
            return None
        try:
            return fmt(float(v))
        except Exception:
            return None
    parts = []
    e  = f("evalue", lambda x: f"e={x:.3e}")
    b  = f("bits",   lambda x: f"bits={x:.1f}")
    pi = f("pident", lambda x: f"pident={x:.1f}")
    qc = f("qcov",   lambda x: f"qcov={x:.3f}")
    tc = f("tcov",   lambda x: f"tcov={x:.3f}")
    for p in (e,b,pi,qc,tc):
        if p: parts.append(p)
    return " ".join(parts)

def _label_method(m):
    m = (m or "").lower()
    if m == "mmseqs2": return "MMseqs2"
    if m == "foldseek": return "Foldseek"
    return m or "MMseqs2"

# ----------------------------------------------------

def _deep_try_fill_cell(df, r, chain_col, species_name, taxid, seed_acc_for_chain):
    # already filled?
    cur = df.at[r, chain_col]
    if isinstance(cur, str) and cur.strip():
        return False

    cands = _gather_same_chain_candidates(df, chain_col, seed_acc_for_chain)
    if DEEP_VERBOSE:
        print(f"[DEEP] {species_name}({taxid}) {chain_col} will try {len(cands)} candidate(s)")

    for q in cands:
        key = (q, int(taxid))
        if key in _deep_cache:
            acc, info = _deep_cache[key]
        else:
            acc, info, _ = find_ortholog_deeppass(q, int(taxid), species_name)
            _deep_cache[key] = (acc, info)

        metrics = _fmt_metrics(info)
        method_label = _label_method(info.get("method"))

        if acc:
            df.at[r, chain_col] = acc
            print(f"[OK] {species_name}({taxid}) {seed_acc_for_chain} via {q} ← {method_label} {acc}"
                  + (f" {metrics}" if metrics else ""))
            log_row = {
                "species_name": species_name, "taxid": int(taxid),
                "seed": seed_acc_for_chain, "chain": chain_col,
                "found_acc": acc, "method": info.get("method",""),
                "status": info.get("status",""), "deep_round": True,
                "via": q
            }
            for k in ("evalue","bits","pident","qcov","tcov","alnlen","qlen","tlen","time_s"):
                if k in info:
                    log_row[k] = info[k]
            methods_log.append(log_row)
            return True
        else:
            print(f"[MISS] {species_name}({taxid}) {seed_acc_for_chain} via {q}")

    return False

# ---- Run deep rounds ----
fills_total = 0
for rnd in range(1, int(DEEP_SEARCH_ROUNDS) + 1):
    fills_this = 0
    if DEEP_VERBOSE:
        print(f"\n=== Deep round {rnd} ===")
    for r, (sp, tid) in enumerate(zip(df["species_name"], df["taxid"])):
        if not tid:
            continue
        for seed, chain in zip(SEED_LIST, CHAIN_ALIASES):
            if _deep_try_fill_cell(df, r, chain, sp, tid, seed):
                fills_this += 1

    fills_total += fills_this
    if DEEP_VERBOSE:
        print(f"[Deep] round {rnd} filled {fills_this}")
    if fills_this == 0:
        if DEEP_VERBOSE:
            print("[Deep] No more fills.")
        break

# ---- Save merged results ----
ts = pd.Timestamp.now().strftime('%Y%m%d-%H%M%S')
out_csv = f"{OUTPUT_BASENAME}_{ts}.csv"
out_methods = f"{OUTPUT_BASENAME}_{ts}_methods.csv"
df.to_csv(out_csv, index=False)
pd.DataFrame(methods_log).to_csv(out_methods, index=False)
print("\nSaved (first + deep):")
print(" - Results:", out_csv)
print(" - Methods:", out_methods)
df


In [None]:
#@title 📥 Download results
from google.colab import files, runtime
import os
if "out_csv" in globals() and os.path.exists(out_csv):
    files.download(out_csv)
if "out_methods" in globals() and os.path.exists(out_methods):
    files.download(out_methods)
