In [11]:
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import re
import pandas as pd
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord


@dataclass
class FastaAnnotator:
    """
    Reusable annotator for aligned FASTA files that:
      1) Normalizes headers (uppercase; canonical accession tokens).
      2) Splits sequences into two groups:
         - "NCBI side" (non-Whitman): merged to a TSV of NCBI metadata by accession.
         - "Whitman side" (SNVT-tagged): matched to Excel metadata by a canonical sample ID.
      3) Normalizes and attaches dates (keeps sequences even if date missing).
      4) Composes final FASTA headers with a user-defined template and writes output.

    Parameters
    ----------
    fasta_aligned_path : str
        Path to the aligned FASTA.
    ncbi_metadata_path : str
        Path to a TSV with an 'accession' column (case-insensitive), plus any other fields.
    submitter_col : str
        Column name in the NCBI TSV used for state mapping (substring match).
    state_map : Dict[str, str]
        Mapping from submitter name substrings to USA 2-letter state codes (e.g., {"HJELLE":"NM"}).
    whitman_sources : Sequence[Tuple[str, str, str, str]]
        Sequence of (excel_path, sheet_name, sample_col, date_col) for Whitman metadata inputs.
    whitman_us_state : str
        Two-letter state code to assign to Whitman sequences (default: "WA").
    header_template : str
        Template for final FASTA header. Available fields: {sample}, {usa}, {collection_date}.
    sequence_colname : str
        Column name for sequences in the internal DataFrame.

    Notes
    -----
    • Whitman detection is case-insensitive for the token 'SNVT' anywhere in the header.
    • Canonical NCBI accession token is derived from the first header token:
        - Take header → UPPERCASE
        - Split on space or pipe → take first segment
        - Strip trailing version suffix: .<digits>
    • Whitman sample IDs are extracted from headers matching:
        ..._SNVT_<PREFIX>_<NUM>[_FILLED]  →  PREFIX<NUM>, with PEMA→PESO normalization.
      Fallbacks also support WHIT#, EP#, or last-token heuristics.
    • Whitman metadata uses the same sample-extraction logic as the FASTA, uppercased.
    """

    # Inputs
    fasta_aligned_path: str
    ncbi_metadata_path: str
    submitter_col: str = "submitters"
    state_map: Dict[str, str] = field(default_factory=dict)

    # Whitman metadata sources as (file_path, sheet_name, sample_col, date_col)
    whitman_sources: Sequence[Tuple[str, str, str, str]] = field(default_factory=list)

    # Configuration
    whitman_us_state: str = "WA"
    header_template: str = "{sample}|{usa}|{collection_date}"
    sequence_colname: str = "sequence"

    # Internals (populated after annotate())
    aln_dict: Dict[str, SeqRecord] = field(init=False, default_factory=dict)  # keys: FULL UPPER headers; rec.name: accession_token
    ncbi_df: pd.DataFrame = field(init=False, default_factory=pd.DataFrame)
    merged_ncbi_df: pd.DataFrame = field(init=False, default_factory=pd.DataFrame)
    whitman_df: pd.DataFrame = field(init=False, default_factory=pd.DataFrame)
    combined_df: pd.DataFrame = field(init=False, default_factory=pd.DataFrame)

    # ----------------------------
    # Public API
    # ----------------------------

    def annotate(self) -> pd.DataFrame:
        """Run the full pipeline and return the combined annotated DataFrame."""
        self._load_alignment()
        self._load_ncbi_metadata()

        left = self._build_ncbi_side()
        right = self._build_whitman_side()

        combined = pd.concat([left, right], ignore_index=True, sort=False)

        # Normalize / render dates and compose final headers
        combined["collection_date"] = combined["collection_date"].apply(self._normalize_collection_date)
        combined = self._compose_headers(combined)

        self.combined_df = combined
        return combined

    def write_fasta(self, output_path: str) -> None:
        """Write annotated sequences to FASTA using composed 'header' as the SeqRecord.id."""
        if self.combined_df.empty:
            raise RuntimeError("No data to write. Run annotate() first.")
        records = [
            SeqRecord(
                Seq(seq),
                id=hdr,
                description=""
            )
            for hdr, seq in zip(self.combined_df["header"], self.combined_df[self.sequence_colname])
        ]
        with open(output_path, "w") as handle:
            SeqIO.write(records, handle, "fasta")

    # ----------------------------
    # Stage 0: I/O + Normalization
    # ----------------------------

    def _load_alignment(self) -> None:
        """
        Load aligned FASTA into a dict of SeqRecords with:
          • KEY  : full UPPERCASE header (unchanged content, just uppercased)
          • rec.name : canonical accession token for NCBI merges
        """
        raw = SeqIO.to_dict(SeqIO.parse(self.fasta_aligned_path, "fasta"))
        cleaned: Dict[str, SeqRecord] = {}

        for k, rec in raw.items():
            kU = str(k).upper()  # global case normalization
            # Canonical accession token for NCBI:
            #   - take first segment before space or pipe
            #   - strip trailing .<digits> version suffix
            first = kU.split(" ", 1)[0].split("|", 1)[0]
            token = re.sub(r"\.\d+$", "", first)  # e.g., OQ999106.1 → OQ999106

            rec_clean = SeqRecord(rec.seq, id=kU, name=token, description="")
            cleaned[kU] = rec_clean

        self.aln_dict = cleaned

    def _load_ncbi_metadata(self) -> None:
        """Load NCBI metadata TSV and standardize column names + accession case."""
        df = pd.read_csv(self.ncbi_metadata_path, sep="\t")
        df.columns = df.columns.str.lower()
        if "accession" not in df.columns:
            raise ValueError("NCBI metadata must contain an 'accession' column.")
        df["accession"] = df["accession"].astype(str).str.upper()
        self.ncbi_df = df

    # ----------------------------
    # Stage 1: Frame builders
    # ----------------------------

    def _df_from_alignment(self, keep_if) -> pd.DataFrame:
        """Create a DataFrame with 'header', 'accession_token', and 'sequence' for records passing keep_if(header)."""
        rows: List[Dict[str, str]] = []
        for key, rec in self.aln_dict.items():
            if keep_if(key):
                rows.append({
                    "header": key,                          # full UPPER header (for Whitman parsing)
                    "accession_token": rec.name,            # canonical accession (for NCBI merges)
                    self.sequence_colname: str(rec.seq)
                })
        return pd.DataFrame(rows)

    def _build_ncbi_side(self) -> pd.DataFrame:
        """
        Non-Whitman sequences: merge with NCBI metadata by canonical accession token.
        Adds 'usa' via submitter mapping (if provided).
        Produces columns: ['accession', 'sequence', 'usa', 'collection_date', 'sample'].
        """
        aln_ncbi = self._df_from_alignment(lambda k: not self._is_whitman(k))

        # Prepare for merge: NCBI metadata 'accession' must match 'accession_token'
        merged = pd.merge(self.ncbi_df, aln_ncbi,
                          left_on="accession", right_on="accession_token", how="inner")

        # Assign USA state via submitter mapping (if any)
        merged = self._map_submitters_to_state(merged, self.submitter_col, self.state_map)

        # Normalize/ensure required columns exist
        out = merged.copy()
        if "collection_date" in out.columns:
            out["collection_date"] = out["collection_date"].apply(self._normalize_collection_date)
        else:
            out["collection_date"] = pd.NaT

        # For the NCBI side, we default 'sample' to the accession token unless a 'sample' column exists
        out["sample"] = out.get("sample", out["accession"])

        if "usa" not in out.columns:
            out["usa"] = pd.NA

        out.rename(columns={self.sequence_colname: "sequence"}, inplace=True)
        out = out[["accession", "sequence", "usa", "collection_date", "sample"]].copy()
        return out

    def _build_whitman_side(self) -> pd.DataFrame:
        """
        Whitman sequences (headers containing 'SNVT', case-insensitive):
          • Derive 'sample' from the header via SNVT regex (PEMA→PESO at sample level).
          • LEFT join to Whitman metadata (do NOT drop rows on missing dates).
          • Assign usa = whitman_us_state.
        Produces columns: ['accession', 'sequence', 'usa', 'collection_date', 'sample'].
        """
        aln_whit = self._df_from_alignment(self._is_whitman)
        if aln_whit.empty:
            return pd.DataFrame(columns=["accession", "sequence", "usa", "collection_date", "sample"])

        # Derive sample (from full header) and USA
        whit = aln_whit.copy()
        whit["sample"] = whit["header"].apply(self._sample_extractor)
        whit["usa"] = self.whitman_us_state

        # Load Whitman metadata, normalized to same sample keys
        meta_whit = self._load_whitman_metadata()

        # LEFT join: keep sequences even if metadata/date missing
        whit = pd.merge(whit, meta_whit, on="sample", how="left")

        # Compose output schema
        whit.rename(columns={"accession_token": "accession", self.sequence_colname: "sequence"}, inplace=True)
        out = whit[["accession", "sequence", "usa", "collection_date", "sample"]].copy()
        return out

    # ----------------------------
    # Stage 2: Metadata loaders and utilities
    # ----------------------------

    def _load_whitman_metadata(self) -> pd.DataFrame:
        """
        Load and harmonize Whitman metadata to ['sample', 'collection_date'].
        • Uppercase sample strings, then apply SAME extractor as for FASTA headers.
        • Keep rows even if collection_date is missing (no dropping; left-join later).
        """
        frames: List[pd.DataFrame] = []
        for (path, sheet, sample_col, date_col) in self.whitman_sources:
            tmp = pd.read_excel(path, sheet_name=sheet)[[sample_col, date_col]].copy()
            tmp.columns = ["sample_raw", "collection_date_raw"]
            frames.append(tmp)

        if frames:
            meta = pd.concat(frames, ignore_index=True)
        else:
            meta = pd.DataFrame(columns=["sample_raw", "collection_date_raw"])

        # Normalize to uppercase and derive canonical sample
        meta["sample"] = meta["sample_raw"].astype(str).str.upper().apply(self._sample_extractor)

        # Parse dates; allow missing
        meta["collection_date"] = meta["collection_date_raw"].apply(self._normalize_collection_date)

        return meta[["sample", "collection_date"]].copy()

    @staticmethod
    def _map_submitters_to_state(df: pd.DataFrame,
                                 submitter_col: str,
                                 state_map: Dict[str, str]) -> pd.DataFrame:
        """Assign two-letter state codes based on substring matches in submitter_col."""
        if not state_map or submitter_col not in df.columns:
            return df
        if "usa" not in df.columns:
            df["usa"] = pd.NA

        # Case-insensitive substring match
        sub = df[submitter_col].astype(str).str.upper()
        for name, state in state_map.items():
            mask = sub.str.contains(str(name).upper(), na=False)
            df.loc[mask, "usa"] = state
        return df

    # ----------------------------
    # Stage 3: Parsing / Normalization helpers
    # ----------------------------

    @staticmethod
    def _is_whitman(header: str) -> bool:
        """Detect Whitman entries by presence of 'SNVT' (case-insensitive)."""
        return "SNVT" in str(header).upper()

    @staticmethod
    def _normalize_collection_date(val: Union[str, pd.Timestamp, None]) -> Optional[pd.Timestamp]:
        """
        Normalize collection dates:
          - 'YYYY'        -> YYYY-06-01
          - 'YYYY-MM'     -> YYYY-MM-01
          - 'YYYY-MM-DD'  -> as-is
          - otherwise: pandas-coerced (may be NaT)
        """
        if val is None or (isinstance(val, float) and pd.isna(val)):
            return pd.NaT
        if isinstance(val, pd.Timestamp):
            return pd.to_datetime(val.date())

        s = str(val).strip()
        if not s or s.lower() in {"nan", "nat"}:
            return pd.NaT

        if re.fullmatch(r"\d{4}", s):
            return pd.to_datetime(f"{s}-06-01", errors="coerce")
        if re.fullmatch(r"\d{4}-\d{2}", s):
            return pd.to_datetime(f"{s}-01", errors="coerce")
        return pd.to_datetime(s, errors="coerce")

    @staticmethod
    def _safe_str(x) -> str:
        return "" if pd.isna(x) else str(x)

    def _compose_headers(self, df: pd.DataFrame) -> pd.DataFrame:
        """Create final FASTA headers using the header_template."""
        df = df.copy()
        # Render date to ISO string (YYYY-MM-DD or empty)
        date_str = df["collection_date"].dt.strftime("%Y-%m-%d")
        df["header"] = [
            self.header_template.format(
                sample=self._safe_str(s),
                usa=self._safe_str(u),
                collection_date=self._safe_str(d)
            )
            for s, u, d in zip(df["sample"], df["usa"], date_str)
        ]
        return df

    # ----------------------------
    # Sample extraction logic
    # ----------------------------

    @staticmethod
    def _sample_extractor(header_or_sample: str) -> str:
        """
        Derive a canonical sample ID from either a FASTA header or a raw metadata sample string.
        Priority patterns (all case-insensitive; input is uppercased internally):
          1) ..._SNVT_<PREFIX>_<NUM>[_FILLED]  →  PREFIX<NUM>  (with PEMA→PESO normalization)
          2) WHIT<NUM>                          →  WHIT<NUM>   (no leading zeros)
          3) EP<NUM>                            →  <NUM>
          4) Fallback: last '_' token (alnum only)
        """
        h = str(header_or_sample).upper().strip()

        # Prefer SNVT pattern (allow optional trailing _FILLED)
        m = re.search(r"_SNVT_([A-Z]+)_(\d+)(?:_FILLED)?$", h)
        if m:
            prefix, num = m.group(1), int(m.group(2))
            if prefix == "PEMA":  # normalize at the sample level (safe)
                prefix = "PESO"
            return f"{prefix}{num}"

        # WHIT<NUM>
        m = re.search(r"(WHIT)0?(\d{1,4})", h)
        if m:
            return f"WHIT{int(m.group(2))}"

        # EP<NUM> → numeric only
        m = re.search(r"(EP)0?(\d+)", h)
        if m:
            return m.group(2)

        # Conservative fallback: last '_' token stripped to [A-Z0-9]
        last = h.split("_")[-1]
        return re.sub(r"[^A-Z0-9]+", "", last)

In [12]:
annot = FastaAnnotator(
    fasta_aligned_path="../Avail_seqs/M/SNV_M_NCBI_MEZAP.fasta",
    ncbi_metadata_path="../Avail_seqs/ncbi_data.tsv",
    submitter_col="submitters",
    state_map={"GOODFELLOW":"NM","HECHT":"AZ","HJELLE":"NM","BOTTEN":"NM","SPIROPOULOU":"NM"},
    whitman_sources=[
        ("../Avail_seqs/rodent_sero_ELISAdata.xlsx", "RNAextractionLung", "Individual", "Date Sampled"),
        ("../Avail_seqs/WHIT.Samples.EEIDP.xlsx", "3.3.25 - RNA Extraction", "Sample ID Number", "Collection Date"),
    ],
    whitman_us_state="WA",
    header_template="{sample}|{usa}|{collection_date}"
)
df = annot.annotate()
annot.write_fasta("../Avail_seqs/ncbi_whitman_aligned.fasta")

df

Unnamed: 0,accession,sequence,usa,collection_date,sample,header
0,OQ999106,ATGGTAGGGTGGGTTTGCATCTTCCTCGTGGTCCTTACTACTGCAA...,CA,1999-04-29,OQ999106,OQ999106|CA|1999-04-29
1,OQ999109,ATGGTAGGGTGGGTTTGCATCTTCCTCGTGGTCCTTACCACTGCAA...,CA,1999-06-21,OQ999109,OQ999109|CA|1999-06-21
2,OQ999111,----------------------------------------------...,CA,1999-04-14,OQ999111,OQ999111|CA|1999-04-14
3,OQ999114,ATGGTAGGGTGGGTTTGCATCTTCCTCGTGGTCCTTACTACTGCAA...,CA,2000-03-04,OQ999114,OQ999114|CA|2000-03-04
4,OQ999119,ATGGTAGGGTGGGTTTGCATCTTCCTCGTGGTCCTTACTACTGCAA...,CO,2011-03-28,OQ999119,OQ999119|CO|2011-03-28
...,...,...,...,...,...,...
66,2DZNTT_6_6_SNVT_PEMA_230_FILLED,ATGGTAGGGTGGGTTTGCATCTTCCTCGTGGTCCTTACTACTGCAA...,WA,NaT,PESO230,PESO230|WA|
67,YMD9RR_2_SNVT_PEMA_261_FILLED,ATGGTAGGGTGGGTTTGCATCTTCCTCGTGGTCCTTACTACTGCAA...,WA,NaT,PESO261,PESO261|WA|
68,YMD9RR_3_SNVT_PEMA_287_FILLED,ATGGTAGGGTGGGTTTGCATCTTCCTCGTGGTCCTTACTACTGCAA...,WA,NaT,PESO287,PESO287|WA|
69,YMD9RR_4_SNVT_PEMA_295_FILLED,ATGGTAGGGTGGGTTTGCATCTTCCTCGTGGTCCTTACTACTGCAA...,WA,NaT,PESO295,PESO295|WA|
