In [None]:
###functions
import pandas as pd
import pysam
import os
import subprocess
import numpy as np
from scipy.cluster import hierarchy
from scipy.spatial.distance import pdist
import matplotlib.pyplot as plt
import math
import tempfile
from typing import Dict, List, Tuple, Optional
from Bio import SeqIO
from Bio.Seq import Seq
from intervaltree import Interval, IntervalTree
import re
from collections import defaultdict

def filter_te_by_coverage(gff3_file, bam_file, output_file, coverage_threshold=10):
    gff = pd.read_csv(gff3_file, sep="\t", comment="#", header=None,
                      names=["seqid","source","type","start","end","score","strand","phase","attributes"])
    
    def parse_attributes(attr_str):
        d = {}
        for item in attr_str.split(";"):
            if "=" in item:
                k, v = item.split("=",1)
                d[k] = v
        return d
    
    gff['attr_dict'] = gff['attributes'].apply(parse_attributes)
    
    te_df = gff[gff['type'] == "transposable_element"].copy()
    ltr_df = gff[gff['type'] == "long_terminal_repeat"].copy()
    
    bam = pysam.AlignmentFile(bam_file, "rb")
    
    def avg_coverage(seqid, start, end):
        # pysam uses 0-based, half-open intervals
        cov = bam.count_coverage(seqid, start-1, end)
        # cov is tuple of 4 arrays (A,C,G,T), sum to get total coverage
        total_cov = [sum(base) for base in zip(*cov)]
        return sum(total_cov) / len(total_cov) if total_cov else 0
    
    te_df['avg_cov'] = te_df.apply(lambda row: avg_coverage(row['seqid'], row['start'], row['end']), axis=1)
    
    # Filter by coverage threshold
    te_pass = te_df[te_df['avg_cov'] >= coverage_threshold]
    
    te_ids = set(te_pass['attr_dict'].apply(lambda x: x.get("ID")))
    ltr_pass = ltr_df[ltr_df['attr_dict'].apply(lambda x: x.get("Parent") in te_ids)]
    
    filtered = pd.concat([te_pass, ltr_pass]).sort_values(['seqid','start'])
    
    def dict_to_attr(d):
        return ";".join([f"{k}={v}" for k,v in d.items()])
    
    filtered['attributes'] = filtered['attr_dict'].apply(dict_to_attr)
    
    filtered[['seqid','source','type','start','end','score','strand','phase','attributes']].to_csv(
        output_file, sep="\t", header=False, index=False
    )
    
    bam.close()

def compute_te_methylation(gff3_file, bedmethyl_file, output_file,
                           flank_size=2000, bins=50, strand_oriented=True):
    if not bedmethyl_file.endswith(".gz"):
        gz_file = bedmethyl_file + ".gz"
    else:
        gz_file = bedmethyl_file
    tbi_file = gz_file + ".tbi"

    if not os.path.exists(gz_file):
        print(f"[INFO] Compressing {bedmethyl_file} with bgzip...")
        subprocess.run(["bgzip", "-c", bedmethyl_file], stdout=open(gz_file, "wb"), check=True)

    if not os.path.exists(tbi_file):
        print(f"[INFO] Indexing {gz_file} with tabix...")
        subprocess.run(["tabix", "-p", "bed", gz_file], check=True)

    tbx = pysam.TabixFile(gz_file)

    tes = []
    with open(gff3_file) as f:
        for line in f:
            if line.startswith("#"):
                continue
            chrom, source, feature, start, end, score, strand, phase, attrs = line.strip().split("\t")
            start, end = int(start), int(end)
            attr_dict = {kv.split("=")[0]: kv.split("=")[1] for kv in attrs.split(";") if "=" in kv}
            te_id = attr_dict.get("ID")
            lineage = attr_dict.get("Lineage", "NA")
            tes.append((chrom, start, end, strand, te_id, lineage))

    results = []

    for chrom, te_start, te_end, strand, te_id, lineage in tes:
        if strand == "+":
            upstream_start, upstream_end = max(0, te_start - flank_size), te_start
            downstream_start, downstream_end = te_end, te_end + flank_size
        else:
            upstream_start, upstream_end = te_end, te_end + flank_size
            downstream_start, downstream_end = max(0, te_start - flank_size), te_start

        regions = [
            ("upstream", upstream_start, upstream_end),
            ("TE", te_start, te_end),
            ("downstream", downstream_start, downstream_end)
        ]

        global_bin = 1

        for pos_label, rstart, rend in regions:
            if rend <= rstart:
                continue

            bin_edges = np.linspace(rstart, rend, bins+1, dtype=int)
            bin_numbers = list(range(global_bin, global_bin + bins))

            if strand == "-" and strand_oriented:
                bin_numbers = bin_numbers[::-1]

            for (bstart, bend, bin_num) in zip(bin_edges[:-1], bin_edges[1:], bin_numbers):
                try:
                    records = tbx.fetch(chrom, bstart, bend)
                except ValueError:
                    continue

                meth_vals = []
                for rec in records:
                    fields = rec.split("\t")
                    percent = float(fields[10])
                    pos = int(fields[1])
                    if bstart <= pos < bend:
                        meth_vals.append(percent)

                avg_meth = np.mean(meth_vals) if meth_vals else 0.0
                results.append([te_id, bin_num, avg_meth, lineage, pos_label])

            global_bin += bins

    df = pd.DataFrame(results, columns=["TE_ID","bin_number","average_methylation","Lineage","position"])
    if output_file.endswith(".csv"):
        df.to_csv(output_file, index=False)
    else:
        df.to_csv(output_file, index=False, sep="\t")

def merge_methylation_files(cg_file, chg_file, chh_file, output_file):
    df_cg = pd.read_csv(cg_file, sep="\t")
    df_cg["context"] = "CG"

    df_chg = pd.read_csv(chg_file, sep="\t")
    df_chg["context"] = "CHG"

    df_chh = pd.read_csv(chh_file, sep="\t")
    df_chh["context"] = "CHH"

    df_all = pd.concat([df_cg, df_chg, df_chh], axis=0, ignore_index=True)

    df_all.to_csv(output_file, sep="\t", index=False)

def cluster_te_methylation_table(
        infile, lineage, contexts=None,
        metric='euclidean', cluster_method='average',
        n_clusters=4, max_na=10, eval_range=(2,10),
        save_clustered_file="clustered_TE_table.tsv"):

    df = pd.read_csv(infile, sep="\t")
    df_te = df[(df['Lineage']==lineage) & (df['Position']=="TE")]
    if df_te.empty:
        raise ValueError(f"No data for lineage {lineage} and Position==TE.")

    if contexts is None:
        contexts = sorted(df_te['Context'].unique())
    first_ctx = contexts[0]


    df_first = df_te[df_te['Context']==first_ctx]
    pivot = df_first.pivot(index='TE_ID', columns='Bin_num', values='Average_methylation_per_bin')

    pivot = pivot[pivot.isna().sum(axis=1) <= max_na]
    
    pivot_filled = pivot.fillna(0)

    first_col = pivot_filled.iloc[:,0].to_numpy()
    pivot_filled = pivot_filled.loc[(pivot_filled.to_numpy() != first_col[:, None]).any(axis=1)]
    if pivot_filled.empty:
        raise ValueError("No TE_IDs left after filtering.")

    data_mat = pivot_filled.values
    Z = hierarchy.linkage(pdist(data_mat, metric=metric), method=cluster_method)

    cluster_labels = hierarchy.fcluster(Z, t=n_clusters, criterion='maxclust')
    te_cluster_map = dict(zip(pivot_filled.index, cluster_labels))

    wss_list = []
    cluster_range = list(range(eval_range[0], eval_range[1]+1))
    for k in cluster_range:
        labels = hierarchy.fcluster(Z, t=k, criterion='maxclust')
        wss = sum(((data_mat[np.where(labels==cid)[0],:] -
                    data_mat[np.where(labels==cid)[0],:].mean(axis=0))**2).sum()
                  for cid in np.unique(labels))
        wss_list.append(wss)

    plt.figure(figsize=(6,4))
    plt.plot(cluster_range, wss_list, 'o-', color='blue')
    plt.xlabel("Number of clusters (k)")
    plt.ylabel("Within-cluster sum of squares (WSS)")
    plt.title(f"Elbow plot for TE methylation clustering ({lineage})")
    plt.show()

    out_df = []
    for ctx in contexts:
        df_ctx = df_te[df_te['Context']==ctx].copy()
        df_ctx = df_ctx[df_ctx['TE_ID'].isin(pivot_filled.index)]
        df_ctx['cluster'] = df_ctx['TE_ID'].map(te_cluster_map)
        out_df.append(df_ctx)

    out_df = pd.concat(out_df, ignore_index=True)
    out_df.to_csv(save_clustered_file, sep="\t", index=False)

def adjast_table_full(clustered_TE_txt, full_unclust_txt, out_txt):
    with open(clustered_TE_txt, 'r') as old, \
    open(full_unclust_txt, 'r') as old1, \
    open(out_txt, 'w') as new:
        new.write('TE_ID\tBin_num\tAverage_methylation_per_bin\tLineage\tPosition\tContext\tcluster\n')
        dict_clusters = {}
        for line in old:
            if 'TE_ID' in line:
                continue
            line = line.strip().split('\t')
            dict_clusters[line[0]] = line[-1]
        for line in old1:
            if 'TE_ID' in line:
                continue
            line = line.strip().split('\t')
            if line[0] in dict_clusters:
                apply_cluster = dict_clusters[line[0]]
                tmp = '\t'.join(line)
                new_line = f'{tmp}\t{apply_cluster}\n'
                new.write(new_line)
    !rm {clustered_TE_txt}

def parse_gff3_features(gff3_path: str):
    features = []
    with open(gff3_path) as fh:
        for line in fh:
            if line.startswith("#") or not line.strip():
                continue
            cols = line.rstrip("\n").split("\t")
            if len(cols) < 9:
                continue
            seqid, source, ftype, start, end, score, strand, phase, attrs_s = cols[:9]
            start = int(start)
            end = int(end)
            attrs = {}
            for part in attrs_s.split(";"):
                if "=" in part:
                    k, v = part.split("=", 1)
                    attrs[k] = v
            features.append({
                "seqid": seqid,
                "source": source,
                "type": ftype,
                "start": start,
                "end": end,
                "score": score,
                "strand": strand if strand in ("+", "-") else "+",
                "phase": phase,
                "attrs": attrs
            })
    return features

def mafft_align_two_seqs(seq_a: Tuple[str,str], seq_b: Tuple[str,str], mafft_exec: str = "mafft", mafft_options: Optional[List[str]] = None) -> Dict:
    """
    seq_a, seq_b: tuples (header, sequence)
    Returns dict with aligned sequences under keys 'a_aln' and 'b_aln'.
    Uses a temporary FASTA file and calls mafft --auto.
    """
    if mafft_options is None:
        mafft_options = ["--auto", "--quiet"]
    with tempfile.TemporaryDirectory() as td:
        in_fa = os.path.join(td, "input.fa")
        out_fa = os.path.join(td, "aligned.fa")
        with open(in_fa, "w") as fh:
            fh.write(f">{seq_a[0]}\n")
            fh.write(seq_a[1] + "\n")
            fh.write(f">{seq_b[0]}\n")
            fh.write(seq_b[1] + "\n")
        cmd = [mafft_exec] + mafft_options + [in_fa]
        try:
            with open(out_fa, "w") as outf:
                subprocess.run(cmd, check=True, stdout=outf, stderr=subprocess.DEVNULL)
        except subprocess.CalledProcessError as e:
            raise RuntimeError(f"MAFFT failed: {e}")
        records = list(SeqIO.parse(out_fa, "fasta"))
        if len(records) < 2:
            raise RuntimeError("MAFFT output did not contain two sequences.")
        return {"a_aln": str(records[0].seq).upper(), "b_aln": str(records[1].seq).upper()}

_TRANSITIONS = {("A","G"), ("G","A"), ("C","T"), ("T","C")}

def count_transitions_transversions(aln1: str, aln2: str) -> Tuple[int,int,int]:
    assert len(aln1) == len(aln2)
    n_ts = 0
    n_tv = 0
    n_valid = 0
    for a,b in zip(aln1, aln2):
        if a == "-" or b == "-":
            continue
        a_up = a.upper()
        b_up = b.upper()
        if a_up not in ("A","C","G","T") or b_up not in ("A","C","G","T"):
            continue
        if a_up == b_up:
            n_valid += 1
            continue
        n_valid += 1
        if (a_up, b_up) in _TRANSITIONS:
            n_ts += 1
        else:
            n_tv += 1
    return n_ts, n_tv, n_valid

def compute_k2p_from_counts(n_ts: int, n_tv: int, n_sites: int) -> Optional[float]:
    if n_sites <= 0:
        return None
    P = n_ts / n_sites
    Q = n_tv / n_sites
    # K2P formula: K = -1/2 * ln(1 - 2P - Q) - 1/4 * ln(1 - 2Q)
    a = 1.0 - 2.0*P - Q
    b = 1.0 - 2.0*Q
    try:
        if a > 0 and b > 0:
            K = -0.5 * math.log(a) - 0.25 * math.log(b)
            return K
        else:
            d = P + Q
            if d >= 0 and d < 0.75:
                K_jc = -3.0/4.0 * math.log(1 - (4.0/3.0)*d)
                return K_jc
            else:
                return None
    except ValueError:
        return None

def compute_ltr_k2p_and_age(genome_fasta: str,
                            gff3_path: str,
                            mutation_rate_per_year: float,
                            mafft_exec: str = "mafft",
                            mafft_options: Optional[List[str]] = None,
                            lineage_filter: Optional[List[str]] = None) -> pd.DataFrame:

    if mafft_options is None:
        mafft_options = ["--auto", "--quiet"]
    genome = SeqIO.to_dict(SeqIO.parse(genome_fasta, "fasta"))
    feats = parse_gff3_features(gff3_path)
    te_features = {}
    ltr_by_parent = {}
    for f in feats:
        ftype = f["type"]
        attrs = f["attrs"]
        if ftype == "transposable_element":
            te_id = attrs.get("ID")
            if not te_id:
                te_id = f"{f['seqid']}:{f['start']}-{f['end']}"
            te_features[te_id] = f
        elif ftype == "long_terminal_repeat":
            parent = attrs.get("Parent")
            if not parent:
                parent = attrs.get("Name")
            if not parent:
                continue
            ltr_by_parent.setdefault(parent, []).append(f)
    rows = []
    for te_id, te_feat in te_features.items():
        if lineage_filter:
            if te_feat["attrs"].get("Lineage") not in lineage_filter:
                continue
        ltrs = ltr_by_parent.get(te_id, [])
        if len(ltrs) < 2:
            rows.append({
                "TE_ID": te_id,
                "seqid": te_feat["seqid"],
                "LTR1_coords": None,
                "LTR2_coords": None,
                "K2P": None,
                "Age_Mya": None,
                "n_sites": 0,
                "n_ts": 0,
                "n_tv": 0,
                "note": "less_than_two_ltrs"
            })
            continue
        ltrs_sorted = sorted(ltrs, key=lambda x: x["start"])
        ltr1 = ltrs_sorted[0]
        ltr2 = ltrs_sorted[-1]
        seqid = te_feat["seqid"]
        if seqid not in genome:
            rows.append({
                "TE_ID": te_id,
                "seqid": seqid,
                "LTR1_coords": f"{ltr1['start']}-{ltr1['end']}",
                "LTR2_coords": f"{ltr2['start']}-{ltr2['end']}",
                "K2P": None,
                "Age_Mya": None,
                "n_sites": 0,
                "n_ts": 0,
                "n_tv": 0,
                "note": "seqid_not_in_fasta"
            })
            continue
        strand = te_feat.get("strand", "+")
        seq_ltr1 = genome[seqid].seq[ltr1["start"] - 1 : ltr1["end"]]
        seq_ltr2 = genome[seqid].seq[ltr2["start"] - 1 : ltr2["end"]]
        if strand == "-":
            seq_ltr1 = seq_ltr1.reverse_complement()
            seq_ltr2 = seq_ltr2.reverse_complement()
        # convert to plain strings
        s1 = str(seq_ltr1).upper()
        s2 = str(seq_ltr2).upper()

        try:
            aln = mafft_align_two_seqs((f"{te_id}_LTR1", s1), (f"{te_id}_LTR2", s2), mafft_exec=mafft_exec, mafft_options=mafft_options)
            a_aln = aln["a_aln"]
            b_aln = aln["b_aln"]
        except Exception as e:
            rows.append({
                "TE_ID": te_id,
                "seqid": seqid,
                "LTR1_coords": f"{ltr1['start']}-{ltr1['end']}",
                "LTR2_coords": f"{ltr2['start']}-{ltr2['end']}",
                "K2P": None,
                "Age_Mya": None,
                "n_sites": 0,
                "n_ts": 0,
                "n_tv": 0,
                "note": f"mafft_failed:{e}"
            })
            continue
            
        n_ts, n_tv, n_sites = count_transitions_transversions(a_aln, b_aln)
        K = compute_k2p_from_counts(n_ts, n_tv, n_sites)
        if K is None:
            age_mya = None
            note = "k2p_unestimable"
        else:
            age_years = K / (2.0 * float(mutation_rate_per_year))
            age_mya = age_years / 1e6
            note = "ok"
        rows.append({
            "TE_ID": te_id,
            "seqid": seqid,
            "LTR1_coords": f"{ltr1['start']}-{ltr1['end']}",
            "LTR2_coords": f"{ltr2['start']}-{ltr2['end']}",
            "K2P": K,
            "Age_Mya": age_mya,
            "n_sites": n_sites,
            "n_ts": n_ts,
            "n_tv": n_tv,
            "note": note
        })
    df = pd.DataFrame(rows, columns=["TE_ID","seqid","LTR1_coords","LTR2_coords","K2P","Age_Mya","n_sites","n_ts","n_tv","note"])
    return df

def intersect_genes_tes(gene_gff, te_gff, output_file):
    def parse_gff(file, feature_type):
        data = []
        with open(file) as f:
            for line in f:
                if line.startswith("#"):
                    continue
                parts = line.strip().split("\t")
                if len(parts) < 9:
                    continue
                if parts[2] != feature_type:
                    continue
                chrom = parts[0]
                start = int(parts[3])
                end = int(parts[4])
                attrs = {kv.split("=")[0]: kv.split("=")[1] for kv in parts[8].split(";") if "=" in kv}
                ID = attrs.get("ID", "NA")
                data.append({"chrom": chrom, "start": start, "end": end, "ID": ID})
        return pd.DataFrame(data)

    genes = parse_gff(gene_gff, "gene")
    tes = parse_gff(te_gff, "transposable_element")

    chrom_trees = {}
    for chrom in genes['chrom'].unique():
        tree = IntervalTree()
        for _, row in genes[genes['chrom'] == chrom].iterrows():
            tree[row['start']:row['end']+1] = row['ID']
        chrom_trees[chrom] = tree

    output = []
    for _, te in tes.iterrows():
        chrom = te['chrom']
        te_start, te_end = te['start'], te['end']
        te_id = te['ID']
        if chrom in chrom_trees:
            overlaps = chrom_trees[chrom].overlap(te_start, te_end+1)
            if overlaps:
                for gene in overlaps:
                    output.append({"TE_ID": te_id, "type": "genic", "gene": gene.data})
            else:
                output.append({"TE_ID": te_id, "type": "non_genic", "gene": "NA"})
        else:
            output.append({"TE_ID": te_id, "type": "non_genic", "gene": "NA"})

    out_df = pd.DataFrame(output)
    out_df = out_df.drop_duplicates(subset=["TE_ID", "gene"])
    out_df.to_csv(output_file, sep="\t", index=False)

def calculate_te_gene_distance_from_files(gene_gff: str, te_gff: str):
    colnames = ["seqid", "source", "type", "start", "end", "score", "strand", "phase", "attributes"]

    genes = pd.read_csv(gene_gff, sep="\t", names=colnames, comment="#", dtype={"seqid": str})
    tes = pd.read_csv(te_gff, sep="\t", names=colnames, comment="#", dtype={"seqid": str})

    genes = genes[genes["type"] == "gene"].copy()
    tes = tes[tes["type"] == "transposable_element"].copy()

    genes["gene_id"] = genes["attributes"].str.extract(r'ID=([^;]+)')
    tes["te_id"] = tes["attributes"].str.extract(r'ID=([^;]+)')

    results = []

    for _, te in tes.iterrows():
        te_start, te_end = te["start"], te["end"]
        seqid = te["seqid"]

        same_chr_genes = genes[genes["seqid"] == seqid]

        if same_chr_genes.empty:
            results.append((te["te_id"], None))
            continue

        distances = same_chr_genes.apply(
            lambda g: 0 if (te_end >= g["start"] and te_start <= g["end"])
            else min(abs(te_start - g["end"]), abs(te_end - g["start"])),
            axis=1
        )

        min_distance = distances.min()
        results.append((te["te_id"], int(min_distance)))

    result_df = pd.DataFrame(results, columns=["TE_ID", "distance_to_gene"])
    return result_df

def calc_gc_percent(seq: str) -> float:
    seq = seq.upper()
    if len(seq) == 0:
        return 0.0
    g = seq.count("G")
    c = seq.count("C")
    return 100.0 * (g + c) / len(seq)

def te_cg_content(gff3_path, reference_fasta, outfile):
    genome = SeqIO.to_dict(SeqIO.parse(reference_fasta, "fasta"))
    pat_id = re.compile(r'ID=([^;\n]+)')
    pat_parent = re.compile(r'Parent=([^;\n]+)')

    tes = {}
    ltrs = defaultdict(list)

    with open(gff3_path) as fh:
        for ln in fh:
            if ln.startswith("#") or not ln.strip():
                continue
            parts = ln.strip().split("\t")
            if len(parts) < 9:
                continue
            seqid, _, ftype, start_s, end_s, _, _, _, attrs = parts
            start, end = int(start_s), int(end_s)
            m_id = pat_id.search(attrs)
            te_id = m_id.group(1) if m_id else None
            if ftype == "transposable_element" and te_id:
                tes[te_id] = (seqid, start, end)
            elif ftype == "long_terminal_repeat":
                m_parent = pat_parent.search(attrs)
                parent = m_parent.group(1) if m_parent else None
                if parent:
                    ltrs[parent].append((start, end))

    with open(outfile, "w") as outf:
        outf.write("TE_ID\tCG_whole\tCG_body\tCG_LTR\n")

        for te_id, (seqid, start, end) in tes.items():
            seq = str(genome[seqid].seq[start-1:end])
            cg_whole = calc_gc_percent(seq)
            if te_id in ltrs and len(ltrs[te_id]) >= 2:
                ltr1, ltr2 = sorted(ltrs[te_id])[:2]
                ltr_seq1 = str(genome[seqid].seq[ltr1[0]-1:ltr1[1]])
                ltr_seq2 = str(genome[seqid].seq[ltr2[0]-1:ltr2[1]])
                cg_ltr = (calc_gc_percent(ltr_seq1) + calc_gc_percent(ltr_seq2)) / 2.0
                mask = [True]*(end-start+1)
                for s,e in [ltr1,ltr2]:
                    for i in range(s-start, e-start+1):
                        if 0 <= i < len(mask):
                            mask[i] = False
                body_seq = ''.join(b for i,b in enumerate(seq) if mask[i])
                cg_body = calc_gc_percent(body_seq)
            else:
                cg_ltr, cg_body = 0.0, 0.0

            outf.write(f"{te_id}\t{round(cg_whole,2)}\t{round(cg_body,2)}\t{round(cg_ltr,2)}\n")

def compute_locus_methylation(bed_file, bedmethyl_file, output_prefix, bin_size=100):
    if not bedmethyl_file.endswith(".gz"):
        gz_file = bedmethyl_file + ".gz"
    else:
        gz_file = bedmethyl_file
    tbi_file = gz_file + ".tbi"

    if not os.path.exists(gz_file):
        print(f"[INFO] Compressing {bedmethyl_file}...")
        subprocess.run(["bgzip", "-c", bedmethyl_file], stdout=open(gz_file, "wb"), check=True)

    if not os.path.exists(tbi_file):
        print(f"[INFO] Indexing {gz_file}...")
        subprocess.run(["tabix", "-p", "bed", gz_file], check=True)

    tbx = pysam.TabixFile(gz_file)

    # --- Load loci BED ---
    loci = pd.read_csv(bed_file, sep="\t", header=None, names=["chrom", "start", "end", "name"])

    results = []

    def region_methylation(chrom, start, end):
        meth_vals = []
        try:
            for rec in tbx.fetch(chrom, start, end):
                fields = rec.split("\t")
                try:
                    percent = float(fields[10])  # methylation %
                except ValueError:
                    continue
                pos = int(fields[1])
                if start <= pos < end:
                    meth_vals.append(percent)
        except ValueError:
            return np.nan
        return np.mean(meth_vals) if meth_vals else np.nan

    for _, row in loci.iterrows():
        chrom, start, end, name = row["chrom"], row["start"], row["end"], row["name"]
        bins = np.arange(start, end, bin_size)
        for i, bstart in enumerate(bins, 1):
            bend = min(bstart + bin_size, end)
            avg_meth = region_methylation(chrom, bstart, bend)
            results.append([name, chrom, bstart, bend, i, avg_meth])

    out_df = pd.DataFrame(results, columns=["locus_name", "chrom", "start", "end", "bin", "avg_methylation"])
    out_df["avg_methylation"] = out_df["avg_methylation"].fillna(0.0)
    outfile = f"{output_prefix}_binned_methylation.txt"
    out_df.to_csv(outfile, sep="\t", index=False)

In [None]:
#filtering out supplementary, secondary alignments and unmapped reads, sorting and indexing
!samtools view -b -@ 100 -F 0x904 aligned.bam > \
aligned.primary.bam
!samtools sort -@ 100 -o aligned.primary.sort.bam \
aligned.primary.bam
!samtools index -@ 100 aligned.primary.sort.bam

In [None]:
#methylation calling

In [None]:
#reference genome context calling
#CG
!modkit motif bed GCF_002127325.2_HanXRQr2.0-SUNRISE_genomic.fna CG 0 > \
SUNRISE_genomic.CG.bed
#CHG
!modkit motif bed GCF_002127325.2_HanXRQr2.0-SUNRISE_genomic.fna CHG 0 > \
SUNRISE_genomic.CHG.bed
#CHH
!modkit motif bed GCF_002127325.2_HanXRQr2.0-SUNRISE_genomic.fna CHH 0 > \
SUNRISE_genomic.CHH.bed

In [None]:
#mC calling
!modkit pileup -t 100 --filter-threshold C:0.7 --ignore h \
aligned.primary.sort.bam \
aligned.primary.meth.bed
#filtering out mC calls with less than 3 Nmod
!awk '{if ($10 >= 3) print $0}' aligned.primary.meth.bed > \
aligned.primary.meth.filtered.bed

In [None]:
#methylation in contexts calling
!bedtools intersect -sorted -a aligned.primary.meth.filtered.bed \
-b SUNRISE_genomic.CG.bed > \
SUNRISE_genomic.CG.meth.bed
!bedtools intersect -sorted -a aligned.primary.meth.filtered.bed \
-b SUNRISE_genomic.CHG.bed > \
SUNRISE_genomic.CHG.meth.bed
!bedtools intersect -sorted -a aligned.primary.meth.filtered.bed \
-b SUNRISE_genomic.CHH.bed > \
SUNRISE_genomic.CHH.meth.bed

In [None]:
#filtering by coverage
filter_te_by_coverage('final_annotations/sunflower_LTR_RT.gff3', 
                      'aligned.primary.sort.bam', 
                      'sunflower_LTR_RT.filtered.gff3')

In [None]:
###methylation matrix calculation

In [None]:
!grep 'dante_ltr	transposable_element' \
sunflower_LTR_RT.filtered.gff3 > \
sunflower_LTR_RT.filtered.transposable_element.gff3

In [None]:
compute_te_methylation(
    gff3_file="sunflower_LTR_RT.filtered.transposable_element.gff3",
    bedmethyl_file="SUNRISE_genomic.CG.meth.bed",
output_file="sunflower_LTR_RT.CG.meth.transposable_element.txt",
    flank_size=2000,
    bins=50
)

compute_te_methylation(
    gff3_file="sunflower_LTR_RT.filtered.transposable_element.gff3",
    bedmethyl_file="SUNRISE_genomic.CHG.meth.bed",
output_file="sunflower_LTR_RT.CHG.meth.transposable_element.txt",
    flank_size=2000,
    bins=50
)

compute_te_methylation(
    gff3_file="sunflower_LTR_RT.filtered.transposable_element.gff3",
    bedmethyl_file="SUNRISE_genomic.CHH.meth.bed",
output_file="sunflower_LTR_RT.CHH.meth.transposable_element.txt",
    flank_size=2000,
    bins=50
)

In [None]:
###calculate clusters per lineage

In [None]:
merge_methylation_files('sunflower_LTR_RT.CG.meth.transposable_element.txt', 
                        'sunflower_LTR_RT.CHG.meth.transposable_element.txt', 
                        'sunflower_LTR_RT.CHH.meth.transposable_element.txt', 
                        'sunflower_LTR_RT.all_contexts.meth.transposable_element.txt')

In [None]:
dict_n_clust = {
    'Ale' : 2,
    'Angela' : 2,
    'Athila' : 3,
    'Bianca' : 1,
    'Ikeros' : 1,
    'Ivana' : 2,
    'Retand' : 3,
    'SIRE' : 1,
    'TAR' : 3,
    'Tekay' : 4,
    'Tork' : 2
}

In [None]:
for lineage in dict_n_clust:
    cluster_te_methylation_table(
        'docker_sandbox/sources/sunflower_TE_annot/meth_binned_tables/sunflower_LTR_RT.all_contexts.meth1.transposable_element.txt', 
        lineage,
        metric='correlation', cluster_method='average',
        n_clusters=dict_n_clust[lineage], max_na=10, eval_range=(2,10),
        save_clustered_file=f"docker_sandbox/sources/sunflower_TE_annot/meth_binned_tables/sunflower_LTR_RT.all_contexts.chrom_type.{lineage}.txt.tmp")
        adjast_table_full(f'sunflower_LTR_RT.all_contexts.chrom_type.{lineage}.txt.tmp', 
                           'sunflower_LTR_RT.all_contexts.meth.transposable_element.txt', 
                          f'sunflower_LTR_RT.all_contexts.chrom_type.{lineage}.txt')

In [None]:
dict_n_clust_ara = {
    'Ale' : 3
    'Tork' : 3
}

In [None]:
for lineage in dict_n_clust_ara:
    cluster_te_methylation_table(
        'docker_sandbox/sources/sunflower_TE_annot/AT_WT/SQK-NBD114-96_barcode24.all.fix.txt', 
        lineage,
        metric='euclidean', cluster_method='ward',
        n_clusters=dict_n_clust_ara[lineage], max_na=10, eval_range=(2,10),
        save_clustered_file=f"docker_sandbox/sources/sunflower_TE_annot/AT_WT/SQK-NBD114-96_barcode24.all.{lineage}.txt.tmp")
        adjast_table_full(f'docker_sandbox/sources/sunflower_TE_annot/AT_WT/SQK-NBD114-96_barcode24.all.{lineage}.txt.tmp', 
                           'docker_sandbox/sources/sunflower_TE_annot/AT_WT/SQK-NBD114-96_barcode24.all.fix.txt', 
                          f'docker_sandbox/sources/sunflower_TE_annot/AT_WT/SQK-NBD114-96_barcode24.all.{lineage}.txt')

In [None]:
#compute methylation levels per specific region

In [None]:
compute_locus_methylation('selected.bed',
                          'SUNRISE_genomic.CG.meth1.bed', 
                          "selected.CG.100bp.txt")
compute_locus_methylation('selected.bed',
                          'SUNRISE_genomic.CHG.meth1.bed', 
                          "selected.CHG.100bp.txt")
compute_locus_methylation('selected.bed',
                          'SUNRISE_genomic.CHH.meth1.bed', 
                          "selected.CHH.100bp.txt")