# Peak-to-Gene and TSS Annotation  
This notebook loads peak tables (`*_AllPeaks.tsv`), genome annotations (`.gff3`),  
and strand-specific TSS maps (`TSS_plus.bed`, `TSS_minus.bed`) to assign each peak  
to its nearest gene and transcription start site.  

Using `IntervalTree` overlap queries and signed TSS distances, the notebook classifies peaks into:  

- **promoter**  
- **genic**  
- **downstream**  
- **distal**  
- **genic_noTSS**  
- **unknown**  

Outputs include:  

- An Excel file with:  
  - `Merged` — peaks mapped to a unique gene  
  - `Ambiguous_Unmapped` — multi-gene or unassigned peaks  
  - `QC` — summary statistics  
  - `QC_region_class` — counts per region category  

This notebook is designed for **Google Colab** and **reproducibility**, using local repository files  
or optional mock datasets for testing.


In [None]:
#@title 0. Install dependencies and import libraries

!pip install intervaltree --quiet

from pathlib import Path
import pandas as pd
import numpy as np
from intervaltree import Interval, IntervalTree
from bisect import bisect_left
import re


In [None]:
#@title 1. Configure paths and basic parameters

# Example root directory for a local/Colab project.
# Change ROOT_DIR to match your folder (e.g., Google Drive mount).

ROOT_DIR = Path("/content/project_demo")  # <-- CHANGE THIS

DATA_DIR = ROOT_DIR / "data"
ALLPEAKS_DIR = DATA_DIR / "allpeaks"
GFF_FILE = DATA_DIR / "genome" / "annotation.gff3"
TSS_PLUS_BED = DATA_DIR / "tss" / "TSS_plus.bed"
TSS_MINUS_BED = DATA_DIR / "tss" / "TSS_minus.bed"

OUTPUT_DIR = ROOT_DIR / "results"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Biological / window parameters (you can tune these)
PROMOTER_UPSTREAM = 200    # bp upstream of TSS considered "promoter"
PROMOTER_DOWNSTREAM = 50   # bp downstream of TSS still considered "promoter"
DOWNSTREAM_MAX = 1000      # max bp downstream to still call "downstream"


In [None]:
#@title 2. Utility functions for GFF, IntervalTrees, and TSS index

def parse_gff_attributes(attr_str: str) -> dict:
    """
    Convert GFF attribute column into a simple dict.
    """
    attrs = {}
    for field in attr_str.split(";"):
        field = field.strip()
        if not field:
            continue
        if "=" in field:
            k, v = field.split("=", 1)
        elif " " in field:
            # some GFFs use space instead of '='
            k, v = field.split(" ", 1)
        else:
            continue
        attrs[k.strip()] = v.strip()
    return attrs


def choose_gene_id(attrs: dict) -> str | None:
    """
    Choose a gene identifier from GFF attributes in a priority order.
    """
    for key in ["locus_tag", "gene", "ID", "Name"]:
        if key in attrs:
            return attrs[key]
    return None


def build_gene_intervals(gff_path: Path):
    """
    Read a GFF/GFF3 file, aggregate CDS/gene entries into per-gene intervals,
    and return:
      - genes_df: DataFrame with columns [chrom, gene_id, strand, start, end, attributes]
      - trees: dict[chrom] -> IntervalTree of gene intervals
    """
    gene_records = {}

    with gff_path.open() as fh:
        for line in fh:
            if not line.strip() or line.startswith("#"):
                continue
            parts = line.rstrip("\n").split("\t")
            if len(parts) != 9:
                continue

            seqid, source, ftype, start, end, score, strand, phase, attrs_str = parts
            if ftype not in ("CDS", "gene"):
                continue

            start = int(start)
            end = int(end)
            attrs = parse_gff_attributes(attrs_str)
            gene_id = choose_gene_id(attrs)
            if gene_id is None:
                continue

            key = (seqid, gene_id, strand)
            if key not in gene_records:
                gene_records[key] = {
                    "chrom": seqid,
                    "gene_id": gene_id,
                    "strand": strand,
                    "start": start,
                    "end": end,
                    "attributes": attrs_str,
                }
            else:
                gene_records[key]["start"] = min(gene_records[key]["start"], start)
                gene_records[key]["end"] = max(gene_records[key]["end"], end)

    genes_df = pd.DataFrame(gene_records.values()).reset_index(drop=True)

    # Build IntervalTrees per chromosome
    trees = {}
    for chrom, sub in genes_df.groupby("chrom"):
        tree = IntervalTree()
        for _, row in sub.iterrows():
            tree.add(Interval(row["start"], row["end"] + 1, {
                "gene_id": row["gene_id"],
                "strand": row["strand"],
            }))
        trees[chrom] = tree

    return genes_df, trees


def load_tss_bed(bed_path: Path) -> pd.DataFrame:
    """
    Read a BED file with TSS information:
      chrom, start, end, name, score, strand

    and convert to a 1-based TSS coordinate:
      + strand: tss_pos = start + 1
      - strand: tss_pos = end
    """
    df = pd.read_csv(
        bed_path,
        sep="\t",
        header=None,
        names=["chrom", "start", "end", "name", "score", "strand"],
    )

    df["tss_pos"] = np.where(df["strand"] == "+", df["start"] + 1, df["end"])
    return df


def build_tss_index(tss_plus_df: pd.DataFrame, tss_minus_df: pd.DataFrame):
    """
    Build a TSS index:

      tss_index[(chrom, strand)] = {
          "pos": np.array([...]),
          "name": np.array([...]),
      }
    """
    tss_index = {}

    for df in [tss_plus_df, tss_minus_df]:
        for (chrom, strand), sub in df.groupby(["chrom", "strand"]):
            sub_sorted = sub.sort_values("tss_pos")
            key = (chrom, strand)
            tss_index[key] = {
                "pos": sub_sorted["tss_pos"].to_numpy(),
                "name": sub_sorted["name"].to_numpy(),
            }

    return tss_index


def nearest_tss(chrom: str, strand: str, pos: int, tss_index: dict):
    """
    Return (tss_name, tss_pos, distance) for the nearest TSS on a given
    chromosome and strand.

    distance is signed:
      + strand: pos - tss_pos  (positive = downstream)
      - strand: tss_pos - pos  (positive = downstream in the - strand direction)
    """
    key = (chrom, strand)
    if key not in tss_index:
        return None, None, None

    arr = tss_index[key]["pos"]
    names = tss_index[key]["name"]
    if len(arr) == 0:
        return None, None, None

    idx = bisect_left(arr, pos)
    candidates = [j for j in (idx - 1, idx, idx + 1) if 0 <= j < len(arr)]
    if not candidates:
        return None, None, None

    best_j = min(candidates, key=lambda j: abs(arr[j] - pos))
    tss_pos = int(arr[best_j])
    tss_name = str(names[best_j])

    if strand == "+":
        dist = pos - tss_pos
    else:  # strand == "-"
        dist = tss_pos - pos

    return tss_name, tss_pos, dist


In [None]:
#@title 3. Load AllPeaks files and compute PeakCenter

def load_allpeaks_from_dir(allpeaks_dir: Path) -> pd.DataFrame:
    """
    Load all *AllPeaks.tsv files from a directory and concatenate them.
    Creates a 'PeakCenter' column from 'AvgMaxPos' or 'RegionMiddle'.
    Adds a 'source' column with the filename.
    """
    all_files = sorted(allpeaks_dir.glob("*AllPeaks.tsv"))
    if not all_files:
        raise FileNotFoundError(f"No *AllPeaks.tsv files found in {allpeaks_dir}")

    dfs = []
    for f in all_files:
        df = pd.read_csv(f, sep="\t")
        if "AvgMaxPos" in df.columns:
            df["PeakCenter"] = df["AvgMaxPos"]
        elif "RegionMiddle" in df.columns:
            df["PeakCenter"] = df["RegionMiddle"]
        else:
            # fallback: simple midpoint of RegionStart and RegionEnd
            df["PeakCenter"] = (
                (df["RegionStart"] + df["RegionEnd"]) / 2.0
            ).round().astype(int)
        df["source"] = f.name
        dfs.append(df)

    merged = pd.concat(dfs, ignore_index=True)
    return merged


allpeaks_df = load_allpeaks_from_dir(ALLPEAKS_DIR)
print("Loaded peaks shape:", allpeaks_df.shape)
allpeaks_df.head()


In [None]:
#@title 4. Load genes (GFF) and TSS (BED) and build data structures

genes_df, gene_trees = build_gene_intervals(GFF_FILE)
print("Aggregated genes:", genes_df.shape)

tss_plus_df = load_tss_bed(TSS_PLUS_BED)
tss_minus_df = load_tss_bed(TSS_MINUS_BED)
tss_index = build_tss_index(tss_plus_df, tss_minus_df)

genes_df.head(), tss_plus_df.head(), tss_minus_df.head()


In [None]:
#@title 5. Region classification and peak annotation

def classify_region(distance_to_tss: float | None,
                    overlaps_gene: bool,
                    downstream_max: int,
                    prom_upstream: int,
                    prom_downstream: int) -> str:
    """
    Classify the peak region based on distance to the nearest TSS and
    whether the peak overlaps a gene.

    Rules (minimal and generic):
      - "promoter": within [−prom_upstream, 0] or [0, +prom_downstream]
      - "genic": overlaps a gene but not classified as promoter
      - "downstream": (0, downstream_max] if not genic
      - "distal": anything else with a TSS
      - "genic_noTSS": overlaps a gene but no TSS info
      - "unknown": no gene overlap and no TSS info
    """
    if distance_to_tss is None:
        if overlaps_gene:
            return "genic_noTSS"
        else:
            return "unknown"

    d = distance_to_tss

    # upstream promoter (before TSS)
    if -prom_upstream <= d <= 0:
        return "promoter"

    # immediate downstream promoter region
    if 0 <= d <= prom_downstream:
        return "promoter"

    # genic (within gene body, but not promoter)
    if overlaps_gene:
        return "genic"

    # downstream region (further away, but still close)
    if 0 < d <= downstream_max:
        return "downstream"

    return "distal"


def annotate_peaks(allpeaks_df: pd.DataFrame,
                   genes_df: pd.DataFrame,
                   gene_trees: dict,
                   tss_index: dict,
                   prom_upstream: int = PROMOTER_UPSTREAM,
                   prom_downstream: int = PROMOTER_DOWNSTREAM,
                   downstream_max: int = DOWNSTREAM_MAX) -> pd.DataFrame:
    """
    Annotate each peak with:
      - genes_overlapping (list collapsed as ';')
      - n_overlapping_genes
      - primary_gene_id (if unique)
      - primary_gene_strand
      - nearest_tss_name / pos / distance
      - region_class
      - is_ambiguous (True if >1 gene)
      - is_unmapped (True if no gene and no TSS)
      plus all original AllPeaks columns.
    """
    records = []

    for idx, row in allpeaks_df.iterrows():
        chrom = row["Chromosome"]
        center = int(row["PeakCenter"])

        # 1) Genes overlapping the peak center
        overlaps_gene = False
        overlapping_genes = []
        primary_gene_id = None
        primary_gene_strand = None
        n_overlapping_genes = 0

        if chrom in gene_trees:
            hits = gene_trees[chrom].overlap(center, center + 1)
            for h in hits:
                overlapping_genes.append((h.data["gene_id"], h.data["strand"]))
            n_overlapping_genes = len(overlapping_genes)
            if n_overlapping_genes == 1:
                primary_gene_id, primary_gene_strand = overlapping_genes[0]
                overlaps_gene = True
            elif n_overlapping_genes > 1:
                overlaps_gene = True

        # 2) Nearest TSS
        nearest_name = None
        nearest_pos = None
        nearest_dist = None

        if primary_gene_strand in ("+", "-"):
            tss_name, tss_pos, dist = nearest_tss(chrom, primary_gene_strand, center, tss_index)
            nearest_name, nearest_pos, nearest_dist = tss_name, tss_pos, dist
        else:
            # Try both strands and choose the closest if no primary strand is known
            best = None
            for strand in ["+", "-"]:
                tss_name, tss_pos, dist = nearest_tss(chrom, strand, center, tss_index)
                if tss_name is None:
                    continue
                if best is None or abs(dist) < abs(best[2]):
                    best = (tss_name, tss_pos, dist)
            if best is not None:
                nearest_name, nearest_pos, nearest_dist = best

        # 3) Region classification
        region_class = classify_region(
            nearest_dist,
            overlaps_gene,
            downstream_max=downstream_max,
            prom_upstream=prom_upstream,
            prom_downstream=prom_downstream,
        )

        # 4) Ambiguity/unmapped flags
        is_ambiguous = (n_overlapping_genes > 1)
        is_unmapped = (n_overlapping_genes == 0 and nearest_name is None)

        rec = {
            "RegionNumber": row["RegionNumber"],
            "Chromosome": chrom,
            "RegionStart": row["RegionStart"],
            "RegionEnd": row["RegionEnd"],
            "PeakCenter": center,
            "primary_gene_id": primary_gene_id,
            "primary_gene_strand": primary_gene_strand,
            "genes_overlapping": ";".join([g for g, s in overlapping_genes]) if overlapping_genes else "",
            "n_overlapping_genes": n_overlapping_genes,
            "nearest_tss_name": nearest_name,
            "nearest_tss_pos": nearest_pos,
            "nearest_tss_distance": nearest_dist,
            "region_class": region_class,
            "is_ambiguous": is_ambiguous,
            "is_unmapped": is_unmapped,
            "source": row.get("source", ""),
        }

        # keep any extra columns from the original AllPeaks
        for col in allpeaks_df.columns:
            if col not in rec and col not in ["PeakCenter"]:
                rec[col] = row[col]

        records.append(rec)

    annotated_df = pd.DataFrame.from_records(records)
    return annotated_df


In [None]:
#@title 6. Run annotation and preview annotated peaks

annotated_df = annotate_peaks(
    allpeaks_df=allpeaks_df,
    genes_df=genes_df,
    gene_trees=gene_trees,
    tss_index=tss_index,
    prom_upstream=PROMOTER_UPSTREAM,
    prom_downstream=PROMOTER_DOWNSTREAM,
    downstream_max=DOWNSTREAM_MAX,
)

print("Annotated peaks:", annotated_df.shape)
annotated_df.head()


In [None]:
#@title 7. Split into Merged / Ambiguous_Unmapped and build QC tables

# Merged: peaks with a unique gene and not ambiguous/unmapped
merged_df = annotated_df[
    (~annotated_df["is_ambiguous"]) &
    (~annotated_df["is_unmapped"]) &
    (annotated_df["primary_gene_id"].notna())
].copy()

# Ambiguous + Unmapped
ambiguous_df = annotated_df[
    annotated_df["is_ambiguous"] | annotated_df["is_unmapped"]
].copy()

# Basic QC metrics
total_peaks = len(annotated_df)
n_merged = len(merged_df)
n_ambiguous = annotated_df["is_ambiguous"].sum()
n_unmapped = annotated_df["is_unmapped"].sum()

region_counts = (
    annotated_df["region_class"]
    .value_counts(dropna=False)
    .rename("count")
    .reset_index()
    .rename(columns={"index": "region_class"})
)

qc_rows = [
    {"metric": "total_peaks", "value": total_peaks},
    {"metric": "merged_unique_gene", "value": n_merged},
    {"metric": "ambiguous_peaks", "value": n_ambiguous},
    {"metric": "unmapped_peaks", "value": n_unmapped},
]

qc_df = pd.DataFrame(qc_rows)

print("QC summary:")
display(qc_df)
print("\nRegion class counts:")
display(region_counts)


In [None]:
#@title 8. Save results to Excel (Merged, Ambiguous_Unmapped, QC)

output_path = OUTPUT_DIR / "peaks_gene_tss_annotation.xlsx"

with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
    merged_df.to_excel(writer, sheet_name="Merged", index=False)
    ambiguous_df.to_excel(writer, sheet_name="Ambiguous_Unmapped", index=False)
    qc_df.to_excel(writer, sheet_name="QC", index=False)
    region_counts.to_excel(writer, sheet_name="QC_region_class", index=False)

print("Excel file saved to:", output_path)
