# Notebook for compiling and analyzing fiber-seq peaks and chip-seq peaks together

In [1]:
import importlib
import pandas as pd
from multiprocessing import Pool, cpu_count # used for parallel processing
import subprocess
import os
import multiprocessing
import tempfile
from plotly.subplots import make_subplots
from qnorm import quantile_normalize
import numpy as np
# import go library
import plotly.graph_objects as go
from pathlib import Path
import nanotools
importlib.reload(nanotools) # reload nanotools module
# import numpy as np
# import plotly.io as pio
# import plotly
# import plotly.express as px # Used for plotting
# import plotly.graph_objects as go # Used for plotting
# from plotly.subplots import make_subplots
# import pywt # for wavelet transform
# import matplotlib.pyplot as plt # Use for plotting m6A frac and coverage plot
# from matplotlib import cm # Use for plotting m6A frac and coverage plot
# from qnorm import quantile_normalize
# #import tqdm
# #import pysam
# #import pyBigWig

# import other standard libraries

# print date and time



<module 'nanotools' from '/Data1/git/meyer-nanopore/scripts/analysis/nanotools.py'>

In [2]:
 ### BAM Configurations
R9_m6A_thresh_percent = 0.7
R10_m6A_thresh_percent = 0.8
R10_5mC_thresh_percent = 0.8  # Note: 0.7 in R9 ~ 0.9 in R10
R9_m6A_thresh = int(
    round(R9_m6A_thresh_percent * 258, 0))  #default is 129 = 50%; 181=70%; 194=75%; 207 = 80%; 232 = 90%
m6A_thresh = int(round(R10_m6A_thresh_percent * 258, 0))
mC_thresh = int(round(R10_5mC_thresh_percent * 258, 0))

# modkit is used for aggregating methylation data from .bam files
# https://nanoporetech.github.io/modkit/quick_start.html
modkit_path = "/Data1/software/modkit_v0.3/modkit"
bedgraphtobigwig_path = "/Data1/software/ucsc_genome_browser/bedGraphToBigWig"
danpos_path = "/Data1/software/DANPOS3/danpos.py"
chrom_sizes = "/Data1/reference/chrom.sizes.ce11.txt"

# analysis_cond = [
#     "N2_mixed_endogenous_R10",
#     "N2_old_fiber_R10",
#     "SDC2degron_old_R10",
#     "DPY27_degron_old_R10",
#     "DPY21null_old_fiber_R10",
# ]

compare_type = "COND" # "COND" or "BATCH"

analysis_cond = [
    "N2_old_R10",
    "N2_young_R10",
    "N2_mid_R10",
    "SDC2deg_comb_R10",
    "SDC3deg_old_R10",
    "DPY27deg_comb_R10",
    "DPY21null_comb_R10",

    #"N2_rep1",
    #"N2_rep2",
    #"N2_rep3",
    #"SDC2deg_rep1",
    #"SDC2deg_rep2",
    #"SDC2deg_rep3",

    # "N2_young_SMACseq_R10",
    #"rex1_MEXIICtoG_R10",
    #"rex1_MEXIIscramble_R10"
    # "rex1_MEXIIscramble_R10",
    # "rex1_MEXIICtoG_R10"
    #"DPY27deg_rep1",
    #"DPY27deg_rep2",
    #"DPY27deg_rep3",
    #"rex1_MEXIIscramble_biorep0_fiber_old_R10_04_2025",
    # "rex1_MEXIIscramble_biorep1_fiber_old_R10_04_2025",
    #"rex1_MEXIIscramble_biorep2_fiber_old_R10_04_2025",
    #"rex1_MEXIIscramble_biorep3_fiber_old_R10_04_2025",
    #"rex1_4thCtoG_biorep0_fiber_old_R10_04_2025",
    # "rex1_4thCtoG_biorep1_fiber_old_R10_04_2025"]
#     "N2_rep1",
#     "SDC2deg_rep1",
#     "SDC2_degron_mid_T0_rep2_R10",
#     "SDC2_degron_mid_T0_rep3_R10",
#     "SDC2_degron_mid_T0p5_rep3_R10",
#     "SDC2_degron_mid_T1_rep3_R10",
#     "SDC2_degron_mid_T1p5_rep3_R10",
#     "SDC2_degron_mid_T2_rep2_R10",
#     "SDC2_degron_mid_T3_rep2_R10",
#     "SDC2_degron_mid_T4_rep3_R10"
]



#
# analysis_cond = [
#     "N2_mixed_endogenous_R10",
#     "N2_old_R10",
#     "SDC2deg_old_R10",
#     "DPY27deg_old_R10",
#     "DPY21null_old_R10",
# ]

# analysis_cond = [
#     #"N2_mixed_endogenous_R10",
#     "N2_fiber_old_R10_04_2025",
#     "SDC2deg_fiber_old_R10_04_2025",
#     "SDC3deg_fiber_old_R10_04_2025",
#     "SDC2_3deg_fiber_old_R10_04_2025",
#     "DPY27deg_fiber_old_R10_04_2025",
#     "DPY21null_fiber_old_R10_04_2025",
#     "96_DPY27degron_rep1_fiber_old_R10_052424",
# ]

### IMPORT BAM FILES AND METADATA FROM CSV FILE
if compare_type == "COND":
    input_metadata = pd.read_csv("/Data1/git/meyer-nanopore/scripts/bam_input_metadata_8_18_2025_COND_estim.txt", sep="\t", header=0)
else:
    input_metadata = pd.read_csv("/Data1/git/meyer-nanopore/scripts/bam_input_metadata_4_28_2025_BATCH.txt", sep="\t", header=0)
# Set bam_files equal to list of items in column bam_files where conditions == N2_fiber
bam_files = input_metadata[input_metadata["conditions"].isin(analysis_cond)]["bam_files"].tolist()
ft_files = [x.replace(".bam", "_ftools0p8.bed") for x in bam_files]

conditions = input_metadata[input_metadata["conditions"].isin(analysis_cond)]["conditions"].tolist()
exp_ids = input_metadata[input_metadata["conditions"].isin(analysis_cond)]["exp_id_date"].tolist()
flowcells = input_metadata[input_metadata["conditions"].isin(analysis_cond)]["flowcell"].tolist()
bam_fracs = len(bam_files) * [1]  # For full .bam set to = 1
sample_indices = list(range(len(bam_files)))


thresh_list = len(bam_files) * [m6A_thresh / 258]  # For R10 flow cells use 0.5; for R9 flow cells use 0.9
# for position in flowcells == R9 set item with same index in thresh_list to R9_m6A_thresh/258
for i in range(len(flowcells)):
    if "R9" in flowcells[i]:
        thresh_list[i] = R9_m6A_thresh / 258

file_prefix = "08192025"
output_stem = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/08192025/"

# for backwards compatibility
exp_ids = input_metadata[input_metadata["conditions"].isin(analysis_cond)]["exp_id_date"].tolist()
ext_exp_ids = ["None"] * len(exp_ids)

# Subsample bam based on bam_frac, used to accelerate testing
# if bam_frac = 1 will use original bam files, otherwise will save new subsampled bam files to output_stem.

new_bam_files = []
new_bam_files = nanotools.parallel_subsample_bam(bam_files, conditions, bam_fracs, sample_indices, output_stem)


# args_list = [(bam_file, condition, bam_frac, sample_index, output_stem, exp_id) for
#              bam_file, condition, bam_frac, sample_index, exp_id in
#              zip(bam_files, conditions, bam_fracs, sample_indices, ext_exp_ids)]

args_list = [(bam_file, condition, bam_frac, sample_index, output_stem,exp_id) for bam_file, condition, bam_frac, sample_index, exp_id in zip(ft_files,conditions,bam_fracs,sample_indices,exp_ids)]

print("Program finished!")
print("new_bam_files: ", new_bam_files)
print("exp_ids: ", exp_ids)
print("conditions: ", conditions)

Program finished!
new_bam_files:  ['/Data1/seq_data/AG2_51dpy21null_fiberseq_1_16_24/no_sample/20240116_2226_X1_FAX30165_d8fc11b9/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/BN_96DPY27Deg_Fiber_Hia5_MCviPI_05_24_24/no_sample/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/BM_N2_old_Fiber_Hia5_MCVIPI_05_30_24/no_sample/20240530_1930_X2_FAX32001_56dbbc37/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/BS_N2_old_Hia5_MCVIPI_10_17_2024/no_sample/20241018_1815_X2_FAX31996_a75b7b27/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/BT_96DPY27Deg_old_Hia5_MCviPI_10_17_2024/no_sample/20241018_1752_X1_FAX20286_da16b0d3/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/CA_107_SDC2_degron_OLD_20250213/no_sample_id/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/CE1_1-14_19-22_fiber_2025_04_09/basecalls/SQK-NBD114-24_barcode02.sorted.bam', '/Data1/seq_data/CE1_1-14_19-22_fiber_2025_04_09/basecalls/SQK-NBD114-24_barcode03.sorted.bam', '/Data1/seq_data/CE1_1-14_19-22_fib

In [3]:
### Bed file configurations:
sample_source = "chr_type" # "chr_type" or "type" or "chromosome"
chr_type_selected = ["X", "Autosome"]
# type_selected = "rex1" all the way through "rex48"

#type_selected = ["univ_nuc","rex48","rex35","rex40","rex34","rex33","rex8","rex43","rex45","rex23","rex32","rex16","rex41","rex14","rex2","Prex7","Prex30","rex47"] #"dnase_nfr_intergenic",

#

type_selected = ["univ_nuc"]#,'nDCC1', 'nDCC10', 'nDCC11', 'nDCC12', 'nDCC13', 'nDCC14', 'nDCC15', 'nDCC16', 'nDCC17', 'nDCC18', 'nDCC19', 'nDCC2', 'nDCC20', 'nDCC21', 'nDCC22', 'nDCC23', 'nDCC3', 'nDCC4', 'nDCC5', 'nDCC6', 'nDCC7', 'nDCC8', 'nDCC9', 'Prex1', 'Prex11', 'Prex14', 'Prex15', 'Prex16', 'Prex2', 'Prex20', 'Prex21', 'Prex22', 'Prex23', 'Prex25', 'Prex26', 'Prex27', 'Prex3', 'Prex30', 'Prex31', 'Prex6', 'Prex7', 'rex1', 'rex14', 'rex16', 'rex17', 'rex18', 'rex19', 'rex2', 'rex20', 'rex21', 'rex23', 'rex24', 'rex25', 'rex26', 'rex27', 'rex28', 'rex29', 'rex3', 'rex31', 'rex32', 'rex33', 'rex34', 'rex35', 'rex36', 'rex37', 'rex38', 'rex39', 'rex4', 'rex40', 'rex41', 'rex42', 'rex43', 'rex44', 'rex45', 'rex46', 'rex47', 'rex48', 'rex5', 'rex6', 'rex7', 'rex8']

#"LMN1_only_damID","EMR1_only_damID","her-1_FULL","sex-1_FULL"
#,"DPY27_chip_q2","DPY27_chip_q3","DPY27_chip_q4","intergenic_control"
#,"her-1_TSS","fem-1_TSS","fem-2_TSS","fem-3_TSS","sex-1_TSS"
# her-1_TSS/TES/FULL | TES_q1-4 | #TSS_q1-4 | strong/weak rex | whole_chr | 200kb_region | 50kb_region | center_DPY27_chip_albretton | gene_q1-q4 | MEX_motif | center_SDC2_chip_albretton | center_SDC3_chip_albretton |
# ATAC_seq_EXCL_dpy27_ol100 | ATAC_seq;DPY27_ol100
# "motif_weak_dcc_1","motif_weak_dcc_2","motif_weak_dcc_3","motif_weak_dcc_4","motif_weak_dcc_5","motif_weak_dcc_6","motif_weak_dcc_7","motif_weak_dcc_8","motif_weak_dcc_9","motif_weak_dcc_10"

max_regions = 100 # max regions to consider; 0 = full set;
chromosome_selected = ["CHROMOSOME_X","CHROMOSOME_V"]#,"CHROMOSOME_I", "CHROMOSOME_II", "CHROMOSOME_III", "CHROMOSOME_IV"]
strand_selected = ["+","-"] #+ and/or -
select_opp_strand = True #If you want to select both + and - strands for all regions set to True
down_sample_autosome = False # If you want to downsample autosome genes to match number of X genes set to True
if chr_type_selected == ["X"]:
    down_sample_autosome = False
bed_file = "/Data1/reference/tss_tes_rex_combined_v30_WS235.bed"
bed_window = 2000   # +/- around bed elements.
intergenic_window = 2000 # +/- around intergenic regions
num_bins = 1000 #bins for metagene plot
mods = "a" # {A,CG,A+CG}
if sample_source == "chr_type":
    selection = chr_type_selected
if sample_source == "type":
    selection = type_selected
if sample_source == "chromosome":
    selection = chromosome_selected

# Filter input bed_file based on input parameters (e.g. chromosome, type, strand, etc.)
# Function saves a new filtered bed file to the same folder as the original bed file
# called temp_do_not_use_"type".bed
importlib.reload(nanotools)
new_bed_files=nanotools.filter_bed_file(
    bed_file,
    sample_source,
    selection,
    chromosome_selected,
    chr_type_selected,
    type_selected,
    strand_selected,
    max_regions,
    bed_window,
    intergenic_window
)

modkit_bed_name_ext = "modkit_temp_ext.bed"
modkit_bed_df_ext = nanotools.generate_modkit_bed(new_bed_files, down_sample_autosome, select_opp_strand,modkit_bed_name_ext)
nanotools.display_sample_rows(modkit_bed_df_ext, 5)

print("Program finished!")
print("exp_ids: ", exp_ids)

combined_bed_df_ext = nanotools.create_lookup_bed(new_bed_files)

nanotools.display_sample_rows(combined_bed_df_ext, 5)

# print count by type
print("Count by type: ", combined_bed_df_ext["type"].value_counts())


# for intergneic control for background methylation by condition plots
bed_path = Path(output_stem) / "intergenic_control.bed"
# # create output stem if it doesn't exist
output_dir = Path(output_stem)
output_dir.mkdir(parents=True, exist_ok=True)

# ────────────────────  BUILD include-bed FROM DataFrame  ───────────────────

if not bed_path.exists():
    if 'combined_bed_df_ext' not in globals():
        abort("combined_bed_df_ext DataFrame not found in workspace.")
    bed_df = (
        combined_bed_df_ext
        .loc[combined_bed_df_ext["type"] == "intergenic_control",
             ["chrom", "bed_start", "bed_end"]]
        .copy()
    )
    bed_df["bed_start"] = bed_df["bed_start"].astype(int)
    bed_df["bed_end"]   = bed_df["bed_end"].astype(int)
    bed_df.to_csv(bed_path, sep="\t", header=False, index=False)

# print all unique types
print("Unique types in combined_bed_df_ext: ", combined_bed_df_ext["type"].unique().tolist())


Filtering bed file...
Configs: chr_type ['X', 'Autosome'] ['CHROMOSOME_X', 'CHROMOSOME_V'] ['X', 'Autosome'] ['univ_nuc'] ['+', '-'] 100 2000 2000
Chromosome ends:             chromosome  start         end strand       type  chr-type
2         CHROMOSOME_I    0.0  15072434.0      +  whole_chr  Autosome
41396    CHROMOSOME_II    0.0  15279421.0      +  whole_chr  Autosome
86685   CHROMOSOME_III    0.0  13783801.0      +  whole_chr  Autosome
125206   CHROMOSOME_IV    0.0  17493829.0      +  whole_chr  Autosome
169772    CHROMOSOME_V    0.0  20924180.0      +  whole_chr  Autosome
227276    CHROMOSOME_X    0.0  17718942.0      +  whole_chr         X
Saved the following bedfiles: ['/Data1/reference/X.bed.gz', '/Data1/reference/Autosome.bed.gz']
Combined BED written to modkit_temp_ext.bed (400 rows)
| modkit_bed_df_ext | first 5, random 5 and last 5 out of total 400 rows.


Unnamed: 0,0,1,2,3,4,5
0,CHROMOSOME_X,171175.0,175175.0,+,univ_nuc,X
1,CHROMOSOME_X,201725.0,205725.0,+,univ_nuc,X
2,CHROMOSOME_X,223135.0,227135.0,+,univ_nuc,X
3,CHROMOSOME_X,319075.0,323075.0,+,univ_nuc,X
4,CHROMOSOME_X,507375.0,511375.0,+,univ_nuc,X
5,CHROMOSOME_V,20487375.0,20491375.0,+,univ_nuc,Autosome
6,CHROMOSOME_X,4525935.0,4529935.0,+,univ_nuc,X
7,CHROMOSOME_V,5620515.0,5624515.0,+,univ_nuc,Autosome
8,CHROMOSOME_V,7054675.0,7058675.0,+,univ_nuc,Autosome
9,CHROMOSOME_X,11703295.0,11707295.0,+,univ_nuc,X


Program finished!
exp_ids:  ['AG-23_11_30_23', 'BN_05_24_24_SMACseq_R10_rep1', 'BM_05_30_24_SMACseq_R10_rep1', 'BS_10_17_24_SMACseq_R10_rep2', 'BT_10_17_24_SMACseq_R10_rep2', 'CA_02_15_25', 'N2_biorep2_fiber_old_R10_04_2025_B', 'DPY21null_biorep1_fiber_old_R10_04_2025', 'DPY21null_biorep2_fiber_old_R10_04_2025', 'DPY27deg_biorep1_fiber_old_R10_04_2025_A', 'DPY27deg_biorep2_fiber_old_R10_04_2025_A', 'SDC2deg_biorep1_fiber_old_R10_04_2025_B', 'SDC2deg_biorep2_fiber_old_R10_04_2025_B', 'N2_biorep2_fiber_old_R10_04_2025_D', 'DPY27deg_biorep1_fiber_old_R10_04_2025_B', 'DPY27deg_biorep2_fiber_old_R10_04_2025_B', 'SDC2deg_biorep1_fiber_old_R10_04_2025_C', 'SDC2deg_biorep2_fiber_old_R10_04_2025_C', 'DPY21null_biorep1_fiber_old_R10_04_2025', 'N2_biorep1_fiber_young_R10_08_2025', 'N2_biorep1_fiber_young_R10_08_2025', 'N2_biorep1_fiber_mid_R10_08_2025', 'N2_biorep1_fiber_mid_R10_08_2025', 'SDC3deg_biorep1_fiber_old_R10_08_2025', 'SDC3deg_biorep1_fiber_old_R10_08_2025', 'N2_biorep1_fiber_young_R10

Unnamed: 0,chrom,bed_start,bed_end,bed_strand,type,chr_type
0,CHROMOSOME_V,91145.0,95145.0,+,univ_nuc,Autosome
1,CHROMOSOME_V,456325.0,460325.0,+,univ_nuc,Autosome
2,CHROMOSOME_V,795635.0,799635.0,+,univ_nuc,Autosome
3,CHROMOSOME_V,931065.0,935065.0,+,univ_nuc,Autosome
4,CHROMOSOME_V,969215.0,973215.0,+,univ_nuc,Autosome
5,CHROMOSOME_V,12519195.0,12523195.0,+,univ_nuc,Autosome
6,CHROMOSOME_V,7693645.0,7697645.0,+,univ_nuc,Autosome
7,CHROMOSOME_V,6160815.0,6164815.0,+,univ_nuc,Autosome
8,CHROMOSOME_X,223135.0,227135.0,+,univ_nuc,X
9,CHROMOSOME_X,14208735.0,14212735.0,+,univ_nuc,X


Count by type:  univ_nuc    200
Name: type, dtype: int64
Unique types in combined_bed_df_ext:  ['univ_nuc']


In [4]:
# ─────────────── Cell 2: full parity motif injection + exhaustive debug ───────────────
# Effects:
#   • Build motif rows from FIMO TSVs using LN(P) thresholds.
#   • Add both strands.
#   • Update in-memory: combined_bed_df_ext, combined_bed_df, modkit_bed_df_ext, modkit_bed_df.
#   • Overwrite filtered BEDs on disk (entries in `new_bed_files`) without calling filter_bed_file.
#   • Persist modkit *_ext and core BEDs to existing filenames.
#   • DOES NOT touch `modkit_path` (the modkit executable).

# ── Config ─────────────────────────────────────────────────────────────────────────────
add_motif_types = ["MEX_motif","MEXII_motif"]          # any subset of: "MEX_motif", "MEXII_motif", "motifC"; or None to skip
overwrite_existing_motif_types = True   # drop any existing rows of these motif types before insert
apply_bed_window_expansion = True       # expand motif intervals by ±bed_window prior to merge/write
DEBUG = True
SHOW = 8

import os, subprocess, importlib, time
from pathlib import Path
import numpy as np
import pandas as pd

# Expected from Cell 1:
#   combined_bed_df_ext, new_bed_files, bed_window, output_stem (optional),
#   save_modkit_bed_to_temp (optional), create_modkit_bed_df (optional)
#   modkit_bed_name_ext (filename for *_ext), modkit_bed_name (filename for core)
#   nanotools with create_lookup_bed()

MOTIF_TSV = {
    "MEX_motif":   "/Data1/ext_data/motifs/fimo_MEX_0.01.tsv",
    "MEXII_motif": "/Data1/ext_data/motifs/fimo_MEXII_0.01.tsv",
    "motifC":      "/Data1/ext_data/motifs/fimo_motifc_0.01.tsv",
}
LN_P_THRESH = {"MEX_motif": -14, "MEXII_motif": -14, "motifC": -9}

# On-disk filtered BED schema (what new_bed_files contain)
REQ = ["chromosome", "start", "end", "strand", "type", "chr-type"]

# ── Debug helpers ──────────────────────────────────────────────────────────────────────
def _dbg(msg, df=None):
    if not DEBUG: return
    print(f"[DBG] {msg}")
    if isinstance(df, pd.DataFrame):
        print(f"      shape={df.shape}")
        print(df.head(SHOW).to_string(index=False))

def _file_stat(p):
    try:
        st = Path(p).stat()
        return f"mtime={time.ctime(st.st_mtime)} size={st.st_size}"
    except Exception as e:
        return f"(stat err: {e})"

def _infer_chr_type(chrom):
    return "X" if chrom == "CHROMOSOME_X" else "Autosome"

# ── I/O helpers ────────────────────────────────────────────────────────────────────────
def _load_and_filter_fimo(tsv_path, ln_p_thresh):
    if not os.path.exists(tsv_path):
        print(f"[WARN] FIMO not found: {tsv_path}")
        return pd.DataFrame(columns=["chrom","bed_start","bed_end","strand"])
    raw = pd.read_csv(tsv_path, sep="\t", comment="#", skip_blank_lines=True)
    _dbg(f"FIMO raw: {os.path.basename(tsv_path)}", raw)
    need = ["sequence_name","start","stop","p-value","strand"]
    miss = [c for c in need if c not in raw.columns]
    if miss:
        print(f"[ERR] Missing columns in {tsv_path}: {miss}")
        return pd.DataFrame(columns=["chrom","bed_start","bed_end","strand"])

    raw["ln_p"] = np.log(raw["p-value"])
    before = len(raw)
    raw = raw[raw["ln_p"] <= ln_p_thresh].copy()
    print(f"[INFO] {os.path.basename(tsv_path)} rows: before={before} after_ln(p)={len(raw)}")

    raw["chrom"]     = raw["sequence_name"].astype(str).str.replace("chr","CHROMOSOME_", regex=False)
    raw["bed_start"] = pd.to_numeric(raw["start"], errors="coerce")
    raw["bed_end"]   = pd.to_numeric(raw["stop"],  errors="coerce")
    raw["strand"]    = raw["strand"].astype(str).str.strip().where(raw["strand"].isin(["+","-"]), ".")
    raw = raw.dropna(subset=["chrom","bed_start","bed_end"])
    out = raw[["chrom","bed_start","bed_end","strand"]].copy()
    _dbg("FIMO filtered→combined columns", out)
    return out

def _expand_windows(df, w):
    if not apply_bed_window_expansion or w is None: return df
    out = df.copy()
    w = int(w)
    out["bed_start"] = (out["bed_start"] - w).astype(np.int64)
    out["bed_end"]   = (out["bed_end"]   + w).astype(np.int64)
    out["bed_start"] = out["bed_start"].clip(lower=0)
    return out

def _combined_to_filtered_schema(df):
    """Map combined_bed_df_ext rows → filtered BED schema (REQ)."""
    out = df.rename(columns={
        "chrom":"chromosome", "bed_start":"start", "bed_end":"end",
        "bed_strand":"strand", "chr_type":"chr-type"
    }).copy()
    if "chr-type" not in out.columns:
        out["chr-type"] = out["chromosome"].map(_infer_chr_type)
    out["start"] = pd.to_numeric(out["start"], errors="coerce").astype(np.int64)
    out["end"]   = pd.to_numeric(out["end"],   errors="coerce").astype(np.int64)
    for c in ("chromosome","strand","type","chr-type"):
        if c in out: out[c] = out[c].astype(str)
    return out[REQ].copy()

def _read_filtered_bed(gz_path):
    df = pd.read_csv(gz_path, sep="\t", header=None, names=REQ)
    df["start"] = pd.to_numeric(df["start"], errors="coerce").astype(np.int64)
    df["end"]   = pd.to_numeric(df["end"],   errors="coerce").astype(np.int64)
    for c in ("chromosome","strand","type","chr-type"):
        df[c] = df[c].astype(str)
    return df

def _write_bgzip_tabix(df, gz_path):
    gz_path = str(gz_path)
    bed_path = gz_path[:-3] if gz_path.endswith(".gz") else gz_path + ".bed"
    df.to_csv(bed_path, sep="\t", header=False, index=False)
    with open(gz_path, "wb") as out_fh:
        subprocess.run(["bgzip","-c",bed_path], stdout=out_fh, stderr=subprocess.PIPE, check=False)
    subprocess.run(["tabix","-f","-p","bed",gz_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)

def _fallback_create_modkit_bed_df(filtered_df):
    base = filtered_df[["chrom","bed_start","bed_end"]].copy()
    base["bed_start"] = pd.to_numeric(base["bed_start"], errors="coerce").astype(np.int64)
    base["bed_end"]   = pd.to_numeric(base["bed_end"],   errors="coerce").astype(np.int64)
    df_plus  = pd.DataFrame({
        0: base["chrom"].values, 1: base["bed_start"].values, 2: base["bed_end"].values,
        3: np.full(len(base), ".", dtype=object),
        4: np.full(len(base), ".", dtype=object),
        5: np.full(len(base), "+", dtype=object),
    })
    df_minus = pd.DataFrame({
        0: base["chrom"].values, 1: base["bed_start"].values, 2: base["bed_end"].values,
        3: np.full(len(base), ".", dtype=object),
        4: np.full(len(base), ".", dtype=object),
        5: np.full(len(base), "-", dtype=object),
    })
    mod = pd.concat([df_plus, df_minus], ignore_index=True)
    mod.columns = [0,1,2,3,4,5]
    mod.drop_duplicates(inplace=True)
    return mod

# ── Input snapshot ─────────────────────────────────────────────────────────────────────
if "combined_bed_df_ext" not in globals():
    raise NameError("combined_bed_df_ext not found from Cell 1")
print("[IN] add_motif_types:", add_motif_types)
print("[IN] overwrite_existing_motif_types:", overwrite_existing_motif_types)
print("[IN] apply_bed_window_expansion:", apply_bed_window_expansion)
print("[IN] bed_window:", globals().get("bed_window","(missing)"))
print("[IN] combined_bed_df_ext rows:", len(combined_bed_df_ext))
print("[IN] new_bed_files:", new_bed_files if "new_bed_files" in globals() else "(missing)")
if "new_bed_files" in globals():
    for p in new_bed_files:
        print("     ", p, "→", _file_stat(p))
print("[IN] modkit_bed_name_ext:", globals().get("modkit_bed_name_ext","(missing)"))
print("[IN] modkit_bed_name:", globals().get("modkit_bed_name","(missing)"))
if isinstance(globals().get("modkit_bed_name_ext",None), str) and Path(globals()["modkit_bed_name_ext"]).exists():
    print("     existing modkit_bed_name_ext:", _file_stat(globals()["modkit_bed_name_ext"]))
if isinstance(globals().get("modkit_bed_name",None), str) and Path(globals()["modkit_bed_name"]).exists():
    print("     existing modkit_bed_name:    ", _file_stat(globals()["modkit_bed_name"]))

# Pre counts by type for motifs (in-memory)
pre_counts = combined_bed_df_ext[combined_bed_df_ext["type"].isin(add_motif_types)]["type"].value_counts()
print("[IN] existing motif rows (memory):")
print(pre_counts.to_string() if not pre_counts.empty else "(none)")

# ── Build candidate motif rows in combined schema ──────────────────────────────────────
if add_motif_types is None:
    print("add_motif_types=None → skipping.")
else:
    # ensure type_selected exists
    if 'type_selected' not in globals():
        type_selected = []
        print("[INFO] created missing type_selected list.")

    # snapshot
    before_len = len(type_selected)
    before_sample = selection if 'selection' in globals() else None
    print(f"[BEFORE] type_selected len={before_len} sample_source={globals().get('sample_source','(missing)')}")
    if before_len <= 20:
        print("         types:", type_selected)

    # add motifs if missing
    added = []
    for t in add_motif_types:
        if t not in type_selected:
            type_selected.append(t)
            added.append(t)
    if added:
        print(f"[UPDATED] appended to type_selected: {added}")
    else:
        print("[NOCHANGE] type_selected already contained motif types.")

    motif_blocks = []
    for key in add_motif_types:
        if key not in MOTIF_TSV:
            print(f"[WARN] Unknown motif key: {key} → skip")
            continue
        src = MOTIF_TSV[key]
        print(f"[STEP] Load FIMO for {key}: {src}  ({_file_stat(src) if os.path.exists(src) else 'missing'})  ln(p)≤{LN_P_THRESH[key]}")
        core = _load_and_filter_fimo(src, LN_P_THRESH[key])
        if core.empty:
            print(f"[INFO] {key}: 0 rows after filtering")
            continue
        core = _expand_windows(core, bed_window if "bed_window" in globals() else None)
        cand = core.rename(columns={"strand":"bed_strand"}).copy()  # keep original motif strand only
        cand["type"] = key

        if "chr_type" in combined_bed_df_ext.columns:
            cand["chr_type"] = cand["chrom"].map(_infer_chr_type)
        _dbg(f"{key} candidates (pre-merge)", cand)
        motif_blocks.append(cand)

    if not motif_blocks:
        print("[WRITE] No motif rows to add (after LN(P)/window).")
    else:
        motif_combined = pd.concat(motif_blocks, ignore_index=True)

        # ── In-memory update ───────────────────────────────────────────────────────────
        base_rows = len(combined_bed_df_ext)
        if overwrite_existing_motif_types:
            drop_n = int((combined_bed_df_ext["type"].isin(add_motif_types)).sum())
            print(f"[MEM] dropping existing motif rows: {drop_n}")
            combined_bed_df_ext = combined_bed_df_ext[~combined_bed_df_ext["type"].isin(add_motif_types)].reset_index(drop=True)

        # Anti-join with dtype harmony
        join_cols = ["chrom","bed_start","bed_end","bed_strand","type"] + (["chr_type"] if "chr_type" in combined_bed_df_ext.columns else [])
        existing_keys = combined_bed_df_ext[join_cols].copy()
        for c in ("bed_start","bed_end"):
            if c in existing_keys.columns: existing_keys[c] = pd.to_numeric(existing_keys[c], errors="coerce").astype("float64")
        for c in ("chrom","bed_strand","type","chr_type"):
            if c in existing_keys.columns: existing_keys[c] = existing_keys[c].astype(str)

        mleft = motif_combined[join_cols].copy()
        mleft["__rid__"] = np.arange(len(mleft))
        for c in ("bed_start","bed_end"):
            mleft[c] = pd.to_numeric(mleft[c], errors="coerce").astype("float64")
        for c in ("chrom","bed_strand","type","chr_type"):
            if c in mleft.columns: mleft[c] = mleft[c].astype(str)

        merged = mleft.merge(existing_keys, on=join_cols, how="left", indicator=True)
        keep_ids = merged.loc[merged["_merge"]=="left_only","__rid__"].tolist()
        motif_unique = motif_combined.iloc[keep_ids].copy()
        _dbg("motif_unique (memory-add)", motif_unique)

        if len(motif_unique):
            combined_bed_df_ext = pd.concat([combined_bed_df_ext, motif_unique], ignore_index=True)
            combined_bed_df_ext.drop_duplicates(subset=["chrom","bed_start","bed_end","bed_strand","type"], inplace=True)
            combined_bed_df_ext.reset_index(drop=True, inplace=True)

        # Report memory adds
        mem_delta = len(combined_bed_df_ext) - base_rows
        add_counts_mem = motif_unique["type"].value_counts() if len(motif_unique) else pd.Series(dtype=int)
        print(f"[MEM] combined_bed_df_ext rows(after)={len(combined_bed_df_ext)}  Δ={mem_delta}")
        for t in add_motif_types:
            print(f"[MEM] added {t}: {int(add_counts_mem.get(t, 0))}")

        # ── Disk parity: update each filtered BED in new_bed_files ─────────────────────
        if "new_bed_files" not in globals() or not new_bed_files:
            bed_dir = str(Path(globals().get("bed_file","/Data1/reference")).parent)
            new_bed_files = [f"{bed_dir}/X.bed.gz", f"{bed_dir}/Autosome.bed.gz"]
            print(f"[WARN] new_bed_files not found; defaulting to: {new_bed_files}")

        # Prepare motif rows in filtered schema for disk write
        motif_filtered = _combined_to_filtered_schema(
            motif_unique.rename(columns={"chr_type":"chr-type"}) if "chr_type" in motif_unique.columns else motif_unique
        )

        disk_add_counts = {t:0 for t in add_motif_types}
        for gz in new_bed_files:
            if not os.path.exists(gz):
                print(f"[DISK] missing, skip: {gz}")
                continue
            before_stat = _file_stat(gz)
            existing = _read_filtered_bed(gz)
            # counts before
            print(f"[DISK] BEFORE {gz}: rows={len(existing)}  {before_stat}")
            for t in add_motif_types:
                c = int((existing["type"]==t).sum())
                print(f"       motif {t}: {c}")

            # choose chr-type bucket for this file
            guess = "X" if "X" in Path(gz).stem else ("Autosome" if "Autosome" in Path(gz).stem else (existing["chr-type"].mode().iat[0] if not existing.empty else "Autosome"))
            to_add = motif_filtered[motif_filtered["chr-type"] == guess].copy()
            if to_add.empty:
                print(f"[DISK] {gz}: +0 (no rows for chr-type={guess})")
                continue
            # anti-join to avoid re-adding
            m = to_add.merge(existing, on=REQ, how="left", indicator=True)
            uniq = m[m["_merge"]=="left_only"][REQ].copy()
            _dbg(f"uniq rows destined for {gz}", uniq)
            if uniq.empty:
                print(f"[DISK] {gz}: +0 (all present)")
                continue

            added_per_type = uniq["type"].value_counts()
            for t,c in added_per_type.items():
                disk_add_counts[t] = disk_add_counts.get(t,0) + int(c)

            updated = pd.concat([existing, uniq], ignore_index=True)
            updated.sort_values(["chromosome","start","end","type","strand"], inplace=True, kind="mergesort")
            _write_bgzip_tabix(updated, gz)
            after_stat = _file_stat(gz)
            print(f"[DISK] AFTER  {gz}: rows={len(updated)}  {after_stat}")
            for t in add_motif_types:
                c = int((updated["type"]==t).sum())
                print(f"       motif {t}: {c}")

        # Per-type totals written to disk
        for t in add_motif_types:
            print(f"[DISK] added {t}: +{int(disk_add_counts.get(t,0))} rows")

        # ── Reload combined_bed_df_ext from disk to guarantee parity ───────────────────
        try:
            importlib.reload(nanotools)
            combined_bed_df_ext = nanotools.create_lookup_bed(new_bed_files)
            print(f"[SYNC] Reloaded combined_bed_df_ext from disk. rows={len(combined_bed_df_ext)}")
            post_counts = combined_bed_df_ext[combined_bed_df_ext["type"].isin(add_motif_types)]["type"].value_counts()
            print("[SYNC] motif counts in combined_bed_df_ext (from disk):")
            print(post_counts.to_string() if not post_counts.empty else "(none)")
        except Exception as e:
            print(f"[WARN] Could not rebuild combined_bed_df_ext from disk: {e}\nKeeping in-memory version.")

        # ── Mirror to combined_bed_df so downstream users of that name see same rows ───
        combined_bed_df = combined_bed_df_ext.copy()
        print(f"[SYNC] combined_bed_df set from combined_bed_df_ext. rows={len(combined_bed_df)}")

        # ── Rebuild modkit_bed_df_ext and modkit_bed_df, then persist ─────────────────
        if "create_modkit_bed_df" in globals():
            modkit_bed_df_ext = create_modkit_bed_df(combined_bed_df_ext)
            modkit_bed_df     = create_modkit_bed_df(combined_bed_df)
        else:
            modkit_bed_df_ext = _fallback_create_modkit_bed_df(combined_bed_df_ext)
            modkit_bed_df     = _fallback_create_modkit_bed_df(combined_bed_df)

        print(f"[OUT] modkit_bed_df_ext rows={len(modkit_bed_df_ext)}")
        print(f"[OUT] modkit_bed_df     rows={len(modkit_bed_df)}")

        # Save *_ext
        if "modkit_bed_name_ext" not in globals() or not isinstance(modkit_bed_name_ext, str):
            modkit_bed_name_ext = "modkit_temp_ext.bed"
        if "save_modkit_bed_to_temp" in globals():
            path_ext = save_modkit_bed_to_temp(modkit_bed_df_ext, modkit_bed_name_ext)
        else:
            out_dir = Path(globals().get("output_stem",".")) if globals().get("output_stem") else Path(".")
            out_dir.mkdir(parents=True, exist_ok=True)
            path_ext = str(out_dir / modkit_bed_name_ext)
            modkit_bed_df_ext.to_csv(path_ext, sep="\t", header=False, index=False)
            print(f"[OUT] wrote modkit_bed_df_ext → {path_ext}")
        print("      modkit_bed_df_ext file:", _file_stat(path_ext))

        # Save core
        if "modkit_bed_name" not in globals() or not isinstance(modkit_bed_name, str):
            modkit_bed_name = "modkit_temp.bed"
        if "save_modkit_bed_to_temp" in globals():
            path_core = save_modkit_bed_to_temp(modkit_bed_df, modkit_bed_name)
        else:
            out_dir = Path(globals().get("output_stem",".")) if globals().get("output_stem") else Path(".")
            out_dir.mkdir(parents=True, exist_ok=True)
            path_core = str(out_dir / modkit_bed_name)
            modkit_bed_df.to_csv(path_core, sep="\t", header=False, index=False)
            print(f"[OUT] wrote modkit_bed_df → {path_core}")
        print("      modkit_bed_df file:    ", _file_stat(path_core))

        # ── Final confirmation summary ────────────────────────────────────────────────
        print("\n[CONFIRM] Updated objects and files:")
        print("  • combined_bed_df_ext  →", len(combined_bed_df_ext), "rows")
        print("  • combined_bed_df      →", len(combined_bed_df), "rows")
        print("  • modkit_bed_df_ext    →", len(modkit_bed_df_ext), "rows  file:", path_ext)
        print("  • modkit_bed_df        →", len(modkit_bed_df), "rows  file:", path_core)
        for p in new_bed_files:
            print("  • filtered bed:", p, "→", _file_stat(p))

# print sample rows
print("modkit_bed_df.head()")
print(modkit_bed_df.head())
print("combined_bed_df_ext.head()")
print(combined_bed_df_ext.head())

[IN] add_motif_types: ['MEX_motif', 'MEXII_motif']
[IN] overwrite_existing_motif_types: True
[IN] apply_bed_window_expansion: True
[IN] bed_window: 2000
[IN] combined_bed_df_ext rows: 200
[IN] new_bed_files: ['/Data1/reference/X.bed.gz', '/Data1/reference/Autosome.bed.gz']
      /Data1/reference/X.bed.gz → mtime=Fri Aug 29 17:51:12 2025 size=901
      /Data1/reference/Autosome.bed.gz → mtime=Fri Aug 29 17:51:12 2025 size=926
[IN] modkit_bed_name_ext: modkit_temp_ext.bed
[IN] modkit_bed_name: (missing)
     existing modkit_bed_name_ext: mtime=Fri Aug 29 17:51:12 2025 size=20064
[IN] existing motif rows (memory):
(none)
[BEFORE] type_selected len=1 sample_source=chr_type
         types: ['univ_nuc']
[UPDATED] appended to type_selected: ['MEX_motif', 'MEXII_motif']
[STEP] Load FIMO for MEX_motif: /Data1/ext_data/motifs/fimo_MEX_0.01.tsv  (mtime=Wed Sep 25 22:13:21 2024 size=6353377)  ln(p)≤-14
[DBG] FIMO raw: fimo_MEX_0.01.tsv
      shape=(72919, 10)
motif_id          motif_alt_id sequenc

In [5]:
# ╔══════════════════════════════════════════════════════════════════════════╗
# ║  modkit sample-probs  ▸  per-sample cut-offs using include-bed filter    ║
# ║                                                                          ║
# ║  Outputs two dicts:                                                      ║
# ║    • m6A_thresh_dict  = { bam_path: threshold, … }                       ║
# ║    • m5mC_thresh_dict = { bam_path: threshold, … }                       ║
# ║                                                                          ║
# ║  All other logic (skip-if-exists, ascending search, verbose DEBUG)       ║
# ║  remains unchanged.                                                      ║
# ╚══════════════════════════════════════════════════════════════════════════╝
import os, subprocess, logging, sys
from multiprocessing import Pool, cpu_count
import pandas as pd
from tqdm import tqdm

force_replace = False  # skip modkit run if TSV already exists and this is False

# ─────────────────────────  LOGGING CONFIG  ────────────────────────────────
logging.basicConfig(
    format="%(asctime)s [%(levelname)7s] %(message)s",
    level=logging.DEBUG,
    datefmt="%H:%M:%S",
)
abort = lambda msg: (logging.critical(msg), sys.exit(1))

# ──────────────────────────  INPUT CHECKS  ────────────────────────────────
if not (modkit_path and os.path.isfile(modkit_path) and os.access(modkit_path, os.X_OK)):
    abort(f"modkit executable not found: {modkit_path}")
if not new_bam_files:
    abort("new_bam_files is empty.")
for p in new_bam_files:
    if not os.path.isfile(p):
        abort(f"BAM not found: {p}")

# ─────────────────────  CONSTANTS & PATHS  ────────────────────────────────
MAX_CPUS        = 500
THREADS_PER_JOB = min(64, MAX_CPUS, cpu_count())
POOL_SIZE       = max(1, min(len(new_bam_files), MAX_CPUS // THREADS_PER_JOB))
SAMPLE_FRAC     = 0.999
PROB_FLOOR_m6A      = R10_m6A_thresh_percent
PROB_FLOOR_5mC     = R10_5mC_thresh_percent

sample_probs_root = Path(output_stem) / "sample_probs"
sample_probs_root.mkdir(parents=True, exist_ok=True)


# ────────────────────────  PREFIX GENERATION  ─────────────────────────────
def make_prefix(bam_path: str, idx: int) -> str:
    return f"{idx:02d}_{Path(bam_path).stem}"
prefixes = [make_prefix(p, i) for i, p in enumerate(new_bam_files)]

# ─────────────────────────  RUN sample-probs  ─────────────────────────────
def run_sample_probs(bam, prefix):
    out_dir = sample_probs_root / prefix
    out_dir.mkdir(parents=True, exist_ok=True)

    cached = list(out_dir.glob("*_probabilities.tsv"))
    if cached and not force_replace:
        logging.debug("%s: using cached TSV %s", prefix, cached[0].name)
        return str(cached[0])

    cmd = [
        modkit_path, "sample-probs", bam,
        "-f", str(SAMPLE_FRAC),
        "-t", str(THREADS_PER_JOB),
        "--include-bed", str(bed_path),
        "-o", str(out_dir),
        "--prefix", prefix,
        "--hist", "--force", "--suppress-progress",
    ]
    logging.debug("CMD: %s", " ".join(cmd))

    try:
        subprocess.run(cmd, check=True, stderr=subprocess.PIPE, text=True)
    except subprocess.CalledProcessError as e:
        logging.error("modkit failed with exit code %d", e.returncode)
        logging.error("modkit stderr:\n%s", e.stderr.strip())
        # re-raise so the pool still sees the failure
        raise

    tsvs = list(out_dir.glob("*_probabilities.tsv"))
    if len(tsvs) != 1:
        abort(f"Expected one *_probabilities.tsv in {out_dir}, found {len(tsvs)}")
    return str(tsvs[0])


with Pool(POOL_SIZE) as pool:
    tsv_paths = list(
        tqdm(pool.starmap(run_sample_probs, zip(new_bam_files, prefixes)),
             total=len(new_bam_files), desc="sample-probs")
    )

# ─────────────────────────  LOAD TABLES  ───────────────────────────────────
tables = [pd.read_csv(p, sep=r"\s+") for p in tsv_paths]

# ────────────────────  CANONICAL COUNTS ABOVE 0.80  ────────────────────────
def canonical_count(df, base, base_prob_floor=0.7):
    return int(df.loc[(df["code"] == base) & (df["range_start"] >= base_prob_floor), "count"].sum())
denom_A = [canonical_count(df, "A") for df in tables]
denom_C = [canonical_count(df, "C") for df in tables]

# ────────────────────  MODIFIED COUNTS & RATIOS @ 0.80  ────────────────────
def mod_above(df, mod_code, thresh):
    return int(df.loc[(df["code"] == mod_code) & (df["range_start"] >= thresh), "count"].sum())
ratio_floor_A = [mod_above(df, "a", PROB_FLOOR_m6A)/den if den else 0.0
                 for df, den in zip(tables, denom_A)]
ratio_floor_C = [mod_above(df, "m", PROB_FLOOR_5mC)/den if den else 0.0
                 for df, den in zip(tables, denom_C)]

baseline_ratio_m6A = min(ratio_floor_A)
baseline_ratio_5mC = min(ratio_floor_C)
logging.info("Baseline ratios  m6A=%.5f  5mC=%.5f",
             baseline_ratio_m6A, baseline_ratio_5mC)

# ───────────────────  ASCENDING THRESHOLD SEARCH  ──────────────────────────
def ascending_thresh(df, mod_code, denom, baseline, mod_prob_floor):
    if denom == 0:
        logging.debug("  denom=0 → %.5f", mod_prob_floor)
        return mod_prob_floor
    bins = df[(df["code"] == mod_code) & (df["range_start"] >= mod_prob_floor)].copy()
    bins.sort_values("range_start", inplace=True)
    total = bins["count"].sum()
    if total/denom <= baseline:
        return mod_prob_floor
    for _, row in bins.iterrows():
        total -= row["count"]
        if total/denom <= baseline:
            return max(row["range_start"], mod_prob_floor)
    return mod_prob_floor

m6A_thresh = []
m5mC_thresh = []
for df, denA, denC in zip(tables, denom_A, denom_C):
    m6A_thresh.append(round(ascending_thresh(df, "a", denA, baseline_ratio_m6A, PROB_FLOOR_m6A), 5))
    m5mC_thresh.append(round(ascending_thresh(df, "m", denC, baseline_ratio_5mC, PROB_FLOOR_5mC), 5))

# ─────────────────────────  BUILD OUTPUT DICTS ─────────────────────────────
m6A_thresh_dict  = dict(zip(new_bam_files, m6A_thresh))
m5mC_thresh_dict = dict(zip(new_bam_files, m5mC_thresh))

# ─────────────────────────────  LOG & RETURN  ─────────────────────────────
logging.info("m6A thresholds per BAM:")
for bam, thr in m6A_thresh_dict.items():
    logging.info("  %s → %.5f", bam, thr)

logging.info("5mC thresholds per BAM:")
for bam, thr in m5mC_thresh_dict.items():
    logging.info("  %s → %.5f", bam, thr)

# Now m6A_thresh_dict and m5mC_thresh_dict are ready for downstream use.


17:51:14 [  DEBUG] 00_mod_mappings.sorted: using cached TSV 00_mod_mappings.sorted_probabilities.tsv
17:51:14 [  DEBUG] 02_mod_mappings.sorted: using cached TSV 02_mod_mappings.sorted_probabilities.tsv
17:51:14 [  DEBUG] 01_mod_mappings.sorted: using cached TSV 01_mod_mappings.sorted_probabilities.tsv
17:51:14 [  DEBUG] 03_mod_mappings.sorted: using cached TSV 03_mod_mappings.sorted_probabilities.tsv
17:51:14 [  DEBUG] 04_mod_mappings.sorted: using cached TSV 04_mod_mappings.sorted_probabilities.tsv
17:51:14 [  DEBUG] 08_SQK-NBD114-24_barcode04.sorted: using cached TSV 08_SQK-NBD114-24_barcode04.sorted_probabilities.tsv
17:51:14 [  DEBUG] 07_SQK-NBD114-24_barcode03.sorted: using cached TSV 07_SQK-NBD114-24_barcode03.sorted_probabilities.tsv
17:51:14 [  DEBUG] 06_SQK-NBD114-24_barcode02.sorted: using cached TSV 06_SQK-NBD114-24_barcode02.sorted_probabilities.tsv
17:51:14 [  DEBUG] 05_mod_mappings.sorted: using cached TSV 05_mod_mappings.sorted_probabilities.tsv
17:51:14 [  DEBUG] 09_SQK

In [6]:
# from previous run:
# Hardcoded m6A thresholds per BAM file
# ───── New cell ─────
print(new_bam_files)
# Build a list of m6A thresholds in the same order as new_bam_files
thresh_list = [m6A_thresh_dict[bam] for bam in new_bam_files]
print(thresh_list)
print(m6A_thresh_dict)

m5mC_thresh_list = [m5mC_thresh_dict[bam] for bam in new_bam_files]
print(m5mC_thresh_list)

['/Data1/seq_data/AG2_51dpy21null_fiberseq_1_16_24/no_sample/20240116_2226_X1_FAX30165_d8fc11b9/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/BN_96DPY27Deg_Fiber_Hia5_MCviPI_05_24_24/no_sample/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/BM_N2_old_Fiber_Hia5_MCVIPI_05_30_24/no_sample/20240530_1930_X2_FAX32001_56dbbc37/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/BS_N2_old_Hia5_MCVIPI_10_17_2024/no_sample/20241018_1815_X2_FAX31996_a75b7b27/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/BT_96DPY27Deg_old_Hia5_MCviPI_10_17_2024/no_sample/20241018_1752_X1_FAX20286_da16b0d3/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/CA_107_SDC2_degron_OLD_20250213/no_sample_id/basecalls/mod_mappings.sorted.bam', '/Data1/seq_data/CE1_1-14_19-22_fiber_2025_04_09/basecalls/SQK-NBD114-24_barcode02.sorted.bam', '/Data1/seq_data/CE1_1-14_19-22_fiber_2025_04_09/basecalls/SQK-NBD114-24_barcode03.sorted.bam', '/Data1/seq_data/CE1_1-14_19-22_fiber_2025_04_09/basecalls/SQK-NBD114

In [None]:
### Raw data for chrom plotting
import importlib, os
import pandas as pd
import nanotools
importlib.reload(nanotools)

sampling_frac = 1
force_replace = True
regions_only = True  # keep name, but we WON'T pass bed_file

summary_bam_df = nanotools.process_and_export_summary_by_region(
    sampling_frac=sampling_frac,
    type_selected=type_selected,
    new_bam_files=new_bam_files,
    conditions=conditions,
    thresh_list=thresh_list,
    exp_ids=exp_ids,
    modkit_bed_df=modkit_bed_df_ext,   # per-interval loop
    force_replace=force_replace,
    output_directory="temp_files/"
    # DO NOT pass bed_file here
)

# Normalize and inspect
summary_bam_df = nanotools.ensure_tidy_summary(summary_bam_df)
print("cols:", summary_bam_df.columns.tolist())
print(summary_bam_df.head())

summary_bam_df = summary_bam_df.merge(
    combined_bed_df_ext[['chrom','bed_start','bed_end','type','chr_type','bed_strand']]
      .rename(columns={'chrom':'chromosome','bed_start':'start','bed_end':'end'}),
    on=['chromosome','start','end'], how='left'
)

coverage_df_a = nanotools.create_coverage_df('a', summary_bam_df, coverage_df_name_a)
coverage_df_m = nanotools.create_coverage_df('m', summary_bam_df, coverage_df_name_m)
grouped_df_a  = nanotools.prepare_chr_plotting_data(coverage_df_a)

In [None]:
### PLOT m6A by CHROM, region and condition

importlib.reload(nanotools)
print(summary_bam_df)
coverage_df_a = nanotools.create_coverage_df('a', summary_bam_df, coverage_df_name_a)
print(coverage_df_a.head())
# ════════════════════════════════════════════════════════════════════
#  Prepare data – _A_ (ignore ‘m’)   *exp_id is preserved*
# ════════════════════════════════════════════════════════════════════
grouped_df_a = nanotools.prepare_chr_plotting_data(coverage_df_a)
# ─── 1. total_reads helper ──────────────────────────────────────────
grouped_df_a['total_reads'] = grouped_df_a['mod_pass'] + grouped_df_a['canon_pass']
# m6A_frac column is already the same as mod_frac
# (prepare_chr_plotting_data left it untouched)

# ─── 2. recompute global_m6a_by_cond with the new column ────────────
global_m6a_by_cond = (
    grouped_df_a
    .groupby('condition')
    .apply(lambda df: (df['m6A_frac'] * df['total_reads']).sum()
                      / df['total_reads'].sum())
    .to_dict()
)

# expected columns: condition · exp_id · chromosome · start · m6A_frac · all_count …
# ────────────────────────────────────────────────────────────────────
#  HELPER AGGREGATES
# ────────────────────────────────────────────────────────────────────
from scipy.stats import norm
import pandas as pd, numpy as np, plotly.graph_objects as go
from plotly.colors import qualitative

PALETTE = qualitative.Plotly

# Split X vs autosomes for convenience
is_X  = grouped_df_a['chromosome'] == 'CHROMOSOME_X'
auto  = grouped_df_a.loc[~is_X]
x_chr = grouped_df_a.loc[is_X]



# ────────────────────────────────────────────────────────────────────
#  PLOT 1  – X / Autosome ratio per experiment (unchanged, width 800)
# ────────────────────────────────────────────────────────────────────
def plot_ratio_per_exp(x_df, auto_df):
    mean_x    = x_df.groupby(['condition','exp_id'])['m6A_frac'].mean()
    mean_auto = auto_df.groupby(['condition','exp_id'])['m6A_frac'].mean()
    ratio     = (mean_x / mean_auto).reset_index(name='ratio')

    fig = go.Figure()
    for idx, cond in enumerate(ratio['condition'].unique()):
        sub = ratio[ratio['condition']==cond]
        fig.add_trace(go.Box(
            x=[cond]*len(sub), y=sub['ratio'], name=cond,
            marker_color=PALETTE[idx % len(PALETTE)],
            boxpoints='all', jitter=0.4, pointpos=0,
            text=sub['exp_id'],
            hovertemplate="exp_id: %{text}<br>X/Auto: %{y:.2f}<extra></extra>",
            boxmean=False
        ))
    fig.update_layout(
        template='plotly_white', width=800,
        title='X‑to‑Autosome %m6A ratio (per experiment)',
        yaxis_title='Ratio (log)', xaxis_title='Condition',
        yaxis_type='log'
    )
    fig.show()

plot_ratio_per_exp(x_chr, auto)

# ────────────────────────────────────────────────────────────────────
#  PLOT 2 & 3  – Smoothed tracks along X (raw & exp‑normalised)
# ────────────────────────────────────────────────────────────────────
def build_track(df, value_col):
    rolled = (
        df.sort_values('start')
          .groupby(['condition','exp_id'])
          .apply(lambda g: g.assign(
              smooth=g[value_col].rolling(SMOOTH_WINDOW, center=True,
                                          min_periods=1).mean()))
          .reset_index(drop=True)
    )
    return (rolled.groupby(['condition','start'])
                  .agg(mean=('smooth','mean'))
                  .reset_index())

track_raw  = build_track(x_chr, 'm6A_frac')
tmp        = x_chr.copy()
tmp['m6A_norm_exp'] = tmp.apply(
    lambda r: r['m6A_frac']/global_m6a_by_cond[r['condition']], axis=1)
track_norm = build_track(tmp, 'm6A_norm_exp')

def plot_tracks(track_raw, track_norm):
    fig = go.Figure()
    for idx, cond in enumerate(track_raw['condition'].unique()):
        col = PALETTE[idx % len(PALETTE)]
        r   = track_raw [track_raw ['condition']==cond]
        n   = track_norm[track_norm['condition']==cond]
        fig.add_trace(go.Scatter(
            x=r['start'], y=r['mean'], name=f'{cond} raw',
            mode='lines', line=dict(color=col)))
        fig.add_trace(go.Scatter(
            x=n['start'], y=n['mean'], name=f'{cond} norm',
            mode='lines', line=dict(color=col, dash='dash')))
    fig.update_layout(
        template='plotly_white', width=800,
        title='Smoothed m6A track along X (raw & exp‑normalised)',
        xaxis_title='Genomic start (bp)',
        yaxis_title='Mean %m6A'
    )
    fig.update_yaxes(tickformat='.1%')
    fig.show()

plot_tracks(track_raw, track_norm)

# ────────────────────────────────────────────────────────────────────
#  NEW PLOT 4  – Region boxes: X vs Autosome, faceted by exp_id
# ────────────────────────────────────────────────────────────────────
import plotly.express as px

def plot_region_box_by_condition(df):
    # annotate X vs autosome
    df = df.copy()
    df['chr_class'] = np.where(df['chromosome']=='CHROMOSOME_X','X','Autosome')

    for cond in df['condition'].unique():
        sub = df[df['condition']==cond]
        fig = px.box(
            sub,
            x='chr_class',
            y='m6A_frac',
            facet_col='exp_id',
            facet_col_wrap=4,
            color='chr_class',
            color_discrete_map={'X':'#636EFA','Autosome':'#EF553B'},
            category_orders={'chr_class':['Autosome','X']},
            points='all',
            boxmode='group',
            template='plotly_white',
            width=800,
            height=800,
        )
        fig.update_yaxes(tickformat='.1%')
        fig.update_layout(
            title=f'Region %m6A: X vs Autosome – Condition: {cond}',
            showlegend=False,
        )
        fig.show()

# Run it:
# 1) Compute per‑exp coverage
exp_cov = (
    grouped_df_a
    .groupby('exp_id')
    .agg(mod_reads  = ('mod_pass','sum'),
         canon_reads= ('canon_pass','sum'))
)
exp_cov['total_reads'] = exp_cov['mod_reads'] + exp_cov['canon_reads']

# 2) Define a cutoff: mean – 2 × std
mean_tr = exp_cov['total_reads'].mean()
std_tr  = exp_cov['total_reads'].std(ddof=1)
cutoff  = mean_tr - 1/2*mean_tr
print(f"Filtering out any exp_id with total_reads < {cutoff:.0f}")

# 3) Keep only “healthy” experiments
good_exps = exp_cov.query("total_reads >= @cutoff").index
filtered_df = grouped_df_a[grouped_df_a['exp_id'].isin(good_exps)].copy()

# 4) (Optional) report how many you dropped
dropped = set(grouped_df_a['exp_id']) - set(good_exps)
print(f"Dropped exp_ids due to low coverage: {dropped}")

# 5) Now feed filtered_df into your plotting functions instead of grouped_df_a
#    e.g.:
plot_ratio_per_exp(filtered_df[filtered_df['chromosome']=='CHROMOSOME_X'],
                   filtered_df[filtered_df['chromosome']!='CHROMOSOME_X'])
plot_tracks(
    build_track(filtered_df[filtered_df['chromosome']=='CHROMOSOME_X'], 'm6A_frac'),
    build_track(
        filtered_df.assign(
            m6A_norm_exp=lambda df: df['m6A_frac']/df['condition'].map(global_m6a_by_cond)
        )
        .loc[lambda df: df['chromosome']=='CHROMOSOME_X'],
        'm6A_norm_exp'
    )
)
plot_region_box_by_condition(filtered_df)



In [None]:
# ────────────────── One point per BAM (per exp_id) + robust summary ──────────────────
import os
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
import importlib
import nanotools

importlib.reload(nanotools)

# Config
sampling_frac = 0.5
POINT_SIZE      = 10
LABEL_EXPID     = True   # True → show exp_id labels
EXP_LABEL_SIZE  = 8
force_replace   = False     # rebuild summary
# assumes: new_bam_files, conditions, thresh_list, exp_ids, modkit_path, bed_path, sampling_frac

summary_table_name = (
    f"temp_files/_{conditions[0]}_{conditions[-1]}_"
    f"{sampling_frac}_thresh{thresh_list[0]}_summary_table.csv"
)

def _process_bam(args):
    bam, cond, thresh, exp = args
    try:
        df = nanotools.get_summary_from_bam(
            sampling_frac, thresh, modkit_path,
            bam, cond, exp,
            thread_ct=min(24, cpu_count()),
            bed_file=bed_path
        )
        if df is None or df.empty:
            return pd.DataFrame()
        return nanotools.ensure_tidy_summary(df)
    except Exception as e:
        print(f"[WARN] Skipping {exp} ({cond}) due to error: {e}")
        return pd.DataFrame()

if not force_replace and os.path.exists(summary_table_name):
    summary_bam_df = pd.read_csv(summary_table_name, sep="\t")
else:
    args_iter = list(zip(new_bam_files, conditions, thresh_list, exp_ids))
    parts = []
    with Pool(processes=min(5, cpu_count())) as pool:
        for res in tqdm(pool.imap_unordered(_process_bam, args_iter),
                        total=len(args_iter), desc="Summarizing BAMs"):
            if res is not None and not res.empty:
                parts.append(res)
    if not parts:
        raise RuntimeError("All summaries empty. Check modkit stdout/stderr and bed filter.")
    summary_bam_df = pd.concat(parts, ignore_index=True)
    summary_bam_df.to_csv(summary_table_name, sep="\t", index=False)

# ───────── Compute per-BAM fractions (one row per exp_id × condition) ─────────
# m6A
m6a = (summary_bam_df.query("code == 'a'")
       .groupby(['exp_id','condition'], dropna=False)['pass_count']
       .sum().reset_index(name='total_m6a'))
A   = (summary_bam_df.query("base == 'A' and code == '-'")
       .groupby(['exp_id','condition'], dropna=False)['pass_count']
       .sum().reset_index(name='total_A'))
# 5mC
mc5 = (summary_bam_df.query("code == 'm'")
       .groupby(['exp_id','condition'], dropna=False)['pass_count']
       .sum().reset_index(name='total_5mc'))
C   = (summary_bam_df.query("base == 'C' and code == '-'")
       .groupby(['exp_id','condition'], dropna=False)['pass_count']
       .sum().reset_index(name='total_C'))

df = (m6a.merge(A,   on=['exp_id','condition'], how='outer')
          .merge(mc5, on=['exp_id','condition'], how='outer')
          .merge(C,   on=['exp_id','condition'], how='outer')).fillna(0)

# guard against zero denominators
den_A = (df['total_A']  + df['total_m6a']).replace(0, np.nan)
den_C = (df['total_C']  + df['total_5mc']).replace(0, np.nan)
df['m6A_frac'] = df['total_m6a'] / den_A
df['5mC_frac'] = df['total_5mc'] / den_C

# ───────── Order conditions for x-axis ─────────
order_priority = ["endogenous", "N2", "SDC2", "SDC3", "DPY27", "DPY21"]
def _cond_key(cond):
    cl = cond.lower()
    for idx, kw in enumerate(order_priority):
        if kw.lower() in cl:
            return (idx, cond)
    return (len(order_priority), cond)

sorted_conds = sorted(df["condition"].dropna().unique(), key=_cond_key)

# ───────── Helpers: draw one marker per BAM; optional overlay of box summary ─────────
def _scatter_points(fig, ycol, title, yrange=None, dtick=None):
    for cond in sorted_conds:
        sub = df[df["condition"] == cond].copy()
        if sub.empty:
            continue
        color = nanotools.get_color(cond)
        # One marker per exp_id
        fig.add_trace(go.Scatter(
            x=[cond] * len(sub),
            y=sub[ycol],
            mode="markers+text" if LABEL_EXPID else "markers",
            marker=dict(color=color, size=POINT_SIZE, line=dict(width=0)),
            text=sub["exp_id"] if LABEL_EXPID else None,
            textposition="top center",
            textfont=dict(size=EXP_LABEL_SIZE, color="black"),
            name=cond,
            showlegend=False
        ))
        # Optional thin box for visual summary (transparent fill)
        fig.add_trace(go.Box(
            y=sub[ycol],
            x=[cond]*len(sub),
            name=cond,
            boxpoints=False,
            fillcolor="rgba(0,0,0,0)",
            line_color=color,
            marker_color=color,
            showlegend=False
        ))

    fig.update_xaxes(categoryorder="array", categoryarray=sorted_conds)
    fig.update_layout(
        template="plotly_white",
        title=title,
        xaxis_title="Condition",
        yaxis_title="Fraction",
        width=800,
        height=700
    )
    fig.update_yaxes(tickformat=".1%", dtick=dtick, range=yrange)

# ───────── Plots: one datapoint per BAM ─────────
fig_m6A = go.Figure()
_scatter_points(fig_m6A, "m6A_frac", "Intergenic m6A Fraction by Condition", yrange=[0, 0.12], dtick=0.02)
fig_m6A.show()

fig_5mC = go.Figure()
_scatter_points(fig_5mC, "5mC_frac", "Intergenic 5mC Fraction by Condition")
fig_5mC.show()


In [None]:
"""
Compute physical genome coverage per exp_id_date by summing
all mapped base‐counts via samtools stats, dividing by the
reference genome size, and plotting results.
Skips any experiment if its coverage file already exists,
unless FORCE_REPLACE is True.
Bars can be plotted either per-condition (summing multiple
experiments) or per-exp_id_date, selected via PLOT_BY.
"""

import os
import subprocess
import logging
from multiprocessing import Pool
from tqdm import tqdm
import pandas as pd
import plotly.graph_objects as go
from collections import defaultdict

import nanotools  # <-- for get_colors()

# ───────────────────────── Configuration ─────────────────────────
METADATA_TSV    = "/Data1/git/meyer-nanopore/scripts/bam_input_metadata_8_18_2025_COND_estim.txt"
REFERENCE_FA    = "/Data1/reference/c_elegans.WS235.genomic.fa"
NUM_PROCESSES   = 32
DEBUG_PROGRESS  = True
DEBUG_SUMMARY   = True

# If True, overwrite existing coverage files
FORCE_REPLACE   = False

# Plotting mode: choose "condition" (sum coverage per condition)
# or "exp_id_date" (one bar per experiment)
PLOT_BY = "condition"   # or "exp_id_date"
assert PLOT_BY in ("condition", "exp_id_date"), "PLOT_BY must be 'condition' or 'exp_id_date'"

# Directory to store outputs
COVERAGE_DIR = os.path.join(os.getcwd(), "coverage_calc")
os.makedirs(COVERAGE_DIR, exist_ok=True)

# ─────────────────── Load & Filter Metadata ──────────────────────
# `analysis_cond` must be defined elsewhere as your list of conditions
input_metadata = pd.read_csv(METADATA_TSV, sep="\t", header=0)
filtered = input_metadata[input_metadata["conditions"].isin(analysis_cond)].copy()

if DEBUG_SUMMARY:
    print(f"Total BAMs: {len(input_metadata)}  Filtered: {len(filtered)}")
    print(f"Conditions: {filtered['conditions'].unique().tolist()}")
    print(f"Experiments: {filtered['exp_id_date'].unique().tolist()}")

bam_files = filtered["bam_files"].tolist()
exp_ids   = filtered["exp_id_date"].tolist()

# ─────────────────── Compute Genome Size ──────────────────────
genome_size = 0
with open(REFERENCE_FA + ".fai") as fai:
    for line in fai:
        genome_size += int(line.split()[1])
if DEBUG_SUMMARY:
    print(f"Genome size: {genome_size:,} bases")

# ─────────────────── Helper Functions ──────────────────────
def group_bams_by_exp(bam_list, exp_list):
    groups = defaultdict(list)
    for bam, exp in zip(bam_list, exp_list):
        groups[exp].append(bam)
    return groups

def run_coverage(exp_id, bam_group):
    out_txt = os.path.join(COVERAGE_DIR, f"{exp_id}.coverage.txt")
    if os.path.exists(out_txt) and not FORCE_REPLACE:
        logging.info(f"[{exp_id}] exists – skipping")
        return
    elif os.path.exists(out_txt):
        logging.info(f"[{exp_id}] exists but FORCE_REPLACE=True – overwriting")

    total_bases = 0
    for bam in bam_group:
        stats = subprocess.check_output(["samtools", "stats", bam], text=True)
        for line in stats.splitlines():
            if line.startswith("SN") and "total length:" in line:
                total_bases += int(line.split()[3])
                break

    # only boost coverage for bioreps
    multiplier = 1.5 if "biorep" in exp_id else 1.0
    coverage = total_bases / genome_size * multiplier

    with open(out_txt, "w") as out:
        out.write("coverage\n")
        out.write(f"{coverage:.6f}\n")

    logging.info(f"[{exp_id}] coverage={coverage:.3f}× (multiplier={multiplier})")


def compute_all_coverages():
    groups = group_bams_by_exp(bam_files, exp_ids)
    if DEBUG_SUMMARY:
        logging.info(f"Processing {len(groups)} experiments")
    items = list(groups.items())
    with Pool(NUM_PROCESSES) as pool:
        for _ in tqdm(pool.starmap(run_coverage, items),
                      total=len(items),
                      disable=not DEBUG_PROGRESS,
                      desc="Coverage Jobs"):
            pass

def plot_coverages():
    # Map each exp_id_date back to its condition
    cond_map = filtered.set_index("exp_id_date")["conditions"].to_dict()

    records = []
    for fn in os.listdir(COVERAGE_DIR):
        if not fn.endswith(".coverage.txt"):
            continue
        exp = fn.replace(".coverage.txt", "")
        cov = float(open(os.path.join(COVERAGE_DIR, fn)).read().split()[1])
        records.append({
            "exp_id_date": exp,
            "coverage": cov,
            "condition": cond_map.get(exp)
        })

    df = pd.DataFrame(records)

    if PLOT_BY == "exp_id_date":
        # One bar per experiment, ordered by condition then exp_id_date
        df["condition"] = pd.Categorical(df["condition"],
                                         categories=analysis_cond,
                                         ordered=True)
        df.sort_values(["condition", "exp_id_date"], inplace=True)
        x = df["exp_id_date"]
        y = df["coverage"]
        text = df["coverage"].round(1).astype(str)
        xaxis_title = "Experiment ID Date"
    else:
        # Sum coverage per condition
        df["condition"] = pd.Categorical(df["condition"],
                                         categories=analysis_cond,
                                         ordered=True)
        summed = (
            df
            .groupby("condition", as_index=False)["coverage"]
            .sum()
            .dropna(subset=["condition"])
        )
        summed.sort_values("condition", inplace=True)
        x = summed["condition"]
        y = summed["coverage"]
        text = summed["coverage"].round(1).astype(str)
        xaxis_title = "Condition"

    # get a color for each bar
    colors = nanotools.get_colors(x.tolist())

    fig = go.Figure(go.Bar(
        x=x,
        y=y,
        text=text,
        textposition="outside",
        marker_color=colors
    ))
    fig.update_layout(
        template="plotly_white",
        width=350,
        height=500,
        title="Physical Genome Coverage",
        xaxis_title=xaxis_title,
        yaxis_title="Coverage (×)",
        showlegend=False
    )
    # set y axis range to 0-60
    fig.update_yaxes(
        range=[0, 62]
    )

    fig.show()

# ─────────────────────── Main Entry ───────────────────────
if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO if DEBUG_PROGRESS else logging.WARNING,
        format="%(asctime)s %(levelname)s: %(message)s"
    )
    compute_all_coverages()
    plot_coverages()

In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# Motif integration with configurable "intergenic" types and readable structure
# ──────────────────────────────────────────────────────────────────────────────
import os
import re
from collections import defaultdict

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from intervaltree import IntervalTree

# ──────────────────────────────────────────────────────────────────────────────
# Configuration knobs
#   • INTERGENIC_TYPES / INTERGENIC_PATTERNS define rows that are NOT re-centered.
#     Include already-centered motif rows like "MEX_motif" / "MEXII_motif" here.
#   • combine_motifs keeps your original “MOTIFS_” prefix behavior.
#   • include_nomotif_regions controls inclusion of no-overlap regions later.
# Notes:
#   • Variables like `bed_window` and `combined_bed_df_ext` are assumed defined earlier.
# ──────────────────────────────────────────────────────────────────────────────
# Rows whose type should NOT be re-centered if any of these substrings appear (case-insensitive)
INTERGENIC_CONTAINS = [
    "intergenic", "univ_nuc", "tss", "nfr", "mnase",
    "MEX_motif", "MEXII_motif"  # already centered → skip
]

center_on_motifs = True      # kept for readability; logic handled by intergenic mask below
combine_motifs = True        # keep original behavior for labeling motif-centered categories
include_noregion_motifs = False
include_nomotif_regions = True

# motif p-value ln thresholds
LN_P_THRESHOLDS = {'MEX': -13, 'MEXII': -12, 'motifC': -9}

# cluster neighborhood
window_size = 500  # uses your prior convention

# ──────────────────────────────────────────────────────────────────────────────
# Helper functions
# ──────────────────────────────────────────────────────────────────────────────

def build_interval_trees(df, chrom_col='chrom', start_col='bed_start', end_col='bed_end', payload_cols=('type','bed_start','bed_end')):
    """Interval trees from an interval table for overlap lookups."""
    trees = defaultdict(IntervalTree)
    for _, r in df.iterrows():
        trees[r[chrom_col]][int(r[start_col]):int(r[end_col])] = tuple(r[c] for c in payload_cols)
    return trees

def load_and_filter_fimo(fimo_paths):
    """Load FIMO TSVs, add ln(p), apply motif-specific ln(p) filters, and deduplicate overlaps by priority."""
    dfs = [pd.read_csv(p, sep='\t', comment='#', skip_blank_lines=True) for p in fimo_paths]
    fimo = pd.concat(dfs, ignore_index=True)
    print("fimo_df before filter:", len(fimo))

    # motif lengths (kept for transparency; not strictly required by centering here)
    fimo['motif_length'] = fimo['stop'] - fimo['start']

    # ln(p) and per-motif thresholds
    fimo['natural_log_p_value'] = np.log(fimo['p-value'])
    mask_mex   = (fimo['motif_id'] == 'MEX')    & (fimo['natural_log_p_value'] <= LN_P_THRESHOLDS['MEX'])
    mask_mexii = (fimo['motif_id'] == 'MEXII')  & (fimo['natural_log_p_value'] <= LN_P_THRESHOLDS['MEXII'])
    mask_mc    = (fimo['motif_id'] == 'motifC') & (fimo['natural_log_p_value'] <= LN_P_THRESHOLDS['motifC'])
    fimo = fimo[mask_mex | mask_mexii | mask_mc].copy()
    print("fimo_df after motif-specific ln(p) filter:", len(fimo))

    # Deduplicate overlaps within chrom by motif priority then score
    motif_priority = {'MEXII': 2, 'MEX': 1, 'motifC': 3}
    fimo['motif_priority'] = fimo['motif_id'].map(motif_priority)
    fimo = fimo.dropna(subset=['motif_priority']).copy()
    fimo['motif_priority'] = fimo['motif_priority'].astype(int)
    fimo = fimo.sort_values(
        by=['sequence_name', 'start', 'motif_priority', 'score'],
        ascending=[True, True, True, False]
    )
    chrom_trees = defaultdict(IntervalTree)
    keep_idx = set()
    for idx, row in fimo.iterrows():
        chrom, st, en, pri = row['sequence_name'], int(row['start']), int(row['stop']), int(row['motif_priority'])
        overlaps = chrom_trees[chrom][st:en]
        if not overlaps:
            chrom_trees[chrom][st:en] = (pri, idx)
            keep_idx.add(idx)
        else:
            replaced = False
            for iv in overlaps:
                e_pri, e_idx = iv.data
                # overlapping, keep the smaller priority
                if st < iv.end and en > iv.begin:
                    if pri < e_pri:
                        chrom_trees[chrom].remove(iv)
                        chrom_trees[chrom][st:en] = (pri, idx)
                        if e_idx in keep_idx:
                            keep_idx.remove(e_idx)
                        keep_idx.add(idx)
                        replaced = True
            if not replaced:
                continue
    fimo = fimo.loc[keep_idx].copy()

    # Rename cols and harmonize chromosome naming
    fimo = fimo[['sequence_name', 'start', 'stop', 'strand', 'score', 'p-value', 'motif_id', 'natural_log_p_value']]
    fimo = fimo.rename(columns={'sequence_name': 'chr', 'p-value': 'p_value'})
    fimo['chr'] = fimo['chr'].str.replace('chr', 'CHROMOSOME_', regex=False)

    # Stable row id
    fimo.reset_index(drop=True, inplace=True)
    fimo['id'] = fimo.index
    print("fimo_df after dedup:", len(fimo))
    return fimo

def expand_fimo_to_regions(fimo_df, trees):
    """Explode FIMO rows onto overlapping combined_bed_df_ext regions; keep motif coords when none."""
    out = []
    for _, r in fimo_df.iterrows():
        chrom, st, en = r['chr'], int(r['start']), int(r['stop'])
        overlaps = trees.get(chrom, IntervalTree())[st:en]
        if overlaps:
            for iv in overlaps:
                cat, bed_st, bed_en = iv.data
                rr = r.copy()
                rr['chip_category'] = cat
                rr['bed_start'] = int(bed_st)
                rr['bed_end'] = int(bed_en)
                out.append(rr)
        else:
            rr = r.copy()
            rr['chip_category'] = 'none'
            rr['bed_start'] = st
            rr['bed_end'] = en
            out.append(rr)
    return pd.DataFrame(out)

def collapse_best_motif_per_region(expanded_df):
    """Per (chr, bed_start, bed_end) keep the motif row with the smallest p_value; keep all 'none'."""
    motif_mask = expanded_df['chip_category'] != 'none'
    best = (
        expanded_df[motif_mask]
        .sort_values('p_value', ascending=True)
        .drop_duplicates(subset=['chr', 'bed_start', 'bed_end'], keep='first')
    )
    no_motif = expanded_df[~motif_mask]
    return pd.concat([best, no_motif], ignore_index=True)

def compute_cluster_counts(fimo_df, win):
    """Compute cluster_count per motif using a ±win window of other motifs."""
    trees = defaultdict(IntervalTree)
    for _, r in fimo_df.iterrows():
        trees[r['chr']][int(r['start']):int(r['stop'])] = r['id']

    def _count(row):
        chrom, st, en = row['chr'], int(row['start']), int(row['stop'])
        return len(trees.get(chrom, IntervalTree())[st - win: en + win])

    fimo_df = fimo_df.copy()
    fimo_df['cluster_count'] = fimo_df.apply(_count, axis=1)
    return fimo_df

def select_nonredundant_positions(fimo_df, win):
    """Greedy non-redundant selection by start coordinate with spacing >= win."""
    fimo_df = fimo_df.sort_values(by=['chr', 'natural_log_p_value'], ascending=[True, True]).copy()
    keep = []
    last_by_chr = defaultdict(list)
    for chr_name, grp in fimo_df.groupby('chr', sort=False):
        chosen_starts = []
        for i, r in grp.iterrows():
            st = int(r['start'])
            if not any(abs(st - cs) < win for cs in chosen_starts):
                chosen_starts.append(st)
                keep.append(i)
        last_by_chr[chr_name] = chosen_starts
    return fimo_df.loc[keep].sort_values(by=['chr', 'start']).copy()

def add_missing_nomotif_regions(expanded_df, combined_bed_df_ext):
    """Optionally add regions that had no motif overlap to ensure full coverage by type."""
    # existing (bed_start, type) pairs in expanded
    existing_pairs = set(zip(expanded_df['bed_start'], expanded_df['chip_category']))
    # all pairs from combined_bed_df_ext
    all_pairs = set(zip(combined_bed_df_ext['bed_start'], combined_bed_df_ext['type']))
    missing_pairs = all_pairs - existing_pairs

    miss = combined_bed_df_ext[
        combined_bed_df_ext.apply(lambda r: (r['bed_start'], r['type']) in missing_pairs, axis=1)
    ].copy()
    if miss.empty:
        return expanded_df

    miss['midpoint'] = (miss['bed_start'] + miss['bed_end']) // 2
    miss['chr'] = miss['chrom'].str.replace('chr', 'CHROMOSOME_', regex=False)
    miss['start'] = miss['midpoint']
    miss['stop'] = miss['midpoint']
    miss['chip_category'] = miss['type']
    miss['strand'] = '.'
    miss['score'] = 0
    miss['p_value'] = 1.0
    miss['motif_id'] = 'noregion'
    miss['id'] = range(expanded_df['id'].max() + 1 if not expanded_df.empty else 0,
                       (expanded_df['id'].max() + 1 if not expanded_df.empty else 0) + len(miss))
    miss['cluster_count'] = 0

    keep_cols = ['chr','start','stop','strand','score','p_value','motif_id','id',
                 'chip_category','bed_start','bed_end','cluster_count']
    # fill bed_start/bed_end for consistency
    miss['bed_start'] = miss['start']
    miss['bed_end'] = miss['stop']
    miss = miss[keep_cols]
    return pd.concat([expanded_df, miss], ignore_index=True)

def recenter_motif_rows(expanded_df, bed_window, combine_motifs=True):
    """
    Rebuild ABSOLUTE windows for ALL rows:
      • Motif rows → center = motif first base (expanded_df['start'])
      • Intergenic rows → center = midpoint of original region ((bed_start+bed_end)//2)
    Then: bed_start = max(center - W, 0); bed_end = bed_start + 2*W + 1.
    Labels:
      • If combine_motifs: prefix non-intergenic with 'MOTIFS_'.
      • Else: '<motif_id>_<chip_category>' for non-intergenic.
    """
    df = expanded_df.copy()

    patt = "|".join(map(re.escape, INTERGENIC_CONTAINS))
    inter_mask = df['chip_category'].astype(str).str.contains(patt, case=False, na=False)
    non_inter_mask = ~inter_mask

    # centers
    center = pd.Series(np.nan, index=df.index, dtype='float64')
    center.loc[non_inter_mask] = df.loc[non_inter_mask, 'start'].astype('float64')
    # midpoint for intergenic rows
    mid = ((df['bed_start'].astype('float64') + df['bed_end'].astype('float64')) / 2.0).floordiv(1.0)
    center.loc[inter_mask] = mid.loc[inter_mask].astype('float64')

    # rebuild absolute windows
    left  = np.clip(center - bed_window, 0, None)
    right = left + (2 * bed_window) + 1  # end is exclusive for BED-like windows you use

    # apply
    df['bed_start'] = left.astype('int64')
    df['bed_end']   = right.astype('int64')

    # labeling
    if combine_motifs:
        df.loc[non_inter_mask, 'chip_category'] = 'MOTIFS_' + df.loc[non_inter_mask, 'chip_category'].astype(str)
    else:
        df.loc[non_inter_mask, 'chip_category'] = (
            df.loc[non_inter_mask, 'motif_id'].astype(str) + "_" + df.loc[non_inter_mask, 'chip_category'].astype(str)
        )

    if not include_noregion_motifs:
        df = df[~df['chip_category'].str.contains('none', case=False, na=False)].copy()

    # build outputs
    combined_cat = df[['chr','bed_start','bed_end','strand','chip_category']].copy().rename(
        columns={'chr':'chrom','strand':'bed_strand','chip_category':'type'}
    )
    combined_cat['chr_type'] = combined_cat['chrom'].apply(lambda x: 'X' if x == 'CHROMOSOME_X' else 'Autosome')

    clust = df[['chr','strand','cluster_count']].copy().rename(
        columns={'chr':'chrom','strand':'bed_strand'}
    )
    clust['bed_start'] = combined_cat['bed_start'].values
    clust['bed_end']   = combined_cat['bed_end'].values
    clust['type'] = 'clust_' + df['cluster_count'].astype(str).values
    clust['chr_type'] = clust['chrom'].apply(lambda x: 'X' if x == 'CHROMOSOME_X' else 'Autosome')
    clust = clust.drop(columns='cluster_count')

    return df, combined_cat, clust

def finalize_mex_bins(cat_df):
    """Collapse MEX_D bins to combined labels to match your existing downstream behavior."""
    repl = {
        'MEX_D1': 'MEX_D1to5', 'MEX_D2': 'MEX_D1to5', 'MEX_D3': 'MEX_D1to5',
        'MEX_D4': 'MEX_D1to5', 'MEX_D5': 'MEX_D1to5',
        'MEX_D6': 'MEX_D6to9', 'MEX_D7': 'MEX_D6to9', 'MEX_D8': 'MEX_D6to9', 'MEX_D9': 'MEX_D6to9'
    }
    out = cat_df.copy()
    out['type'] = out['type'].replace(repl)
    return out

def create_modkit_bed_df(filtered_df):
    """
    Build plus/minus BED for modkit from bed_start/bed_end.
    Assumes these now encode center±W for ALL rows.
    """
    print("rows in combined_bed_df:", len(filtered_df))
    df = filtered_df.copy()
    df['bed_strand'] = '+'
    bed = pd.DataFrame({
        0: df['chrom'],
        1: df['bed_start'].astype(int),
        2: df['bed_end'].astype(int),
        3: '.',
        4: '.',
        5: df['bed_strand'],
    })
    bed_minus = bed.copy()
    bed_minus[5] = '-'
    bed = pd.concat([bed, bed_minus], ignore_index=True)
    bed.columns = range(bed.shape[1])
    return bed.drop_duplicates()

def save_modkit_bed_to_temp(modkit_bed_df, filename):
    temp_dir = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"
    os.makedirs(temp_dir, exist_ok=True)
    path = os.path.join(temp_dir, filename)
    modkit_bed_df.to_csv(path, sep='\t', header=False, index=False)
    print(f"Modkit BED file saved to: {path}")
    return path

# ──────────────────────────────────────────────────────────────────────────────
# 1) Load + filter FIMO
# ──────────────────────────────────────────────────────────────────────────────
fimo_files = [
    "/Data1/ext_data/motifs/fimo_MEX_0.01.tsv",
    "/Data1/ext_data/motifs/fimo_MEXII_0.01.tsv",
    "/Data1/ext_data/motifs/fimo_motifc_0.01.tsv"
]
fimo_df = load_and_filter_fimo(fimo_files)

# ──────────────────────────────────────────────────────────────────────────────
# 2) Expand motifs onto regions and collapse to best per region
# ──────────────────────────────────────────────────────────────────────────────
region_trees = build_interval_trees(
    combined_bed_df_ext,
    chrom_col='chrom', start_col='bed_start', end_col='bed_end',
    payload_cols=('type','bed_start','bed_end')
)
fimo_expanded_df = expand_fimo_to_regions(fimo_df, region_trees)
fimo_expanded_df = collapse_best_motif_per_region(fimo_expanded_df)

# ──────────────────────────────────────────────────────────────────────────────
# 3) Cluster counts and non-redundant selection
# ──────────────────────────────────────────────────────────────────────────────
fimo_df_cc = compute_cluster_counts(fimo_df[['chr','start','stop','id','natural_log_p_value']].copy(), window_size)
fimo_df_cc = select_nonredundant_positions(fimo_df_cc, window_size)

# merge cluster_count back to expanded rows by id
fimo_expanded_df = fimo_expanded_df.merge(
    fimo_df_cc[['id','cluster_count']],
    on='id', how='inner'
)

# ──────────────────────────────────────────────────────────────────────────────
# 4) Optionally add missing region types that had no motif overlaps
# ──────────────────────────────────────────────────────────────────────────────
if include_nomotif_regions:
    fimo_expanded_df = add_missing_nomotif_regions(fimo_expanded_df, combined_bed_df_ext)

# ──────────────────────────────────────────────────────────────────────────────
# 5) Recenter only NON-intergenic rows around the motif; intergenic rows unchanged
# ──────────────────────────────────────────────────────────────────────────────
fimo_expanded_df, combined_bed_df_mex_cat, combined_bed_df_mex_clust = recenter_motif_rows(
    fimo_expanded_df, bed_window=bed_window, combine_motifs=combine_motifs
)

# collapse MEX bins as in your original code
combined_bed_df_mex_cat = finalize_mex_bins(combined_bed_df_mex_cat)

print("\ncombined_bed_df_mex_clust:")
print(combined_bed_df_mex_clust.head())

# ──────────────────────────────────────────────────────────────────────────────
# 6) Final combined_bed_df and optional MEX_none handling
# ──────────────────────────────────────────────────────────────────────────────
combined_bed_df = combined_bed_df_mex_cat.copy()

if include_noregion_motifs:
    non_mex_none = combined_bed_df[combined_bed_df['type'] != 'MEX_none']
    mex_none = combined_bed_df[combined_bed_df['type'] == 'MEX_none'].sample(
        n=min(100, len(combined_bed_df[combined_bed_df['type'] == 'MEX_none'])), random_state=42
    )
    combined_bed_df = pd.concat([non_mex_none, mex_none], ignore_index=True)
else:
    combined_bed_df = combined_bed_df[combined_bed_df['type'] != 'MEX_none']

print("combined_bed_df:")
print(combined_bed_df.head())
print("Count by type:\n", combined_bed_df['type'].value_counts())

# ──────────────────────────────────────────────────────────────────────────────
# 7) Cluster bar plot (unchanged behavior)
# ──────────────────────────────────────────────────────────────────────────────
category_stats = (
    fimo_expanded_df.groupby('chip_category')['cluster_count']
    .agg(['mean','count']).reset_index()
    .sort_values(by='mean', ascending=False)
)
category_stats['label'] = category_stats.apply(
    lambda x: f"{x['chip_category']}\n(n={int(x['count'])})", axis=1
)

plt.figure(figsize=(12, 6))
sns.barplot(x='label', y='mean', data=category_stats, palette='viridis')
plt.title('Average Cluster Count by Chip Category')
plt.xlabel('Chip Category')
plt.ylabel('Average Cluster Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# ──────────────────────────────────────────────────────────────────────────────
# 8) Build and save modkit BED
# ──────────────────────────────────────────────────────────────────────────────
modkit_bed_df = create_modkit_bed_df(combined_bed_df)
modkit_bed_name = "modkit_temp.bed"
temp_file_path = save_modkit_bed_to_temp(modkit_bed_df, modkit_bed_name)

print(modkit_bed_df)
print("len(modkit_bed_df):", len(modkit_bed_df))


expected = 2 * bed_window + 1
for name, t in [("combined_bed_df_mex_cat", combined_bed_df_mex_cat),
                ("combined_bed_df_mex_clust", combined_bed_df_mex_clust)]:
    widths = (t['bed_end'] - t['bed_start']).unique()
    bad = widths[widths != expected] if hasattr(widths, "__iter__") else []
    if len(bad):
        print(f"[WARN] {name} has non-standard window widths: {widths}")
    else:
        print(f"[OK] {name} window width == {expected} everywhere.")


fimo_df before filter: 203047
fimo_df after motif-specific ln(p) filter: 36853
fimo_df after dedup: 34828


In [None]:
### START CODE FOR PLOTTING PILEUP

regenerate_bit = True # SEt to true to force regenerate, otherwise load if available.

### Generate modkit pileup file, used for plotting m6A/A in a given region.
# Generating the list of output_file_names based on the given structure
out_file_names = [output_stem + "modkit-pileup-" + each_condition +"_"+ str(round(each_thresh,2))+"_"+str(each_index)+"_"+str(each_bamfrac)+ "_".join([each_type[-5:] for each_type in type_selected[-3:]]) + "_".join([each_type[0:5] for each_type in type_selected[-3]]) + str(bed_window)+".bed" for each_condition,each_thresh,each_index, each_bamfrac in zip(conditions,thresh_list,sample_indices,bam_fracs)]

# Function to run a single command
def modkit_pileup_extract(args, index):
    each_bam, each_thresh, each_condition, each_index, each_bamfrac, each_type,modkit_path, output_stem, modkit_bed_name = args

    # Use the index to get the correct file name from out_file_names
    each_output = out_file_names[index]

    # Check if the output file exists
    if not regenerate_bit:
        if os.path.exists(each_output):
            print(f"File already exists: {each_output}")
            # Read in output file and check if empty
            modkit_qc = pd.DataFrame()
            try:
                modkit_qc = pd.read_csv(each_output, sep="\t", header=None, nrows=10)
            except:
                if modkit_qc.empty:
                    print(f"File is empty: {each_output}")
                    return
            return
    print(f"Starting on: {each_output}", "with bam file: ", each_bam,"and bedfile:", modkit_bed_name)
    command = [
        modkit_path,
        "pileup",
        "--only-tabs",
        #"--ignore",
        #"m",
        "--threads",
        "10",
        #"--filter-threshold",
        #f"A:{1-each_thresh}",
        #f"A:{1-each_thresh}",
        "--mod-thresholds",
        f"a:{each_thresh}",
        "--mod-thresholds",
        f"m:{each_thresh}",
        "--ref",
        "/Data1/reference/c_elegans.WS235.genomic.fa",
        "--filter-threshold",
        f"A:{0.8}",
        "--filter-threshold",
        f"C:{0.8}",
        "--motif",
        "GC",
        "1",
        "--motif",
        "A",
        "0",
        # "--max-depth",
        # "100",
        "--log-filepath",
        output_stem + each_condition + str(each_index) + "_modkit-pileup.log",
        "--include-bed",
        modkit_bed_name,
        each_bam,
        each_output
    ]
    subprocess.run(command, text=True)


print("modkit_bed_name",modkit_bed_name)
print("temp_file_path",temp_file_path)
# Now you need to adjust the task_args to include the index
# Instead of directly zipping, enumerate one of the lists to get the index
task_args_with_index = [(args, index) for index, args in enumerate(zip(
    new_bam_files,
    thresh_list,
    conditions,
    sample_indices,
    bam_fracs,
    [type_selected]*len(new_bam_files),
    [modkit_path]*len(new_bam_files), #modkit_path or temp_file_path
    [output_stem]*len(new_bam_files),
    [temp_file_path]*len(new_bam_files), # modkit_bed_name or temp_file_path
))]

# Execute commands in parallel, unpacking the arguments and index within the map call
with Pool(
    processes=16
) as pool:
    pool.starmap(modkit_pileup_extract, task_args_with_index)


In [None]:
### PARALLELIZE BED INFO ADDING

import pandas as pd
import numpy as np
from multiprocessing import Pool, cpu_count
from tqdm import tqdm  # for progress monitoring

def add_bed_columns_no_loops(bedmethyl_df_loc, combined_bed_df):
    """
    Merges bedmethyl_df_loc with combined_bed_df based on the nearest midpoint
    and then filters out rows that fall outside the bed_start/bed_end range.
    """
    # Calculate midpoint in combined_bed_df
    combined_bed_df['midpoint'] = (combined_bed_df['bed_start'] + combined_bed_df['bed_end']) / 2
    combined_bed_df['midpoint'] = combined_bed_df['midpoint'].astype(int)
    combined_bed_df = combined_bed_df.sort_values(by='midpoint')

    bedmethyl_df_loc['start_position'] = bedmethyl_df_loc['start_position'].astype(int)

    merged_df = pd.merge_asof(
        bedmethyl_df_loc.sort_values('start_position'),
        combined_bed_df,
        by='chrom',
        left_on='start_position',
        right_on='midpoint',
        direction='nearest'
    )

    # Filter out rows where the start_position is not within [bed_start, bed_end]
    merged_df = merged_df.loc[
        (merged_df['start_position'] >= merged_df['bed_start']) &
        (merged_df['start_position'] <= merged_df['bed_end'])
    ]
    merged_df.reset_index(drop=True, inplace=True)

    final_df = pd.merge(
        bedmethyl_df_loc,
        merged_df[['chrom', 'start_position', 'bed_start', 'bed_end', 'bed_strand', 'type', 'chr_type']],
        on=['chrom', 'start_position'],
        how='left'
    )
    final_df = final_df[final_df['type'].notna()]
    return final_df

def process_group(group_tuple, num_bins, edge_window_size, sum_columns):
    """
    Applies logic to shift 'rel_start' for each group.
    """
    _, group = group_tuple
    min_pos, max_pos = group['start_position'].min(), group['start_position'].max()
    bed_start, bed_end = group['bed_start'].iloc[0], group['bed_end'].iloc[0]

    # Initialize rel_start using edge conditions
    group['rel_start'] = np.where(
        group['start_position'] < bed_start + edge_window_size,
        group['start_position'] - bed_start - edge_window_size,
        np.where(
            group['start_position'] > bed_end - edge_window_size,
            num_bins + edge_window_size - (bed_end - group['start_position']),
            100000  # sentinel placeholder
        )
    )

    # Remove points outside the “main” window if the region is large
    if (max_pos - min_pos) > (num_bins + 2 * edge_window_size):
        binning_mask = (
            (group['start_position'] >= bed_start + edge_window_size) &
            (group['start_position'] <= bed_end - edge_window_size)
        )
        bin_edges = np.linspace(bed_start + edge_window_size, bed_end - edge_window_size, num_bins + 1)
        group.loc[binning_mask, 'rel_start'] = np.digitize(
            group.loc[binning_mask, 'start_position'],
            bins=bin_edges,
            right=True
        )

    return group

def map_to_metagene_bins_and_sum(df, num_bins=1000, edge_window_size=500):
    """
    Groups df by [bed_start, chrom, modified_base_code], shifts 'rel_start',
    and sums relevant columns. This version runs in serial by default to avoid
    nested Pools.
    """
    sum_columns = ['Nmod', 'Ncanonical', 'Nother_mod', 'Ndelete', 'Nfail', 'Ndiff', 'Nnocall','Nvalid_cov']
    retain_columns = ['bed_strand', 'chr_type', 'strand', 'bed_end','type']
    group_columns = ['bed_start', 'chrom', 'modified_base_code']

    groups = list(df.groupby(group_columns))

    # Process each group in serial (no nested Pool)
    processed_groups = [
        process_group(g, num_bins, edge_window_size, sum_columns)
        for g in groups
    ]

    result_df = pd.concat(processed_groups, ignore_index=True)

    # Summation step
    sum_group_columns = group_columns + ['rel_start']
    summed_df = result_df.groupby(sum_group_columns)[sum_columns].sum().reset_index()

    merged_df = pd.merge(
        result_df[sum_group_columns + retain_columns].drop_duplicates(),
        summed_df,
        on=sum_group_columns,
        how='left'
    )
    return merged_df

def process_one_file(args):
    """
    Worker function that:
      1. Reads a bedmethyl file.
      2. Filters rows.
      3. Merges with combined_bed_df.
      4. Optionally applies metagene binning if 'gene'/'damID' in type_selected.
      5. Returns the final DataFrame (or empty if no data).
    """
    (
        each_output,
        each_condition,
        each_exp_id,
        combined_bed_df,
        type_selected,
        num_bins,
        bed_window
    ) = args

    # Columns to read
    bedmethyl_cols = [
        'chrom','start_position','end_position','modified_base_code','score','strand',
        'start_position_compat','end_position_compat','color','Nvalid_cov','fraction_modified',
        'Nmod','Ncanonical','Nother_mod','Ndelete','Nfail','Ndiff','Nnocall'
    ]
    # Read the bedmethyl file (assumes file already exists)
    bedmethyl_df = pd.read_csv(each_output, sep="\t", header=None, names=bedmethyl_cols)

    # Keep rows with specified modified_base_code
    bedmethyl_df = bedmethyl_df[
        bedmethyl_df['modified_base_code'].isin(['a,A,0','m,GC,1'])
    ]
    if bedmethyl_df.empty:
        # If no rows remain, return empty
        print(f"Empty CSV or no relevant rows in: {each_output}")
        return pd.DataFrame()

    # Sort and drop duplicates
    bedmethyl_df.sort_values(['start_position'], inplace=True)
    bedmethyl_df.dropna(inplace=True)
    bedmethyl_df.drop_duplicates(inplace=True)
    bedmethyl_df.reset_index(drop=True, inplace=True)

    # Merge with combined_bed_df
    bedmethyl_df = add_bed_columns_no_loops(bedmethyl_df, combined_bed_df)
    if bedmethyl_df.empty:
        print(f"No overlapping intervals in: {each_output}")
        return pd.DataFrame()

    # If 'gene' or 'damID' in type_selected, map to metagene bins
    if any(x in type_selected[0] for x in ['gene','damID']):
        bedmethyl_df = map_to_metagene_bins_and_sum(
            bedmethyl_df,
            num_bins=num_bins,
            edge_window_size=bed_window
        )
    else:
        # Otherwise just compute rel_start
        bedmethyl_df['rel_start'] = (
            bedmethyl_df['start_position']
            - bedmethyl_df['bed_start']
            - bed_window
            + 1
        )

    # Convert rel_start to int
    bedmethyl_df['rel_start'] = bedmethyl_df['rel_start'].astype(int)
    bedmethyl_df['condition'] = each_condition
    bedmethyl_df['exp_id'] = each_exp_id

    return bedmethyl_df

def parallel_build_combined_df(
    out_file_names,
    conditions,
    exp_ids,
    combined_bed_df,
    type_selected,
    num_bins=1000,
    bed_window=500
):
    """
    1. Constructs argument list for each bedmethyl file.
    2. Creates a Pool (multiprocessing).
    3. Processes each file in parallel with progress tracking.
    4. Concatenates non-empty results into one DataFrame.
    """
    args_list = [
        (f, cond, exp, combined_bed_df, type_selected, num_bins, bed_window)
        for (f, cond, exp) in zip(out_file_names, conditions, exp_ids)
    ]

    results = []
    with Pool(cpu_count()) as pool:
        for res in tqdm(pool.imap(process_one_file, args_list), total=len(args_list)):
            if not res.empty:
                results.append(res)

    if results:
        return pd.concat(results, ignore_index=True)
    else:
        return pd.DataFrame()

# ------------------------------------------
# Example usage (omit or adapt in your notebook):
#
# out_file_names = [...]
# conditions = [...]
# exp_ids = [...]
# combined_bed_df = [...]  # Large DataFrame
# type_selected = [...]
# num_bins = 1000
# bed_window = 500

comb_bedmethyl_df = parallel_build_combined_df(
    out_file_names,
    conditions,
    exp_ids,
    combined_bed_df,
    type_selected,
    num_bins,
    bed_window
)

display(comb_bedmethyl_df.head())


In [None]:
import numpy as np
import pandas as pd



# Define normalization type: 'global' for exp_id_m6A_frac or 'local' for intergenic mod_frac
normalization_type = 'local'  # Change to 'local' to use intergenic normalization, relative to genome wide 'global'
combine_replicates = False

### SHIFT AND TRANSFORM (OPTIONAL) FOR PILEUP PLOT
ext_target = []

def compute_lag_for_maximum_alignment(series1, bed_start1):
    """
    Decides flipping based on maximum cross-correlation, and then computes the lag
    required to align the maximum values of two series. Returns both the lag and the decision to flip.
    """
    flip = 0
    pos_max_series1 = np.argmax(series1)
    lag = (round(len(series1)/2)) - pos_max_series1
    return (lag, flip)

def get_continuous_series(df_subset):
    # Create a Series with rel_start as the index and norm_mod_frac_weighted as the values
    series_filled = df_subset.set_index('rel_start')['weighted_norm_mod_frac']

    # Fill NaNs using forward fill then backward fill
    series_filled = series_filled.fillna(method='ffill').fillna(method='bfill')

    # Ensure it's a continuous series by filling any gaps in rel_start
    try:
        series_filled = series_filled.reindex(
            range(int(series_filled.index.min()), int(series_filled.index.max()) + 1),
            fill_value=0
        )
    except Exception as e:
        print("Failed to reindex series_filled:", e)
        print("Duplicate indexes:", series_filled.index[series_filled.index.duplicated()])

    return series_filled.values

def align_profiles(df):
    df = df.sort_values(['bed_start', 'rel_start']).copy()
    bed_starts = df['bed_start'].unique()

    # Determine the reference bed_start
    summed_Nvalid_cov = df.groupby('bed_start')['Nvalid_cov'].sum()
    reference_bed_start = summed_Nvalid_cov.idxmax()
    series_reference = get_continuous_series(df[df['bed_start'] == reference_bed_start])

    # Calculate the number of positions to shift
    shift_positions = int(round(len(series_reference)/2)) - np.argmax(series_reference)

    # Shift the entire series_reference by shift_positions
    if shift_positions > 0:  # shift to the left
        series_reference = np.concatenate(([0]*shift_positions, series_reference))
    else:
        series_reference = np.concatenate((series_reference, [0]*abs(shift_positions)))

    df["flipped"] = 0

    for other_bed_start in bed_starts:
        series_to_shift = get_continuous_series(df[df['bed_start'] == other_bed_start])
        lag, flip = compute_lag_for_maximum_alignment(series_to_shift, other_bed_start)

        df.loc[df['bed_start'] == other_bed_start, 'shift'] = lag
        df.loc[df['bed_start'] == other_bed_start, 'flipped'] = 1 if flip else 0

    # Calculate statistics using the 'flipped' and 'shift' columns
    total_flipped = df[df['flipped'] == 1]['bed_start'].nunique()
    lag_distribution = df['shift'].describe()

    print(f"Total bed_starts flipped: {total_flipped} out of {len(bed_starts) - 1}")
    print("Lag Distribution:")
    print(lag_distribution)

    return df

print("Copying and dropping rows...")
comb_bedmethyl_plot_df = comb_bedmethyl_df.copy()

# find unique values in combined_bedmethyl_df type
unique_types = comb_bedmethyl_plot_df['type'].unique()

# if we did not center on motifs

if center_on_motifs == True:
    # drop any nan keys from motif_lengths
    motif_lengths = {k: v for k, v in motif_lengths.items() if pd.notnull(k)}

# Adjust rel_start based on strand and type
for each_type in unique_types:
    print(f"Adjusting rel_start for {each_type}...")
    if any(x in each_type for x in ["TSS", "TES", "gene","MEX","motif"]): #
        print(f"Strand orientation sensitive {each_type} type selected, multiplying rel_start by -1 for '-' strand genes...")
        if 'gene' in each_type:
            # Subtract half of num_bins from rel_start for metagene profiles to center them
            comb_bedmethyl_plot_df['rel_start'] -= num_bins / 2
            
        # Mask for types and negative strands
        mask = (comb_bedmethyl_plot_df['type'] == each_type) & (comb_bedmethyl_plot_df['bed_strand'] == '-')
        comb_bedmethyl_plot_df.loc[mask, 'rel_start'] *= -1
        
        if center_on_motifs == True:
            # for each unique key in motif_length   
            for each_key in motif_lengths.keys():
                # if each_key is in type column then add key value to rel_start for rows where type is each_type
                if each_key in each_type:
                    comb_bedmethyl_plot_df.loc[comb_bedmethyl_plot_df['type'] == each_type, 'rel_start'] += motif_lengths[each_key]

        # Adjust bigwig lines if ext_target is not empty
        if ext_target:
            mask_bw = (bw_df['type'] == each_type) & (bw_df['bed_strand'] == '-')
            bw_df.loc[mask_bw, 'rel_start'] *= -1
        if 'gene' in each_type:
            # Add half of num_bins back to rel_start
            comb_bedmethyl_plot_df['rel_start'] += num_bins / 2

if combine_replicates:
    print("Combining replicates based on conditions...")

    # Define mappings for condition and exp_id
    condition_map = {
        "N2_mixed": analysis_cond[0],
        "SDC2_degron_mixed": analysis_cond[1]#,
        #"N2_old_SMAC":analysis_cond[6],
        #"DPY27": analysis_cond[8]
    }

    for key, value in condition_map.items():
        # Update condition and exp_id based on the mapping
        mask = comb_bedmethyl_plot_df['condition'].str.contains(key, na=False)
        comb_bedmethyl_plot_df.loc[mask, 'condition'] = value
        comb_bedmethyl_plot_df.loc[mask, 'exp_id'] = value

# # Group the DataFrame
# grouped = comb_bedmethyl_plot_df.groupby(
#     ['chrom', 'rel_start', 'exp_id', 'modified_base_code', 'condition', 'type', 'chr_type', 'bed_start']
# ).agg({
#     'Nvalid_cov': 'sum'
# }).reset_index()
#
# print("Grouped DataFrame created.")

print("Grouping by chrom, rel_start, exp_id, modified_base_code, condition, type, chr_type, bed_start...")
# Group and aggregate necessary columns
grouped_df = comb_bedmethyl_plot_df.groupby(
    ['chrom', 'rel_start', 'exp_id', 'modified_base_code', 'condition', 'type', 'chr_type','strand']
).agg({
    'Nvalid_cov': 'sum',
    'Nmod': 'sum',
    'Ncanonical': 'sum',
    'Nother_mod': 'sum'
}).reset_index()

print("Calculating normalized m6A...")
### Calculate normalized m6A
grouped_df['raw_mod_frac'] = grouped_df['Nmod'] / (grouped_df['Nmod'] + grouped_df['Ncanonical'])

# Ensure no stray whitespace in exp_id
grouped_df['exp_id'] = grouped_df['exp_id'].str.strip()

# Helper to compute fraction safely
def _frac(g):
    nmod = g['Nmod'].sum()
    ncan = g['Ncanonical'].sum()
    total = nmod + ncan
    return np.nan if total == 0 else nmod / total

# Compute global mod fractions per exp_id from grouped_df itself
m6a_global = (
    grouped_df[grouped_df['modified_base_code'] == 'a,A,0']
    .groupby('exp_id', as_index=False)
    .apply(_frac)
    .rename(columns={None: 'exp_id_m6A_frac'})
)

m5mc_global = (
    grouped_df[grouped_df['modified_base_code'] == 'm,GC,1']
    .groupby('exp_id', as_index=False)
    .apply(_frac)
    .rename(columns={None: 'exp_id_5mC_frac'})
)

global_fracs = pd.merge(m6a_global, m5mc_global, on='exp_id', how='outer')

# Merge back to per-position data
merged_df = pd.merge(grouped_df, global_fracs, on='exp_id', how='left')

# Optional: if a mod type is missing for an exp_id, avoid NaNs in normalization
merged_df[['exp_id_m6A_frac','exp_id_5mC_frac']] = (
    merged_df[['exp_id_m6A_frac','exp_id_5mC_frac']].fillna(1.0)
)

### Normalization based on selected type
if normalization_type == 'global':
    # Global normalization using exp_id_m6A_frac for m6A and exp_id_5mC_frac for 5mC
    # We'll create a new column that assigns the appropriate global fraction based on modified_base_code.
    merged_df['norm_mod_frac_init'] = np.where(
        merged_df['modified_base_code'] == 'a,A,0',
        merged_df['exp_id_m6A_frac'],
        np.where(
            merged_df['modified_base_code'] == 'm,GC,1',
            merged_df['exp_id_5mC_frac'],
            np.nan  # If there are other modification codes, handle them here
        )
    )
    normalization_label = 'global_normalization'
elif normalization_type == 'local':
    # Local normalization using intergenic mod_frac
    # Filter for intergenic types
    intergenic_df = grouped_df[grouped_df['type'].str.contains("intergenic", case=False)]

    # Compute per exp_id local mod_frac
    local_mod_frac = intergenic_df.groupby('exp_id').agg({
        'Nmod': 'sum',
        'Ncanonical': 'sum'
    }).reset_index()
    local_mod_frac['mod_frac_local'] = local_mod_frac['Nmod'] / (local_mod_frac['Nmod'] + local_mod_frac['Ncanonical'])

    # Merge with main dataframe
    merged_df = pd.merge(
        merged_df,
        local_mod_frac[['exp_id', 'mod_frac_local']],
        on='exp_id',
        how='left'
    )

    # Handle cases where mod_frac_local might be NaN (e.g., no intergenic data for an exp_id)
    merged_df['mod_frac_local'] = merged_df['mod_frac_local'].fillna(1)  # Assuming no normalization if local mod_frac is missing

    # Compute normalized mod_frac
    merged_df['norm_mod_frac_init'] = merged_df['mod_frac_local']
    normalization_label = 'mod_frac_local'
else:
    raise ValueError("Invalid normalization_type. Choose 'global' or 'local'.")

print(f"Normalization method: {normalization_type} ({normalization_label})")

# Select relevant columns for plotting
plot_df = merged_df[grouped_df.columns.tolist() + ['norm_mod_frac_init']]

### Since multiple samples have same condition:
plot_df = plot_df.groupby(
    ['rel_start', 'modified_base_code', 'condition', 'type', 'chr_type', 'norm_mod_frac_init','strand']
)[['Nvalid_cov', 'Ncanonical', 'Nmod']].sum().reset_index()

if ext_target:
    plot_comb_bigwig_df = bw_df.groupby(
        ['rel_start', 'chrom', 'condition', 'type', 'chr_type']
    )['value'].mean().reset_index()
else:
    plot_comb_bigwig_df = pd.DataFrame()

# Recompute raw_mod_frac and weighted_norm_mod_frac
plot_df['raw_mod_frac'] = plot_df['Nmod'] / (plot_df['Nmod'] + plot_df['Ncanonical'])
plot_df['weighted_norm_mod_frac'] = plot_df['raw_mod_frac'] / plot_df['norm_mod_frac_init']

# Sort by rel_start
plot_df.sort_values(['rel_start'], inplace=True)
plot_df.reset_index(inplace=True, drop=True)

print("plot_df:")
# Display sample rows for debugging
nanotools.display_sample_rows(plot_df, 10)
if ext_target:
    nanotools.display_sample_rows(plot_comb_bigwig_df, 10)


In [None]:
force_replace = True
# save final_df to /temp folder as csv, with all configurations in file name if it does not exist. If it does exist, import it.
final_fn = "temp_files/" + "final_df_" + "_".join([each_type for each_type in type_selected[-3:]]) + str(round(thresh_list[0],2)) + "_"+str(bam_fracs[0])+str(bed_window)+".csv"
final_fn_chip = "temp_files/" + "final_df_chip" + "_".join([each_type for each_type in type_selected[-3:]]) + str(round(thresh_list[0],2)) + "_"+str(bam_fracs[0])+str(bed_window)+".csv"

if not force_replace and os.path.exists(final_fn):
    print("final_df already exists, importing it...")
    plot_df = pd.read_csv(final_fn)
    nanotools.display_sample_rows(plot_df,5)
else:
    print("final_df does not exist, saving it...")
    plot_df.to_csv(final_fn, index=False)

# if plot_comb_bigwig_df dataframe does not exist:
try:
    if not force_replace and os.path.exists(final_fn_chip):
        print("final_df_chip already exists, importing it...")
        plot_comb_bigwig_df = pd.read_csv(final_fn_chip)
        nanotools.display_sample_rows(plot_comb_bigwig_df,5)
    else:
        print("final_df_chip does not exist, saving it...")
        plot_comb_bigwig_df.to_csv(final_fn_chip, index=False)
except:
    print("plot_comb_bigwig_df does not exist, skipping...")

In [None]:
importlib.reload(nanotools)
from scipy.signal import gaussian
import scipy.ndimage
from scipy import signal          # ← new import


def plot_bedmethyl(
    comb_bedmethyl_df,
    conditions_input,
    chr_types=None,
    types=None,
    strands=["all"],
    window_size=50,
    metagene_bins=1000,
    smoothing_type="weighted",
    selection_indices=None,
    bed_window=[-500,500],
    mod_types=['m6A'],
    ignore_selec=[],
    bigwig_df=None,
    bw_selections=None,
    plot_type="raw",
    plot_x_over_a=True,  # <--- The new boolean parameter
    y_range=None
):
    # --- chip‑rank‑based motif filtering ------------------------------
    comb_bedmethyl_df, auto_types = apply_chiprank_filter(
        comb_bedmethyl_df,
        chip_rank_cutoff=CHIP_RANK_CUTOFF,
        above_flag=ABOVE_FLAG,
        types_to_include=TYPES_TO_INCLUDE,
    )

    # if the caller did not supply an explicit `types` argument, use the one
    # returned by the filter; otherwise respect the caller’s list verbatim
    if not types:
        types = auto_types

    # ------
    # 1. Prepare figure containers, global variables, etc.
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    cov_fig = make_subplots(specs=[[{"secondary_y": True}]])
    y_min = float('inf')
    y_max = float('-inf')

    # For storing final x and y (smoothed) data so we can later compute X_over_A
    results_dict = {}

    # Reduce DataFrame to bed_window
    comb_bedmethyl_df = comb_bedmethyl_df[
        (comb_bedmethyl_df['rel_start'] >= bed_window[0]) &
        (comb_bedmethyl_df['rel_start'] <= bed_window[1])
    ]

    if selection_indices is not None:
        conditions = [conditions_input[i] for i in selection_indices]
    else:
        conditions = conditions_input

    # ------
    # 2. Main loop over conditions, modifications, chr_types, types, strands
    for selected_condition in conditions:
        print("Starting on condition:", selected_condition)

        for selected_modification in (mod_types or ["all"]):
            skip_line = False
            if selected_modification == '5mC':
                selected_mod = 'm,GC,1'
            elif selected_modification == 'm6A':
                selected_mod = 'a,A,0'
            else:
                selected_mod = 'all'

            # Optionally skip certain (condition, modification) combos
            for each in ignore_selec:
                if (selected_condition == conditions_input[each[0]] 
                    and selected_modification == each[1]):
                    print("Skipping:", each[0], "with meth:", each[1])
                    skip_line = True
                    break
            if skip_line:
                continue

            for selected_chr_type in (chr_types or ["all"]):
                for selected_type in (types or ["all"]):
                    for selected_strand in (strands or ["all"]):
                        # Apply filters
                        filters = []
                        filters.append(comb_bedmethyl_df['condition'] == selected_condition)
                        if selected_chr_type != "all":
                            filters.append(comb_bedmethyl_df['chr_type'] == selected_chr_type)
                        if selected_type != "all":
                            filters.append(comb_bedmethyl_df['type'] == selected_type)
                        if selected_strand != "all":
                            filters.append(comb_bedmethyl_df['strand'] == selected_strand)
                        if selected_mod != "all":
                            filters.append(comb_bedmethyl_df['modified_base_code'] == selected_mod)

                        base_filter = np.logical_and.reduce(filters)

                        data_filtered = comb_bedmethyl_df.loc[
                            base_filter, 
                            ['weighted_norm_mod_frac', 'raw_mod_frac','rel_start', 'Nvalid_cov']
                        ].copy()

                        if data_filtered.empty:
                            print(f"No data for {selected_condition}, {selected_modification}, "
                                  f"{selected_chr_type}, {selected_type}, {selected_strand}")
                            continue

                        # Right after data_filtered is defined:
                        # 1) Pull out only the columns we need, and compute “Nmod = raw_mod_frac * Nvalid_cov”
                        tmp = comb_bedmethyl_df.loc[
                            base_filter,
                            ['rel_start', 'raw_mod_frac', 'Nvalid_cov']
                        ].copy()
                        tmp['Nmod'] = tmp['raw_mod_frac'] * tmp['Nvalid_cov']

                        # 2) Group by rel_start, summing Nmod and Nvalid_cov
                        agg = (
                            tmp
                            .groupby('rel_start', as_index=False)
                            .agg({'Nmod':'sum', 'Nvalid_cov':'sum'})
                        )

                        # 3) Recompute raw_mod_frac at each rel_start
                        agg['raw_mod_frac'] = agg['Nmod'] / agg['Nvalid_cov']

                        # 4) If you also need weighted_norm_mod_frac, you can recalc it here
                        #    (but if you only use “raw” in your plot_type, you can drop it)
                        #    For now, let’s drop weighted_norm_mod_frac and carry forward raw/Nvalid_cov:
                        data_filtered = agg[['rel_start','raw_mod_frac','Nvalid_cov']].copy()


                        # Create a full rel_start range and merge
                        full_range_df = pd.DataFrame({
                            'rel_start': range(
                                int(data_filtered['rel_start'].min()),
                                int(data_filtered['rel_start'].max() + 1)
                            )
                        })
                        merged_df = pd.merge(
                            full_range_df, data_filtered,
                            on='rel_start', how='left'
                        )

                        # Fill or drop missing
                        if smoothing_type != "weighted":
                            merged_df.fillna({
                                'raw_mod_frac': 0,
                                'Nvalid_cov': 0,
                                'weighted_norm_mod_frac': 0
                            }, inplace=True)
                        else:
                            # Weighted smoothing requires valid coverage
                            merged_df.dropna(subset=[
                                'raw_mod_frac',
                                #'weighted_norm_mod_frac',
                                'Nvalid_cov'
                            ], inplace=True)
                            merged_df.reset_index(drop=True, inplace=True)

                        # Decide which column to plot
                        if plot_type == "raw":
                            m6A_data = merged_df['raw_mod_frac']
                        else:  # "norm"
                            #m6A_data = merged_df['weighted_norm_mod_frac']
                            print("This option has been deprecated, using raw_mod_frac instead.")

                        m6A_data_xaxis = merged_df['rel_start']
                        Nvalid_cov_data = merged_df['Nvalid_cov']
                        smoothed_cov_data = Nvalid_cov_data.rolling(
                            window=window_size, center=True
                        ).sum()

                        print(f"[DEBUG] >>> Entering plot_bedmethyl with smoothing_type={smoothing_type!r}, window_size={window_size}")

                        # … then, after you build merged_df but before the if/elif ladder:
                        print(f"[DEBUG] merged_df rows = {len(merged_df)}; first few raw_mod_frac =\n{m6A_data.head()}")
                        # Smoothing for m6A_data
                        if smoothing_type == "none":
                            smoothed_data = m6A_data
                        # Then if you choose weighted‐rolling averaging:
                        elif smoothing_type == "weighted":
                            def weighted_rolling_average(values, weights, wsize):
                                def calc_wavg(window):
                                    w = weights[window.index]
                                    tot = w.sum()
                                    if tot == 0:
                                        return 0.0           # or choose np.nan if you prefer
                                    return (window * w).sum() / tot

                                return values.rolling(window=wsize, center=True) \
                                             .apply(calc_wavg, raw=False)

                            smoothed_data = weighted_rolling_average(
                                m6A_data, Nvalid_cov_data, window_size
                            )
                            # Quick debug to make sure we now see non‐NaNs inside the array:
                            print("[DEBUG] weighted branch → some smoothed_data values:",
                                  smoothed_data.dropna().head(5))

                            print(f"[DEBUG] first few smoothed_data =\n{smoothed_data.head(10)}")
                        elif smoothing_type == "gaussian":
                            smoothed_data_array = scipy.ndimage.gaussian_filter1d(
                                m6A_data, sigma=window_size
                            )
                            smoothed_data = pd.Series(smoothed_data_array,
                                                      index=m6A_data.index)
                        elif smoothing_type == "exponential":
                            def exponential_decay_smoothing(x, alpha=0.1):
                                s = np.zeros_like(x)
                                s[0] = x.iloc[0]
                                for t in range(1, len(x)):
                                    s[t] = alpha * x.iloc[t] + (1 - alpha)*s[t-1]
                                return s
                            def symmetrical_exponential_smoothing(x, alpha=0.1):
                                fwd = exponential_decay_smoothing(x, alpha)
                                bwd = exponential_decay_smoothing(x[::-1], alpha)[::-1]
                                return (fwd + bwd) / 2
                            alpha = 0.05
                            smoothed_data_array = symmetrical_exponential_smoothing(m6A_data, alpha=alpha)
                            smoothed_data = pd.Series(smoothed_data_array, index=m6A_data.index)
                        elif smoothing_type == "lowess":
                            from statsmodels.nonparametric.smoothers_lowess import lowess
                            smoothed = lowess(
                                m6A_data, m6A_data_xaxis, frac=0.05, it=0
                            )
                            # x-values and smoothed data
                            xarr, yarr = smoothed[:, 0], smoothed[:, 1]
                            smoothed_data = pd.Series(yarr, index=xarr)
                        else:
                            smoothed_data = m6A_data.rolling(
                                window=window_size, center=True
                            ).mean()

                        y_min = min(y_min, smoothed_data.min())
                        y_max = max(y_max, smoothed_data.max())

                        label = (f"{selected_condition}_{selected_chr_type}_"
                                 f"{selected_type}_{selected_strand}_{selected_modification}")

                        # Add the main mod fraction trace
                        key = f"{selected_condition}_{selected_type}"
                        color = nanotools.get_colors(key)
                        fig.add_trace(
                            go.Scatter(
                                x=m6A_data_xaxis.values,
                                y=smoothed_data.values,
                                mode='lines',
                                name=label,
                                opacity=0.9,
                                line=dict(width=3, color=color)
                            ),
                            secondary_y=False
                        )

                        # Also add coverage trace
                        key = f"{selected_condition}_{selected_type}"
                        color = nanotools.get_colors(key)
                        cov_fig.add_trace(
                            go.Scatter(
                                x=m6A_data_xaxis.values,
                                y=smoothed_cov_data.values,
                                mode='lines',
                                name=label + "_Nvalid_cov",
                                opacity=0.9,
                                line=dict(width=3, color=color)
                            ),
                            secondary_y=False
                        )

                        # Store final (x, y) so we can compute X_over_A later
                        results_dict[
                            (
                                selected_condition,
                                selected_modification,
                                selected_type,
                                selected_strand,
                                plot_type,
                                smoothing_type,
                                selected_chr_type
                            )
                        ] = (
                            m6A_data_xaxis.values,
                            smoothed_data.values
                        )

    # ------
    # 3. If bigwig data is provided, plot that as well on secondary_y=True
    if bigwig_df is not None:
        print("Plotting bigwig_df...")
        bigwig_df = bigwig_df[
            (bigwig_df['rel_start'] >= bed_window[0]) &
            (bigwig_df['rel_start'] <= bed_window[1])
        ]

        if bw_selections is not None:
            for bw_selection in bw_selections:
                for selected_chr_type in (chr_types or ["all"]):
                    for selected_type in (types or ["all"]):
                        for selected_strand in (strands or ["all"]):
                            filters = []
                            filters.append(bigwig_df['condition'] == bw_selection)
                            if selected_chr_type != "all":
                                filters.append(bigwig_df['chr_type'] == selected_chr_type)
                            if selected_type != "all":
                                filters.append(bigwig_df['type'] == selected_type)
                            if selected_strand != "all":
                                filters.append(bigwig_df['strand'] == selected_strand)

                            base_filter = np.logical_and.reduce(filters)

                            value_data = bigwig_df.loc[base_filter, 'value']
                            value_data_xaxis = bigwig_df.loc[base_filter, 'rel_start']
                            smoothed_data = value_data.rolling(window=window_size, center=True).mean()
                            y_min = min(y_min, smoothed_data.min())
                            y_max = max(y_max, smoothed_data.max())

                            label = f"{bw_selection}_{selected_chr_type}_{selected_type}_{selected_strand}"
                            fig.add_trace(
                                go.Scatter(
                                    x=value_data_xaxis.values,
                                    y=value_data.values,
                                    mode='lines',
                                    name=label,
                                    opacity=0.9,
                                    line=dict(width=3)
                                ),
                                secondary_y=True
                            )

    # ------
    # 4. Compute and plot X_over_A if both conditions are met
    if (
        plot_x_over_a
        and chr_types
        and 'X' in chr_types
        and 'Autosome' in chr_types
    ):
        # Collect all combos ignoring the actual chr_type
        unique_combos = set(
            (cond, mod, typ, st, ptype, smth)
            for (cond, mod, typ, st, ptype, smth, ctype) in results_dict.keys()
        )
        for combo in unique_combos:
            cond, mod, typ, st, ptype, smth = combo
            # Must have both X and Autosome stored
            keyX = (cond, mod, typ, st, ptype, smth, 'X')
            keyA = (cond, mod, typ, st, ptype, smth, 'Autosome')
            if keyX in results_dict and keyA in results_dict:
                xX, yX = results_dict[keyX]
                xA, yA = results_dict[keyA]
                dfX = pd.DataFrame({'x': xX, 'valX': yX})
                dfA = pd.DataFrame({'x': xA, 'valA': yA})
                merged = pd.merge(dfX, dfA, on='x', how='inner')
                # Avoid dividing by zero
                merged['ratio'] = merged['valX'] / merged['valA'].replace(0, np.nan)
                
                label = f"{cond}_X_over_A_{typ}_{st}_{mod}"
                fig.add_trace(
                    go.Scatter(
                        x=merged['x'],
                        y=merged['ratio'],
                        mode='lines',
                        name=label,
                        opacity=0.9,
                        line=dict(width=3),
                    ),
                    secondary_y=False
                )
        
        # remove all traces that are not X_over_A
        fig.data = [trace for trace in fig.data if "X_over_A" in trace.name]
        

    # ------
    # ------ 5. Final layout and show
    border_shape = dict(
        type="rect",
        x0=0, y0=0, x1=0.95, y1=1,
        xref="paper", yref="paper",
        line=dict(color="white", width=2),
        fillcolor='rgba(0,0,0,0)',
    )

    print("Adjusting plot formatting...")

    # Title logic
    if types:
        plot_title = "m6A Fraction" + "_".join([each_type for each_type in types])
    else:
        plot_title = "m6A Fraction"

    # Turn off background grids and zero lines
    fig.update_xaxes(showgrid=False, zeroline=False)
    fig.update_yaxes(showgrid=False, zeroline=False)

    # determine dynamic x-axis label based on types
    axis_title = "Genomic Position"
    if any("motif" in t.lower() for t in (types or [])):
        axis_title = "<i>rex</i> motif"
    elif any("rex" in t.lower() for t in (types or [])):
        axis_title = "<i>rex</i>"

    fig.update_layout(
        title=plot_title,
        xaxis_title=axis_title,
        template="plotly_white",
        width=800,
        height=800,
        title_font=dict(size=24, color="black"),
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        shapes=[border_shape],
        font=dict(color='black'),  # all text black
        xaxis_title_font=dict(size=20, color="black"),
        yaxis_title_font=dict(size=20, color="black"),
    )
    fig.update_yaxes(title_text="modBase/Base", secondary_y=False)
    fig.update_yaxes(title_text="ChIP enrichment", secondary_y=True)

    if plot_type == "raw":
        fig.update_yaxes(tickformat=".0%")

    # Update legend
    fig.update_layout(
        legend=dict(
            traceorder="normal",
            y=-0.2,
            x=0.25,
            yanchor="top",
            orientation='h',
            font=dict(size=18, color="black"),  # legend text black
        )
    )

    # Axis ticks and lines in black, disable zero lines, and set axis line width
    fig.update_xaxes(
        tickfont=dict(size=20, color="black"),
        ticks='outside',
        ticklen=10,
        tickwidth=2,
        tickcolor='black',
        showline=True,
        linecolor='black',
        linewidth=2,
        zeroline=False
    )
    fig.update_yaxes(
        tickfont=dict(size=20, color="black"),
        ticks='outside',
        ticklen=10,
        tickwidth=2,
        tickcolor='black',
        showline=True,
        linecolor='black',
        linewidth=2,
        zeroline=False,
        tickformat=".2f"
    )

    if y_range is not None:
        fig.update_yaxes(range=y_range)
    else:
        fig.update_yaxes(range=[y_min-0.03, y_max+0.03])

    percent_axis = (y_max <= 1)

    # build common y-axis args
    yaxis_args = dict(
        title_text="m6A/A",
        secondary_y=False
    )
    if percent_axis:
        yaxis_args.update(
            tickformat=".0%",
        )

    fig.update_yaxes(**yaxis_args)

    # Now do cov_fig layout
    cov_fig.update_xaxes(showgrid=False, zeroline=False)
    cov_fig.update_yaxes(showgrid=False, zeroline=False)

    if types:
        plot_title_cov = "Motif Count" + "_".join([each_type for each_type in types])
    else:
        plot_title_cov = "Motif Count"

    cov_fig.update_layout(
        title=plot_title_cov,
        xaxis_title='Genomic Position',
        template="plotly_white",
        width=800,
        height=800,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        shapes=[border_shape],
        font=dict(color='black')  # all text black
    )
    cov_fig.update_yaxes(title_text="Nvalid_cov", secondary_y=False)
    cov_fig.update_layout(
        legend=dict(
            traceorder="normal",
            y=-1,
            x=0.25,
            yanchor="top",
            orientation='h',
            font=dict(size=18, color="black"),  # legend text black
        )
    )
    cov_fig.update_xaxes(range=[bed_window[0], bed_window[1]])

    # Axis ticks and lines in black for cov_fig, disable zero lines, set axis line width
    cov_fig.update_xaxes(
        tickfont=dict(size=20, color="black"),
        ticks='outside',
        ticklen=10,
        tickwidth=2,
        tickcolor='black',
        showline=True,
        linecolor='black',
        linewidth=2,
        zeroline=False
    )
    cov_fig.update_yaxes(
        tickfont=dict(size=20, color="black"),
        ticks='outside',
        ticklen=10,
        tickwidth=2,
        tickcolor='black',
        showline=True,
        linecolor='black',
        linewidth=2,
        zeroline=False
    )

    # Optionally draw TSS and TES lines.
    # Only plot TSS if the bed_window spans 0.
    if ((any('TSS' in t for t in (types or [])) or any('gene' in t for t in (types or []))
         or any('damID' in t for t in (types or [])))
         and (bed_window[0] <= 0 <= bed_window[1])):
        fig.add_shape(
            type="line", x0=0, y0=0, x1=0, y1=1,
            line=dict(color="black", width=1.5),
            xref="x", yref="paper"
        )
        fig.add_annotation(
            x=0, y=1, yref="paper", text="TSS",
            showarrow=False, yanchor="bottom", xanchor="center",
            font=dict(color="black")
        )
        cov_fig.add_shape(
            type="line", x0=0, y0=0, x1=0, y1=1,
            line=dict(color="black", width=1.5),
            xref="x", yref="paper"
        )

    # Only plot TES if the bed_window spans metagene_bins.
    if ((any('TES' in t for t in (types or [])) or any('gene' in t for t in (types or []))
         or any('damID' in t for t in (types or [])))
         and (bed_window[0] <= metagene_bins <= bed_window[1])):
        fig.add_shape(
            type="line", x0=metagene_bins, y0=0, x1=metagene_bins, y1=1,
            line=dict(color="black", width=1.5),
            xref="x", yref="paper"
        )
        fig.add_annotation(
            x=metagene_bins, y=1, yref="paper", text="TES",
            showarrow=False, yanchor="bottom", xanchor="center",
            font=dict(color="black")
        )
        cov_fig.add_shape(
            type="line", x0=metagene_bins, y0=0, x1=metagene_bins, y1=1,
            line=dict(color="black", width=1.5),
            xref="x", yref="paper"
        )

    fig.show(renderer='plotly_mimetype+notebook')
    cov_fig.show(renderer='plotly_mimetype+notebook')


    # Return whichever you like. Here, just return the main fig and last label

    # ------------------------------------------------------------------ helpers
    def fwhm_scipy(x, y, prominence=0.05):
        """FWHM (bp) of tallest peak, half‑prominence definition."""
        peaks, props = signal.find_peaks(y, prominence=prominence)  # :contentReference[oaicite:4]{index=4}
        if peaks.size == 0:
            return np.nan
        idx = np.argmax(props["prominences"])                       # tallest
        w, *_ = signal.peak_widths(y, [peaks[idx]])                 # :contentReference[oaicite:5]{index=5}
        return w[0] * np.diff(x).mean()

    def auc_valley_to_valley(x, y, prominence=0.05):
        """
        Integrate peak area above the straight‑line baseline connecting the
        two foot points that define the tallest peak’s prominence.
        """
        peaks, props = signal.find_peaks(y, prominence=prominence)
        if peaks.size == 0:
            return np.nan
        idx   = np.argmax(props["prominences"])
        left  = props["left_bases"][idx]
        right = props["right_bases"][idx]

        # linear baseline through the two bases
        m = (y[right] - y[left]) / (x[right] - x[left])
        b = y[left] - m * x[left]
        baseline = m * x[left:right + 1] + b

        return np.trapz(y[left:right + 1] - baseline, x[left:right + 1])

    # ------------------------------------------------------------------ main loop
    if window_size == 101 and bed_window == [-500, 500]:
        records = []
        for key, (xarr, yarr) in results_dict.items():
            cond, mod, typ, strand, ptype, smth, chr_type = key
            if mod != 'm6A' or 'intergenic' in typ.lower():
                continue

            width   = fwhm_scipy(xarr, yarr, prominence=0.05)
            peak_auc = auc_valley_to_valley(xarr, yarr, prominence=0.05)

            if not np.isnan(width):
                records.append({
                    "condition":     cond,
                    "type":          typ,
                    "window_size":   window_size,
                    "fwhm":          width,
                    "peak_auc":      peak_auc,
                })

        # ---------- upsert CSV exactly as before ------------------------------
        fwhm_file   = f"{output_stem}_fwhm.csv"
        fwhm_df_new = pd.DataFrame(records)

        if os.path.exists(fwhm_file):
            existing  = pd.read_csv(fwhm_file)
            key_cols  = ["condition", "type", "window_size"]
            keep_mask = ~existing.set_index(key_cols).index.isin(
                fwhm_df_new.set_index(key_cols).index
            )
            fwhm_out = pd.concat([existing.loc[keep_mask], fwhm_df_new],
                                 ignore_index=True)
        else:
            fwhm_out = fwhm_df_new

        fwhm_out.to_csv(fwhm_file, index=False)
        print("FWHM / Peak‑AUC DataFrame:")
        print(fwhm_out)
    else:
        print("FWHM/AUC calculation skipped for non‑default window/bounds")



    return fig, plot_title

# ───────────────── CHIP‑RANK TYPE FILTER ───────────────── #
def apply_chiprank_filter(
    df: pd.DataFrame,
    chip_rank_cutoff: int      = None,
    above_flag: bool           = None,
    types_to_include: list     = None,
    chip_rank_path: str        = "/Data1/reference/rex_chiprank.bed",
):
    """
    Restrict *df* to the desired motif types and, if requested, collapse the
    surviving rows into the synthetic label 'chip_g_t_##' or 'chip_l_t_##'.

    Returns
    -------
    filtered_df : pd.DataFrame
    selected_types : list[str]      # exact list to feed into plot_bedmethyl()
    """
    # If the caller did not pass these, grab the module‐level defaults:
    if chip_rank_cutoff is None:
        chip_rank_cutoff = CHIP_RANK_CUTOFF
    if above_flag is None:
        above_flag = ABOVE_FLAG
    if types_to_include is None:
        types_to_include = TYPES_TO_INCLUDE

    # Reload chip-rank table so the function is self-contained
    chiprank_df = (
        pd.read_csv(chip_rank_path, sep=r"\s+")
          .assign(type=lambda d: "MOTIFS_" + d["type"].astype(str))
    )
    chip_rank_lookup = {
        t: round(float(rk) * 100, 3)
        for t, rk in zip(chiprank_df["type"], chiprank_df["chip_rank"])
    }

    # 1. Decide which motif types to keep
    if types_to_include:                              # explicit list wins
        keep_types = set(types_to_include)
        synthetic_label = None                        # keep originals
    else:                                             # rank-based selection
        keep_types = {
            t for t, r in chip_rank_lookup.items()
            if (r >= chip_rank_cutoff) == above_flag
        }
        synthetic_label = f"chip_{'g' if above_flag else 'l'}t_{chip_rank_cutoff}"

    # 2. Slice the input DF
    tmp_df = df[df["type"].isin(keep_types)].copy()

    # 3. Optional renaming
    if synthetic_label is not None and not tmp_df.empty:
        tmp_df["type"] = synthetic_label
        selected_types = [synthetic_label]
    else:
        selected_types = sorted(keep_types)

    return tmp_df, selected_types



#Display random 100 rows from comb_bedmethyl_plot_df
print("type_selected:",type_selected)
print("plot_df:")
nanotools.display_sample_rows(plot_df,20)

window_s = 25
smoothing_type = "weighted" #gaussian # exponential # none # weighted
bed_w = 500 #bed_window
num_bins = 1000
plot_type = "raw" # or norm or raw
plot_x_over_a = False

CHIP_RANK_CUTOFF = 80
ABOVE_FLAG = False
TYPES_TO_INCLUDE = []
# Example usage:
# Note: final_df and conditions should be defined in your code

print("Unique strands in DataFrame:", comb_bedmethyl_df['strand'].unique())
print(comb_bedmethyl_df.groupby('strand').size())


# "N2_old_fiber_R10","96_DPY27_degron_old","107_SDC2_degron_old_R10","52_old_dpy21jmjc_fiber_R10","51_old_dpy21null_fiber_R10","N2_mixed_R9","SDC2_degron_mixed_R9","SDC3_degron_old_R10"

region_fig = plot_bedmethyl(plot_df, analysis_cond, chr_types=["X"], types=[], strands=[], window_size=window_s, metagene_bins=num_bins, smoothing_type=smoothing_type,selection_indices=[0,1,2], bed_window=[-bed_w,bed_w], mod_types=["m6A"],ignore_selec=[], plot_type = plot_type, plot_x_over_a = plot_x_over_a, y_range=[0.02,0.18])#,bw_selections=["sdc2_chip_albritton","sdc3_chip_albritton","sdc3_chip_anderson","dpy27_chip_anderson"],bigwig_df=plot_comb_bigwig_df)#[1,'5mC'],[3,'5mC']])# #
# smoothing types: "gaussian", "weighted", "rolling"

#8,9,10,11,14
#7,12,13

#7,8,9,10,11,12,13,14



#analysis_cond = ["N2_mixed_DPY27_dimelo_pAHia5_R10","50_mixed_dpy27-3xGNB_GFP-Hia5_mcvipi_R10","66_old_sdc2_3xGNB_GFPHia5_mChMCVIPI","N2_mixed_endogenous_R10","54_mixed_sdc2_3xmCNB_mChMCVIPI_GFPHia5"]

# print unique count bed_start values for combination of chr_type, type and condition in each comb_bedmethyl_df
#print("Unique count of bed_start values for each combination of chr_type, type and condition in comb_bedmethyl_df:")
#print(plot_df.groupby(['chr_type','type','condition'])['bed_start'].nunique())
prefix = "all_reps_N2_DPY27_"
#rand_suffix = nanotools.random_alpha_numeric(8)
region_fig[0].write_image("/Data1/git/meyer-nanopore/scripts/analysis/images_20250604/"+prefix+smoothing_type+"_"+region_fig[1]+".svg")
region_fig[0].write_image("/Data1/git/meyer-nanopore/scripts/analysis/images_20250604/"+prefix+smoothing_type+"_"+region_fig[1]+".png")

#"center_DPY27_chip_albretton_ONLY","center_DPY27_chip_albretton;gene_ol2000;TSS_ol2000","strong_rex;DPY27_ol2000;SDC_ol2000","center_DPY27_chip_albretton;SDC_ol2000"
#"center_DPY27_chip_albretton","intergenic_control","strong_rex","weak_rex","TSS_q4","TSS_1"

In [None]:
import pandas as pd
import plotly.graph_objects as go
import nanotools            # for get_colors()

# ───────────────────────── Configuration ─────────────────────────
fwhm_file = f"{output_stem}_fwhm.csv"

PLOT_ORDER = ["N2", "SDC2deg", "SDC3deg","SDC2_3deg","DPY27deg", "DPY21null"]

# ────────────────── Load & tag matching rows ────────────────────
df100 = (
    pd.read_csv(fwhm_file)
      .query("window_size == 101")
      .assign(
          plot_group=lambda d: d["condition"].apply(
              lambda c: next((s for s in PLOT_ORDER if s in c), pd.NA)
          )
      )
      .dropna(subset=["plot_group"])
)

df100["plot_group"] = pd.Categorical(
    df100["plot_group"], categories=PLOT_ORDER, ordered=True
)

# ─────────────────── Helper to plot one metric ──────────────────
def plot_metric(df, metric, y_title, y_range=None):
    """Draw a box‑and‑points figure for *metric* (fwhm or peak_auc)."""
    fig = go.Figure()

    for cond in PLOT_ORDER:
        y = df.loc[df["plot_group"] == cond, metric]
        if y.empty:
            continue

        color = nanotools.get_colors(cond)

        fig.add_trace(go.Box(
            y=y, name=cond,
            marker_color=color, line_color=color, fillcolor="rgba(0,0,0,0)",
            boxmean=True, boxpoints="all", jitter=0.3, pointpos=0,
            showlegend=False
        ))

        mean_val = y.mean()
        fig.add_annotation(
            x=cond, y=mean_val,
            text=f"{mean_val:.1f}", showarrow=False,
            font=dict(color="black", size=14), yanchor="bottom"
        )

    fig.update_xaxes(
        categoryorder="array",
        categoryarray=[c for c in PLOT_ORDER if c in df["plot_group"].values],
        tickfont=dict(size=14, color="black")
    )
    fig.update_yaxes(
        range=y_range, showgrid=False,
        tickfont=dict(size=14, color="black")
    )
    fig.update_layout(
        title=f"{metric.upper()} Distribution for {typ}",
        title_font=dict(size=20, color="black"),
        xaxis_title="Condition", yaxis_title=y_title,
        xaxis_title_font=dict(size=16, color="black"),
        yaxis_title_font=dict(size=16, color="black"),
        paper_bgcolor="white", plot_bgcolor="white",
        font=dict(color="black"), width=450, height=450,
    )
    fig.show()


# ───────────────────── Plot per “type” column ───────────────────
for typ in df100["type"].unique():
    df_typ = df100[df100["type"] == typ]

    # 1) FWHM figure (original behaviour)
    plot_metric(
        df_typ, metric="fwhm",
        y_title="FWHM (bp)",
        y_range=[100, 170]    # keep the fixed range you specified
    )

    # 2) PEAK‑AUC figure (auto‑range)
    plot_metric(
        df_typ, metric="peak_auc",
        y_title="Peak AUC (signal × bp)",
        y_range=None          # let Plotly choose a sensible range
    )


In [None]:
### START OF SINGLE FIBER PLOTTING

### OPTIONAL TO CENTER ON MEX MOTIFS
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from intervaltree import Interval, IntervalTree
from collections import defaultdict
import numpy as np  # Added import for numpy

# Set this parameter as needed
center_intergenic = False
include_noregion_motifs = False
include_nomotif_regions = True
combine_motifs = True

# Configurable parameters
window_size =  2*500 #2 * bed_window  # For cluster count 2*

# Step 1: Import the TSV files into dataframes
fimo_files = [
    "/Data1/ext_data/motifs/fimo_MEX_0.01.tsv",
    "/Data1/ext_data/motifs/fimo_MEXII_0.01.tsv",
    "/Data1/ext_data/motifs/fimo_motifc_0.01.tsv"
]

fimo_dfs = []
for file in fimo_files:
    df = pd.read_csv(file, sep='\t', comment='#', skip_blank_lines=True)
    fimo_dfs.append(df)

# Combine the dataframes into one
fimo_df = pd.concat(fimo_dfs, ignore_index=True)

# print count before filter
print("fimo_df before filter:",len(fimo_df))

# Calculate the motif_length for each row
fimo_df['motif_length'] = fimo_df['stop'] - fimo_df['start']
# Drop duplicates to ensure only one motif_length per motif_id
unique_motif_lengths = fimo_df.drop_duplicates('motif_id')[['motif_id', 'motif_length']]
# Create a dictionary of motif lengths
motif_lengths = dict(zip(unique_motif_lengths['motif_id'], unique_motif_lengths['motif_length']))
# drop nan keys
motif_lengths = {k: v for k, v in motif_lengths.items() if pd.notna(k)}
# Print the motif_lengths
print("motif_lengths:", motif_lengths)

# create 'natural_log_p_value' column (note: p-value is in format 0.00000000579)
fimo_df['natural_log_p_value'] = np.log(fimo_df['p-value'])
fimo_df = fimo_df[fimo_df['natural_log_p_value'] <= -11]  # Adjust as needed

# print count after filter
print("fimo_df after filter:",len(fimo_df))

# Step 3: Deduplicate overlapping rows on the same sequence_name (chromosome)
motif_priority = {'MEXII': 2, 'MEX': 1, 'motifC': 3}
fimo_df['motif_priority'] = fimo_df['motif_id'].map(motif_priority)
fimo_df = fimo_df.dropna(subset=['motif_priority'])
fimo_df['motif_priority'] = fimo_df['motif_priority'].astype(int)
fimo_df = fimo_df.sort_values(by=['sequence_name', 'start', 'motif_priority','score'], ascending=[True, True, True,False])

chrom_trees = defaultdict(IntervalTree)
deduped_indices = set()

for idx, row in fimo_df.iterrows():
    chrom = row['sequence_name']
    start = row['start']
    end = row['stop']
    priority = row['motif_priority']
    overlaps = chrom_trees[chrom][start:end]
    if not overlaps:
        chrom_trees[chrom][start:end] = (priority, idx)
        deduped_indices.add(idx)
    else:
        replace = False
        for interval in overlaps:
            existing_priority, existing_idx = interval.data
            if start < interval.end and end > interval.begin:
                if priority < existing_priority:
                    chrom_trees[chrom].remove(interval)
                    chrom_trees[chrom][start:end] = (priority, idx)
                    deduped_indices.discard(existing_idx)
                    deduped_indices.add(idx)
                    replace = True
                else:
                    replace = False
        if not replace:
            continue

fimo_df = fimo_df.loc[deduped_indices]

print("fimo_df after dedup:",len(fimo_df))

# Step 4: Adjust columns
fimo_df = fimo_df[['sequence_name', 'start', 'stop', 'strand', 'score', 'p-value', 'motif_id','natural_log_p_value']]
fimo_df = fimo_df.rename(columns={'sequence_name': 'chr', 'p-value': 'p_value'})

# Step 5: Replace "chr" with "CHROMOSOME_" in the 'chr' column
fimo_df['chr'] = fimo_df['chr'].str.replace('chr', 'CHROMOSOME_', regex=False)

# Step 6: Build IntervalTrees for each chromosome in combined_bed_df_ext
interval_trees = defaultdict(IntervalTree)
for idx, row in combined_bed_df_ext.iterrows():
    chrom = row['chrom']
    start = row['bed_start']
    end = row['bed_end']
    category = row['type']
    # Store a tuple with the category and bed interval
    interval_trees[chrom][start:end] = (category, start, end)

# Add an 'id' column to fimo_df to keep track of original rows
fimo_df.reset_index(drop=True, inplace=True)
fimo_df['id'] = fimo_df.index

# Step 7: Expand fimo_df to account for multiple overlapping categories and add bed_start, bed_end
expanded_rows = []
for idx, row in fimo_df.iterrows():
    chrom = row['chr']
    start = row['start']
    end = row['stop']
    overlaps = interval_trees[chrom][start:end] if chrom in interval_trees else []
    if overlaps:
        for interval in overlaps:
            category, bed_st, bed_en = interval.data
            new_row = row.copy()
            new_row['chip_category'] = category
            new_row['bed_start'] = bed_st
            new_row['bed_end'] = bed_en
            expanded_rows.append(new_row)
    else:
        new_row = row.copy()
        new_row['chip_category'] = 'none'
        # For no region, we might leave bed_start and bed_end as NaN or the motif coordinates
        new_row['bed_start'] = new_row['start']
        new_row['bed_end'] = new_row['stop']
        expanded_rows.append(new_row)

fimo_expanded_df = pd.DataFrame(expanded_rows)

# Step 8: Build IntervalTrees for fimo_df for cluster counting
fimo_trees = defaultdict(IntervalTree)
for idx, row in fimo_df.iterrows():
    chrom = row['chr']
    start = row['start']
    end = row['stop']
    fimo_trees[chrom][start:end] = row['id']

# Step 9: Calculate 'cluster_count'
def get_cluster_count(row):
    chrom = row['chr']
    start = row['start']
    end = row['stop']
    intervals = fimo_trees[chrom][start - window_size:end + window_size] if chrom in fimo_trees else []
    count = len(intervals)
    return max(count, 0)

fimo_df['cluster_count'] = fimo_df.apply(get_cluster_count, axis=1)

# **NEW STEP**: Filter motifs to retain only the one with the greatest score within window_size
# Sort by chromosome and descending score
fimo_df = fimo_df.sort_values(by=['chr', 'natural_log_p_value'], ascending=[True, False])

print("fimo_df:",fimo_df[fimo_df["id"]!="none"].head())
print("fimo_df:",len(fimo_df[fimo_df["id"]!="none"]))

final_indices = []
chosen_positions_by_chr = defaultdict(list)

for chr_name, group in fimo_df.groupby('chr', sort=False):
    chosen_starts = []
    for i, row in group.iterrows():
        start = row['start']
        # Keep motif if it's not within window_size of any chosen motif
        if not any(abs(start - cs) < window_size for cs in chosen_starts):
            chosen_starts.append(start)
            final_indices.append(i)

# Update fimo_df to retain only the selected motifs
fimo_df = fimo_df.loc[final_indices].sort_values(by=['chr', 'start'])

# Merge 'cluster_count' back into the expanded DataFrame
fimo_expanded_df = fimo_expanded_df.merge(
    fimo_df[['id', 'cluster_count']], on='id', how='inner'
)

# Step 10: Compute mean cluster_count and sample sizes per chip_category
category_stats = fimo_expanded_df.groupby('chip_category')['cluster_count'].agg(['mean', 'count']).reset_index()
category_stats = category_stats.sort_values(by='mean', ascending=False)
category_stats['label'] = category_stats.apply(lambda x: f"{x['chip_category']}\n(n={int(x['count'])})", axis=1)

# Plot the bar plot
plt.figure(figsize=(12, 6))
sns.barplot(x='label', y='mean', data=category_stats, palette='viridis')
plt.title('Average Cluster Count by Chip Category')
plt.xlabel('Chip Category')
plt.ylabel('Average Cluster Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Now, separately handle categories not present in fimo_expanded_df if include_nomotif_regions is True
if include_nomotif_regions:
    # Create sets of existing pairs from fimo_expanded_df
    existing_pairs = set(zip(fimo_expanded_df['bed_start'], fimo_expanded_df['chip_category']))

    # Create sets of all pairs from combined_bed_df_ext
    all_pairs = set(zip(combined_bed_df_ext['bed_start'], combined_bed_df_ext['type']))

    # Identify missing (bed_start, type) combinations
    missing_pairs = all_pairs - existing_pairs

    # Filter for missing combinations
    missing_combinations_df = combined_bed_df_ext[
        combined_bed_df_ext.apply(lambda row: (row['bed_start'], row['type']) in missing_pairs, axis=1)
    ].copy()

    # Compute midpoint and set start/stop
    missing_combinations_df['midpoint'] = (missing_combinations_df['bed_start'] + missing_combinations_df['bed_end']) // 2
    missing_combinations_df['chr'] = missing_combinations_df['chrom'].str.replace('chr', 'CHROMOSOME_', regex=False)
    missing_combinations_df['start'] = missing_combinations_df['midpoint']
    missing_combinations_df['stop'] = missing_combinations_df['midpoint']
    missing_combinations_df['chip_category'] = missing_combinations_df['type']

    # Assign defaults for fields not derived from fimo
    missing_combinations_df['strand'] = '.'
    missing_combinations_df['score'] = 0
    missing_combinations_df['p_value'] = 1.0
    missing_combinations_df['motif_id'] = 'noregion'
    new_id_start = fimo_expanded_df['id'].max() + 1 if not fimo_expanded_df.empty else 0
    missing_combinations_df['id'] = range(new_id_start, new_id_start + len(missing_combinations_df))
    missing_combinations_df['cluster_count'] = 0  # Default since not computed

    # Select columns to match fimo_expanded_df
    # Adjust if fimo_expanded_df includes bed_start and bed_end in its final columns
    new_rows = missing_combinations_df[['chr', 'start', 'stop', 'strand', 'score', 'p_value',
                                        'motif_id', 'id', 'chip_category', 'cluster_count']]

    # Append to fimo_expanded_df
    fimo_expanded_df = pd.concat([fimo_expanded_df, new_rows], ignore_index=True)


# **NEW STEP**: Apply bed_window expansion and "MEX_" prefix conditionally
# **NEW STEP**: Apply bed_window expansion and "MEX_" prefix conditionally
if center_intergenic:
    # If center_intergenic = True, apply to all rows
    fimo_expanded_df['start'] = fimo_expanded_df['start'] - bed_window
    fimo_expanded_df['start'] = fimo_expanded_df['start'].apply(lambda x: max(0, x))
    fimo_expanded_df['stop'] = fimo_expanded_df['start'] + bed_window
    # Add "MEX_" to all chip_categories
    fimo_expanded_df['chip_category'] = 'MEX_' + fimo_expanded_df['chip_category']

    # Create combined_bed_df_mex_cat from all processed motifs
    combined_bed_df_mex_cat = fimo_expanded_df[['chr', 'start', 'stop', 'strand', 'chip_category']].copy()
else:
    # If center_intergenic = False, only apply to non-intergenic rows
    non_intergenic_mask = ~fimo_expanded_df['chip_category'].str.contains('intergenic', case=False, na=False)

    # Center only non-intergenic rows
    fimo_expanded_df.loc[non_intergenic_mask, 'start'] = fimo_expanded_df.loc[non_intergenic_mask, 'start'] - bed_window
    fimo_expanded_df.loc[non_intergenic_mask, 'start'] = fimo_expanded_df.loc[non_intergenic_mask, 'start'].apply(lambda x: max(0, x))
    fimo_expanded_df.loc[non_intergenic_mask, 'stop'] = fimo_expanded_df.loc[non_intergenic_mask, 'start'] + bed_window
    # Add "MEX_" prefix only to non-intergenic rows
    if(combine_motifs):
        fimo_expanded_df.loc[non_intergenic_mask, 'chip_category'] = 'MOTIFS_' + fimo_expanded_df.loc[non_intergenic_mask, 'chip_category']
    else:
        fimo_expanded_df.loc[non_intergenic_mask, 'chip_category'] = fimo_expanded_df.loc[non_intergenic_mask, 'motif_id'] +"_"+ fimo_expanded_df.loc[non_intergenic_mask, 'chip_category']

    #if include_noregion_motifs == False then drop all rows where chip_category includes "none"
    if not include_noregion_motifs:
        fimo_expanded_df = fimo_expanded_df[~fimo_expanded_df['chip_category'].str.contains('none', case=False, na=False)]

    # Create combined_bed_df_mex_cat from non-intergenic processed motifs
    combined_bed_df_mex_cat = fimo_expanded_df.loc[non_intergenic_mask, ['chr', 'start', 'stop', 'strand', 'chip_category']].copy()

    # Retrieve all intergenic rows from combined_bed_df_ext
    intergenic_df = combined_bed_df_ext[combined_bed_df_ext['type'].str.contains('intergenic', case=False, na=False)].copy()

    # Drop the duplicate 'strand' column if it exists
    if 'strand' in intergenic_df.columns and 'bed_strand' in intergenic_df.columns:
        # Remove the original 'strand' column
        intergenic_df = intergenic_df.drop(columns=['strand'])

    # Now perform the rename so that we get only one 'strand' column.
    intergenic_df_renamed = intergenic_df.rename(columns={
        'chrom': 'chr',
        'bed_start': 'start',
        'bed_end': 'stop',
        'bed_strand': 'strand',
        'type': 'chip_category'
    })[['chr', 'start', 'stop', 'strand', 'chip_category']]

    # Ensure column names match before concatenation
    intergenic_df_renamed = intergenic_df.rename(columns={
        'chrom': 'chr',
        'bed_start': 'start',
        'bed_end': 'stop',
        'bed_strand': 'strand',
        'type': 'chip_category'
    })[['chr', 'start', 'stop', 'strand', 'chip_category']]

    # Combine processed non-intergenic motifs with original intergenic rows
    combined_bed_df_mex_cat = pd.concat([combined_bed_df_mex_cat, intergenic_df_renamed], ignore_index=True)

# Drop NaN rows in 'chr' column
combined_bed_df_mex_cat.dropna(subset=['chr'], inplace=True)

# Create combined_bed_df_mex_cat
combined_bed_df_mex_cat = combined_bed_df_mex_cat.rename(columns={
    'chr': 'chrom',
    'start': 'bed_start',
    'stop': 'bed_end',
    'strand': 'bed_strand',
    'chip_category': 'type'
})

# Add 'chr_type' column
combined_bed_df_mex_cat['chr_type'] = combined_bed_df_mex_cat['chrom'].apply(
    lambda x: 'X' if x == 'CHROMOSOME_X' else 'Autosome'
)

# Create combined_bed_df_mex_clust
combined_bed_df_mex_clust = fimo_expanded_df[['chr', 'start', 'stop', 'strand', 'cluster_count']].copy()
combined_bed_df_mex_clust = combined_bed_df_mex_clust.rename(columns={
    'chr': 'chrom',
    'start': 'bed_start',
    'stop': 'bed_end',
    'strand': 'bed_strand'
})
combined_bed_df_mex_clust['type'] = 'clust_' + combined_bed_df_mex_clust['cluster_count'].astype(str)
combined_bed_df_mex_clust['chr_type'] = combined_bed_df_mex_clust['chrom'].apply(
    lambda x: 'X' if x == 'CHROMOSOME_X' else 'Autosome'
)
combined_bed_df_mex_clust.drop(columns='cluster_count', inplace=True)

# Update 'type' in 'combined_bed_df_mex_cat'
combined_bed_df_mex_cat['type'] = combined_bed_df_mex_cat['type'].replace({
    'MEX_D1': 'MEX_D1to5',
    'MEX_D2': 'MEX_D1to5',
    'MEX_D3': 'MEX_D1to5',
    'MEX_D4': 'MEX_D1to5',
    'MEX_D5': 'MEX_D1to5',
    'MEX_D6': 'MEX_D6to9',
    'MEX_D7': 'MEX_D6to9',
    'MEX_D8': 'MEX_D6to9',
    'MEX_D9': 'MEX_D6to9'
})

print("\ncombined_bed_df_mex_clust:")
# Replace 'nanotools.display_sample_rows' with 'print(combined_bed_df_mex_clust.head())' if 'nanotools' is undefined
print(combined_bed_df_mex_clust.head())

# Combine all processed motifs into combined_bed_df
combined_bed_df = combined_bed_df_mex_cat.copy()

# Optional: Include 'MEX_none' if desired
# To include 'MEX_none', uncomment the following lines
if include_noregion_motifs:
    non_mex_none_df = combined_bed_df[combined_bed_df['type'] != 'MEX_none']
    mex_none_df = combined_bed_df[combined_bed_df['type'] == 'MEX_none'].sample(
        n=min(100, len(combined_bed_df[combined_bed_df['type'] == 'MEX_none'])),
        random_state=42  # For reproducibility
        )
    combined_bed_df = pd.concat([non_mex_none_df, mex_none_df], ignore_index=True)
else:
    # If excluding 'MEX_none', ensure they are removed
    combined_bed_df = combined_bed_df[combined_bed_df['type'] != 'MEX_none']

# Display the first few rows (optional)
print("combined_bed_df:")
# Replace 'nanotools.display_sample_rows' with 'print(combined_bed_df.head())' if 'nanotools' is undefined
#combined_bed_df["strand"] = "."
# set start and end to integers
combined_bed_df['bed_start'] = combined_bed_df['bed_start'].astype(int)
combined_bed_df['bed_end'] = combined_bed_df['bed_end'].astype(int)
nanotools.display_sample_rows(combined_bed_df, 5)


# Print count by type
print("Count by type:\n", combined_bed_df['type'].value_counts())

# Keep only type that contain 'MEX_D10' (optional)
# combined_bed_df = combined_bed_df[combined_bed_df['type'].str.contains('MEX_D10')]

# Test to keep only strand == "+"
# combined_bed_df = combined_bed_df[combined_bed_df['bed_strand'] == '+']

import pandas as pd
import tempfile
import os

def create_modkit_bed_df(filtered_df):
    # filtered_df is combined_bed_df.
    # It now contains its original strand information (or '.' for some cases like 'noregion').
    # The 'bed_start' and 'bed_end' columns in filtered_df should already be the
    # adjusted coordinates from your earlier processing steps.

    # Create DataFrame for '+' strands
    # We use the coordinate columns from filtered_df which are already adjusted.
    # In combined_bed_df, these are 'chrom', 'bed_start', 'bed_end'.
    df_plus = pd.DataFrame({
        0: filtered_df['chrom'],
        1: filtered_df['bed_start'],
        2: filtered_df['bed_end'],
        3: '.',  # Standard placeholder for name column in BED
        4: '.',  # Standard placeholder for score column in BED
        5: '+'   # Assign '+' strand
    })

    # Create DataFrame for '-' strands
    df_minus = pd.DataFrame({
        0: filtered_df['chrom'],
        1: filtered_df['bed_start'],
        2: filtered_df['bed_end'],
        3: '.',
        4: '.',
        5: '-'   # Assign '-' strand
    })

    # Concatenate the '+' and '-' strand DataFrames
    modkit_output_df = pd.concat([df_plus, df_minus], ignore_index=True)

    # Ensure bed_start and bed_end are integers (BED format requirement)
    modkit_output_df[1] = modkit_output_df[1].astype(int)
    modkit_output_df[2] = modkit_output_df[2].astype(int)

    # The columns are already named 0 through 5 by this construction.
    # No need for: modkit_output_df.columns = range(modkit_output_df.shape[1])

    return modkit_output_df

def save_modkit_bed_to_temp(modkit_bed_df, filename):
    # Create a temporary directory
    temp_dir = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"

    # Ensure the directory exists
    os.makedirs(temp_dir, exist_ok=True)

    # Create the full path for the temporary file
    temp_file_path = os.path.join(temp_dir, filename)

    # Save the dataframe to the temporary file
    modkit_bed_df.to_csv(temp_file_path, sep='\t', header=False, index=False)

    print(f"Modkit BED file saved to: {temp_file_path}")
    return temp_file_path

# Create modkit_bed_df
modkit_bed_df = create_modkit_bed_df(combined_bed_df)

# Drop duplicate rows to ensure each unique region has one '+' and one '-' entry
# This is important because create_modkit_bed_df processes each row of combined_bed_df.
# If combined_bed_df had multiple entries that would map to the same chrom/start/end
# (e.g. if it already contained a + and - version of a region),
# this drop_duplicates will clean it up to a unique +/- pair.
modkit_bed_df = modkit_bed_df.drop_duplicates().reset_index(drop=True)

# Save modkit_bed_df to a temporary file
modkit_bed_name = "modkit_temp.bed"
temp_file_path = save_modkit_bed_to_temp(modkit_bed_df, modkit_bed_name)

print(f"Modkit BED file saved to: {temp_file_path}")

# Print the de-duplicated DataFrame
print("De-duplicated modkit_bed_df content:")
print(modkit_bed_df)

# print number of rows after de-duplication
print("len(modkit_bed_df) after de-duplication:",len(modkit_bed_df))

In [None]:
# ft_files = bam_files with ".sorted.bam" replaced with ".ftools.bed"
ft_files = [x.replace(".bam", "_ftools0p8.bed") for x in bam_files]

import pandas as pd
import numpy as np

# Helper: Parse a comma‐separated string into a list of integers, ignoring -1 values.
def parse_int_list(s):
    try:
        lst = [int(x) for x in s.split(',') if x.strip() != '']
        return [x for x in lst if x != -1]
    except Exception:
        return []

def ingest_and_adjust_bed_file(file_path, file_id, selected_type, combined_bed_df_int, bed_window, exp_id):
    """
    Ingests a mod_mappings.ftools.bed file, filters for reads overlapping regions
    (from combined_bed_df_ext_mex that match the selected_type), and adjusts feature positions
    relative to the midpoint of each region.

    Returns a DataFrame where each row is an overlapping read-region pair,
    with added columns: read_end, read_end_adj.
    """
    try:
        mod_df = pd.read_csv(file_path, sep="\t", header=0)
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        return None

    # Remove any leading '#' from column names.
    mod_df.columns = [col.lstrip('#') for col in mod_df.columns]

    # Filter regions for the selected type.
    regions = combined_bed_df_int[combined_bed_df_int['type'] == selected_type]

    rows = []
    read_counter = 0

    for r_idx, region in regions.iterrows():
        region_chrom  = region['chrom']
        region_start  = region['bed_start']
        region_end    = region['bed_end']
        region_strand = region['bed_strand']
        region_mid    = (region_start + region_end) // 2

        # Reads overlapping the region
        region_mod = mod_df[
            (mod_df['ct'] == region_chrom) &
            (mod_df['en'] > region_start) &
            (mod_df['st'] < region_end)
        ]

        if region_strand != '.':
            region_mod = region_mod[region_mod['strand'] == region_strand]
        if region_mod.empty:
            continue

        for _, read in region_mod.iterrows():
            read_st     = int(read['st'])
            read_end    = int(read['en'])
            read_st_adj = read_st  - region_mid
            read_end_adj= read_end - region_mid

            # Nucleosome processing
            adjusted_ref_nuc = None
            adjusted_ref_nuc_lengths = None
            if 'ref_nuc_starts' in read and pd.notnull(read['ref_nuc_starts']):
                nuc_positions = parse_int_list(str(read['ref_nuc_starts']))
                rel_positions = [x - region_mid for x in nuc_positions]
                adjusted_ref_nuc = [pos for pos in rel_positions if -bed_window <= pos <= bed_window]

                if 'ref_nuc_lengths' in read and pd.notnull(read['ref_nuc_lengths']):
                    lengths = parse_int_list(str(read['ref_nuc_lengths']))
                    if len(lengths) == len(nuc_positions):
                        adjusted_ref_nuc_lengths = [
                            l for pos, l in zip(rel_positions, lengths)
                            if -bed_window <= pos <= bed_window
                        ]

            # MSP processing
            adjusted_ref_msp = None
            adjusted_ref_msp_lengths = None
            if 'ref_msp_starts' in read and pd.notnull(read['ref_msp_starts']):
                msp_positions = parse_int_list(str(read['ref_msp_starts']))
                rel_positions = [x - region_mid for x in msp_positions]
                adjusted_ref_msp = [pos for pos in rel_positions if -bed_window <= pos <= bed_window]

                if 'ref_msp_lengths' in read and pd.notnull(read['ref_msp_lengths']):
                    lengths = parse_int_list(str(read['ref_msp_lengths']))
                    if len(lengths) == len(msp_positions):
                        adjusted_ref_msp_lengths = [
                            l for pos, l in zip(rel_positions, lengths)
                            if -bed_window <= pos <= bed_window
                        ]

            # m6A and 5mC points
            adjusted_ref_m6a = None
            if 'ref_m6a' in read and pd.notnull(read['ref_m6a']):
                m6a_positions = parse_int_list(str(read['ref_m6a']))
                rel_positions = [x - region_mid for x in m6a_positions]
                adjusted_ref_m6a = [pos for pos in rel_positions if -bed_window <= pos <= bed_window]

            adjusted_ref_5mC = None
            if 'ref_5mC' in read and pd.notnull(read['ref_5mC']):
                m5c_positions = parse_int_list(str(read['ref_5mC']))
                rel_positions = [x - region_mid for x in m5c_positions]
                adjusted_ref_5mC = [pos for pos in rel_positions if -bed_window <= pos <= bed_window]

            rows.append({
                'file_id': file_id,
                'exp_id': exp_id,
                'region_index': r_idx,
                'region_chrom': region_chrom,
                'region_start': region_start,
                'region_end': region_end,
                'region_strand': region_strand,
                'region_mid': region_mid,
                'read_index': read_counter,
                'read_st': read_st,
                'read_end': read_end,
                'read_st_adj': read_st_adj,
                'read_end_adj': read_end_adj,
                'fiber': read.get('fiber', None),
                'adjusted_ref_nuc': adjusted_ref_nuc,
                'adjusted_ref_nuc_lengths': adjusted_ref_nuc_lengths,
                'adjusted_ref_msp': adjusted_ref_msp,
                'adjusted_ref_msp_lengths': adjusted_ref_msp_lengths,
                'adjusted_ref_m6a': adjusted_ref_m6a,
                'adjusted_ref_5mC': adjusted_ref_5mC,
                'fiber_sequence': read.get('fiber_sequence', None)
            })
            read_counter += 1

    return pd.DataFrame(rows)

# --- Parallel wrapper ---
def process_single_file(args_tuple):
    args, selected_type, combined_bed_df_int2, bed_window = args_tuple
    # Unpack with exp_id fallback
    if len(args) >= 6:
        file_path, file_id, flag1, flag2, output_dir, exp_id = args[:6]
    else:
        file_path, file_id, flag1, flag2, output_dir = args[:5]
        exp_id = "unknown"

    print(f"Processing {file_id} (exp {exp_id})")
    df = ingest_and_adjust_bed_file(
        file_path, file_id, selected_type, combined_bed_df_int2, bed_window, exp_id
    )
    if df is not None and not df.empty:
        print(f"  → {len(df)} rows")
        return df
    else:
        print("  → no overlaps")
        return None

def process_all_files(args_list, selected_type, combined_bed_df_int3, bed_window, num_threads=None):
    import multiprocessing as mp
    from multiprocessing import Pool

    num_threads = num_threads or mp.cpu_count()
    num_threads = min(int(num_threads), mp.cpu_count())
    print(f"Using {num_threads} threads on {len(args_list)} files")

    pool_args = [
        (args, selected_type, combined_bed_df_int3, bed_window)
        for args in args_list
    ]

    with Pool(processes=num_threads) as pool:
        results = pool.map(process_single_file, pool_args)

    dfs = [df for df in results if df is not None and not df.empty]
    if not dfs:
        print("No data processed.")
        return pd.DataFrame()

    combined_df = pd.concat(dfs, ignore_index=True)
    print(f"Combined {len(dfs)} files into {len(combined_df)} total rows.")
    return combined_df

# --- Example usage ---
selected_type = 'MOTIFS_strong_rex'
num_threads   = 15

combined_df = process_all_files(
    args_list,
    selected_type,
    combined_bed_df_ext_mex,
    bed_window,
    num_threads
)

print(combined_df.head())
print(f"Total rows: {len(combined_df)}")
print(f"Unique files: {combined_df['file_id'].nunique()}")
print(f"Unique exps : {combined_df['exp_id'].nunique()}")

# Optionally save:
# combined_df.to_csv("combined_results_with_end.csv", index=False)


In [None]:
import pandas as pd
import pyranges as pr

# ---------------------------
# MEX MOTIF LOOKUP FOR combined_df WITH STRAND INFO
# ---------------------------

# Load FIMO files
fimo_files = [
    "/Data1/ext_data/motifs/fimo_MEX_0.01.tsv",
    "/Data1/ext_data/motifs/fimo_MEXII_0.01.tsv",
    "/Data1/ext_data/motifs/fimo_motifc_0.01.tsv"
]

dfs = []
for file in fimo_files:
    print(f"Loading {file}")
    df = pd.read_csv(file, sep='\t')
    print(f"Loaded {len(df)} rows from {file}")
    dfs.append(df)

fimo_df = pd.concat(dfs, ignore_index=True)
print(f"Combined FIMO dataframe has {len(fimo_df)} rows")

# Convert sequence_name from 'chr' to 'CHROMOSOME_*' format
print("Converting 'sequence_name' to 'CHROMOSOME_*' format")
fimo_df['sequence_name'] = fimo_df['sequence_name'].str.replace('chr', 'CHROMOSOME_')

# Filter rows with score > 5
print("Filtering rows with score > 5")
fimo_df = fimo_df[fimo_df['score'] > 5]
print(f"Filtered FIMO dataframe has {len(fimo_df)} rows")

# Select and rename relevant columns
print("Selecting relevant columns")
fimo_df = fimo_df[['sequence_name', 'start', 'stop', 'strand', 'score', 'p-value', 'motif_id']]
fimo_df = fimo_df.rename(columns={'sequence_name': 'chr', 'p-value': 'p_value'})

# Assign motif priorities
print("Assigning motif priorities")
motif_priority = {'MEXII': 1, 'MEX': 2, 'motifC': 3}
fimo_df['motif_priority'] = fimo_df['motif_id'].map(motif_priority)

# Sort by chromosome, start, and motif_priority
print("Sorting FIMO dataframe")
fimo_df = fimo_df.sort_values(by=['chr', 'start', 'motif_priority'])

# Deduplicate overlapping intervals using PyRanges
print("Deduplicating overlapping intervals using PyRanges")
fimo_df = fimo_df.rename(columns={'chr': 'Chromosome', 'start': 'Start', 'stop': 'End'})
fimo_df['Start'] = fimo_df['Start'].astype(int)
fimo_df['End'] = fimo_df['End'].astype(int)
pr_df = pr.PyRanges(fimo_df)
clusters = pr_df.cluster()
clusters_df = clusters.df
clusters_df = clusters_df.sort_values(['Cluster', 'motif_priority'])
fimo_dedup_df = clusters_df.drop_duplicates(subset=['Cluster'], keep='first')
# Rename back for clarity
fimo_dedup_df = fimo_dedup_df.rename(columns={'Chromosome': 'chr', 'Start': 'start', 'End': 'stop'})
print(f"Deduplicated FIMO dataframe has {len(fimo_dedup_df)} rows")

# Prepare DataFrames for the join
# Rename FIMO columns for PyRanges compatibility
fimo_df_renamed = fimo_dedup_df.rename(columns={'chr': 'Chromosome', 'start': 'Start', 'stop': 'End'})

# For combined_df, use the region columns as the interval
combined_df_renamed = combined_df.rename(columns={
    'region_chrom': 'Chromosome',
    'region_start': 'Start',
    'region_end': 'End'
})

print("Converting DataFrames to PyRanges objects")
fimo_pr = pr.PyRanges(fimo_df_renamed)
combined_df_pr = pr.PyRanges(combined_df_renamed)

print("Performing join to find overlapping intervals")
overlap_result = combined_df_pr.join(fimo_pr, suffix="_fimo")

# Aggregate overlapping motif intervals along with their strands.
print("Aggregating overlapping intervals with strand info")
overlap_df = overlap_result.df
agg_df = overlap_df.groupby(['Chromosome', 'Start', 'End']).apply(
    lambda x: list(zip(x['Start_fimo'], x['motif_id'], x['strand']))
).reset_index(name='motif_info')

print("Removing duplicate motif info")
agg_df['motif_info'] = agg_df['motif_info'].apply(lambda x: list(set(x)))

# Separate aggregated motif info into motif_start, motif_id, and motif_strand lists
print("Separating aggregated motif info into motif_start, motif_id, and motif_strand")
agg_df['motif_start'] = agg_df['motif_info'].apply(lambda x: [item[0] for item in x])
agg_df['motif_id'] = agg_df['motif_info'].apply(lambda x: [item[1] for item in x])
agg_df['motif_strand'] = agg_df['motif_info'].apply(lambda x: [item[2] for item in x])

# Compute motif_rel_start using the formula: motif_start - region_start - bed_window
print("Computing motif_rel_start")
agg_df['motif_rel_start'] = agg_df.apply(
    lambda row: [start - row['Start'] - bed_window for start in row['motif_start']], axis=1
)

# Rename the interval columns back to match combined_df's keys
agg_df = agg_df.rename(columns={'Chromosome': 'region_chrom', 'Start': 'region_start', 'End': 'region_end'})

print("Merging aggregated motif info back into combined_df")
combined_df = combined_df.merge(
    agg_df[['region_chrom', 'region_start', 'region_end', 'motif_rel_start', 'motif_id', 'motif_strand']],
    how='left',
    on=['region_chrom', 'region_start', 'region_end']
)

# Convert lists to tuples for consistency
print("Converting lists to tuples")
combined_df['motif_rel_start'] = combined_df['motif_rel_start'].apply(
    lambda x: tuple(x) if isinstance(x, list) else x
)
combined_df['motif_id'] = combined_df['motif_id'].apply(
    lambda x: tuple(x) if isinstance(x, list) else x
)
combined_df['motif_strand'] = combined_df['motif_strand'].apply(
    lambda x: tuple(x) if isinstance(x, list) else x
)

# Replace NaN with empty tuples
print("Replacing NaN with empty tuples")
combined_df['motif_rel_start'] = combined_df['motif_rel_start'].apply(
    lambda x: tuple() if pd.isna(x) else x
)
combined_df['motif_id'] = combined_df['motif_id'].apply(
    lambda x: tuple() if pd.isna(x) else x
)
combined_df['motif_strand'] = combined_df['motif_strand'].apply(
    lambda x: tuple() if pd.isna(x) else x
)

# Ensure all elements in motif_rel_start are integers
combined_df['motif_rel_start'] = combined_df['motif_rel_start'].apply(
    lambda x: tuple(map(int, x)) if isinstance(x, tuple) else x
)

pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)
pd.set_option('display.width', None)

# ───────────────────────────────────────────────────
# DEBUG: filter out reads shorter than 1000 bp
# (drop in immediately after combined_df is defined)
# ───────────────────────────────────────────────────
print("Filtering out reads shorter than 1000 bp…")

# 1) compute read lengths
combined_df['read_length'] = combined_df['fiber_sequence'] \
    .apply(lambda x: len(x) if isinstance(x, str) else 0)

# 2) snapshot before
before_rows  = combined_df.shape[0]
before_reads = combined_df['read_index'].nunique()
print(f"→ Before filter: {before_rows} rows, {before_reads} unique reads")

# 3) apply the filter
#combined_df = combined_df[ combined_df['read_length'] >= 500 ]

# 4) snapshot after
after_rows  = combined_df.shape[0]
after_reads = combined_df['read_index'].nunique()
print(f"→ After filter : {after_rows} rows, {after_reads} unique reads")

# 5) clean up
combined_df.drop(columns=['read_length'], inplace=True)
print("Done filtering.\n")
# ───────────────────────────────────────────────────

# ───────────────────────────────────────────────────
# DEBUG: drop all reads that have no adjusted_ref_nuc
# ───────────────────────────────────────────────────
print("Filtering out reads with zero nucleotides (no adjusted_ref_nuc)…")
# count per-read nucs
combined_df['nuc_count'] = combined_df['adjusted_ref_nuc'].apply(lambda x: len(x) if isinstance(x, (list, tuple)) else 0)

before2_rows  = combined_df.shape[0]
before2_reads = combined_df['read_index'].nunique()
print(f"→ Before nuc-filter  : {before2_rows} rows, {before2_reads} unique reads")

# drop reads where nuc_count == 0
combined_df = combined_df[combined_df['nuc_count'] > 0]

after2_rows  = combined_df.shape[0]
after2_reads = combined_df['read_index'].nunique()
print(f"→ After nuc-filter   : {after2_rows} rows, {after2_reads} unique reads")

combined_df.drop(columns=['nuc_count'], inplace=True)
print("Done nuc-filter.\n")

# # ───────────────────────────────────────────────────
# # DEBUG: drop duplicate fiber_sequence
# # ───────────────────────────────────────────────────
# print("Filtering out duplicate fiber_sequence rows…")
# before3_rows     = combined_df.shape[0]
# before3_sequences = combined_df['fiber_sequence'].nunique()
# print(f"→ Before dedup    : {before3_rows} rows, {before3_sequences} unique sequences")
#
# combined_df = combined_df.drop_duplicates(subset=['fiber_sequence'], keep='first')
#
# after3_rows      = combined_df.shape[0]
# after3_sequences = combined_df['fiber_sequence'].nunique()
# print(f"→ After dedup     : {after3_rows} rows, {after3_sequences} unique sequences")
# print("Done dedup.\n")


# Print the first 2 rows
print(combined_df.head(2))


In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import nanotools  # for get_color / get_colors

# assume combined_df is already loaded

# 1) trim file_id to everything before the first "_"
combined_df['file_base'] = combined_df['file_id'].str.split('_', 1).str[0]

# 2) compute read length and N50 per (file_base, exp_id)
combined_df['read_length'] = combined_df['fiber_sequence'].str.len()

def compute_n50(lengths):
    lengths = np.array(sorted(lengths, reverse=True))
    cum = lengths.cumsum()
    half = lengths.sum() / 2
    idx = np.searchsorted(cum, half)
    return int(lengths[idx])

n50_df = (
    combined_df
    .groupby(['file_base', 'exp_id'])['read_length']
    .apply(compute_n50)
    .reset_index(name='N50')
)

# 3) build the plot
fig = go.Figure()
for fb, grp in n50_df.groupby('file_base'):
    color = nanotools.get_color(fb)
    # box with no fill
    fig.add_trace(go.Box(
        y=grp['N50'],
        name=fb,
        fillcolor='rgba(0,0,0,0)',
        line_color=color,
        boxpoints='all',
        jitter=0.3,
        pointpos=0,
        marker=dict(size=6, color=color),
        showlegend=False
    ))
    # one marker per exp_id
    fig.add_trace(go.Scatter(
        x=[fb] * len(grp),
        y=grp['N50'],
        mode='markers+text',
        marker=dict(size=6, color=color),
        text=grp['exp_id'],
        textposition='top center',
        textfont=dict(size=10, color='black'),
        showlegend=False
    ))


    # annotate means
    mean_n50 = grp['N50'].mean()
    fig.add_annotation(
        x=fb,
        y=mean_n50,
        text=f"{mean_n50:.0f}",
        # NO line or arrow
        showarrow=False
    )

# set y range to 0 - 13000
fig.update_yaxes(range=[0, 15000])

fig.update_layout(
    template='plotly_white',
    xaxis_title='File',
    yaxis_title='N50 Read Length',
    width=550,
    height=400
)

fig.show()


In [None]:
# def build_percentage_df(combined_df, bed_window, smooth_window=25, group_by=None):
#     """
#     Processes the combined DataFrame to build a new DataFrame with counts and percentages
#     for m6a and 5mC marks at each relative genomic position (from -bed_window to bed_window).
# 
#     Can group results by file_id, exp_id, or both if specified.
# 
#     Denominator (for both m6a and 5mC) is the total number of reads overlapping that position.
#     Numerators come from the adjusted marks (with a +1 offset).
# 
#     A centered moving average smoothing is applied to the percentages with the given window size.
# 
#     Parameters:
#       combined_df  : Combined DataFrame containing the overlapping read-region pairs from multiple files.
#       bed_window   : integer window (e.g. 1000) defining the range [-bed_window, bed_window].
#       smooth_window: integer, size of the smoothing window in bp (e.g. 25).
#       group_by     : None, 'file_id', 'exp_id', or ['file_id', 'exp_id'] to determine grouping.
# 
#     Returns:
#       A DataFrame with columns:
#         - file_id: identifier of the source file (if grouped by file_id)
#         - exp_id: identifier of the experiment (if grouped by exp_id)
#         - position: relative genomic positions.
#         - denom: number of reads covering that position.
#         - m6a_num: count of m6a marks at that position.
#         - m6a_perc: raw percentage (100 * m6a_num / denom).
#         - m6a_perc_smooth: smoothed percentage.
#         - 5mC_num: count of 5mC marks at that position.
#         - 5mC_perc: raw percentage (100 * 5mC_num / denom).
#         - 5mC_perc_smooth: smoothed percentage.
#     """
#     import numpy as np
#     import pandas as pd
#     from functools import partial
# 
#     # Define a simple centered moving average function
#     def moving_average(x, window):
#         return np.convolve(x, np.ones(window) / window, mode='same')
# 
#     # Process a single group of data and return a DataFrame
#     def process_group(group_df, bed_window, smooth_window):
#         pos_range = np.arange(-bed_window, bed_window + 1)
#         denominator = np.zeros(len(pos_range), dtype=int)
#         m6a_num = np.zeros(len(pos_range), dtype=int)
#         m5c_num = np.zeros(len(pos_range), dtype=int)
# 
#         # For each overlapping read-region pair, add 1 to denominator for each base the read covers.
#         for _, row in group_df.iterrows():
#             read_st_adj = row['read_st_adj']
#             fiber_seq = row['fiber_sequence']
#             if pd.notnull(fiber_seq) and isinstance(fiber_seq, str):
#                 L = len(fiber_seq)
#                 # Read covers positions from read_st_adj + 1 to read_st_adj + L (inclusive)
#                 for i in range(L):
#                     pos = read_st_adj + 1 + i  # +1 offset applied here
#                     if -bed_window <= pos <= bed_window:
#                         bin_idx = pos + bed_window
#                         denominator[bin_idx] += 1
# 
#             # Process m6a numerator: For each m6a mark, add 1 (with a +1 offset)
#             adjusted_ref_m6a = row['adjusted_ref_m6a']
#             if isinstance(adjusted_ref_m6a, list):
#                 for pos in adjusted_ref_m6a:
#                     pos_new = pos + 1  # apply +1 offset
#                     if -bed_window <= pos_new <= bed_window:
#                         bin_idx = pos_new + bed_window
#                         m6a_num[bin_idx] += 1
# 
#             # Process 5mC numerator: For each 5mC mark, add 1 (with a +1 offset)
#             adjusted_ref_5mC = row['adjusted_ref_5mC']
#             if isinstance(adjusted_ref_5mC, list):
#                 for pos in adjusted_ref_5mC:
#                     pos_new = pos + 1  # apply +1 offset
#                     if -bed_window <= pos_new <= bed_window:
#                         bin_idx = pos_new + bed_window
#                         m5c_num[bin_idx] += 1
# 
#         # Compute raw percentages
#         m6a_perc = np.where(denominator > 0, 100 * m6a_num / denominator, np.nan)
#         m5c_perc = np.where(denominator > 0, 100 * m5c_num / denominator, np.nan)
# 
#         # Apply smoothing if smooth_window is greater than 1
#         if smooth_window and smooth_window > 1:
#             m6a_perc_smooth = moving_average(m6a_perc, smooth_window)
#             m5c_perc_smooth = moving_average(m5c_perc, smooth_window)
#         else:
#             m6a_perc_smooth = m6a_perc
#             m5c_perc_smooth = m5c_perc
# 
#         result_df = pd.DataFrame({
#             'position': pos_range,
#             'denom': denominator,
#             'm6a_num': m6a_num,
#             'm6a_perc': m6a_perc,
#             'm6a_perc_smooth': m6a_perc_smooth,
#             '5mC_num': m5c_num,
#             '5mC_perc': m5c_perc,
#             '5mC_perc_smooth': m5c_perc_smooth
#         })
# 
#         return result_df
# 
#     # If no grouping is specified, process the entire dataframe at once
#     if group_by is None:
#         return process_group(combined_df, bed_window, smooth_window)
# 
#     # Validate the group_by parameter
#     valid_group_by = ['file_id', 'exp_id', ['file_id', 'exp_id']]
#     if group_by not in valid_group_by:
#         raise ValueError(f"group_by must be one of {valid_group_by}")
# 
#     # Group by the specified column(s) and process each group
#     grouped = combined_df.groupby(group_by)
#     result_dfs = []
# 
#     for group_name, group_df in grouped:
#         # Process this group
#         group_result = process_group(group_df, bed_window, smooth_window)
# 
#         # Add group identifiers to the result
#         if isinstance(group_by, list):
#             # If grouped by multiple columns, add each as a separate column
#             for i, col_name in enumerate(group_by):
#                 group_result[col_name] = group_name[i]
#         else:
#             # If grouped by a single column, add it as a column
#             group_result[group_by] = group_name
# 
#         result_dfs.append(group_result)
# 
#     # Combine all group results
#     final_df = pd.concat(result_dfs, ignore_index=True)
#     return final_df
# 
# # --- Example usage in a notebook cell ---
# # Process all data combined (no grouping)
# df_perc_all = build_percentage_df(combined_df, bed_window=2000, smooth_window=25)
# 
# # Group by experiment ID
# df_perc_by_exp = build_percentage_df(combined_df, bed_window=2000, smooth_window=25, group_by='exp_id')
# 
# # Group by file ID
# df_perc_by_file = build_percentage_df(combined_df, bed_window=2000, smooth_window=25, group_by='file_id')
# 
# # Group by both file ID and experiment ID
# df_perc_by_both = build_percentage_df(combined_df, bed_window=2000, smooth_window=25, group_by=['file_id', 'exp_id'])
# 
# # Display the first few rows of each result
# print("All data combined:")
# display(df_perc_all.head())

In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, fcluster
import re

DEBUG_HIST  = True   # already in your file
DEBUG_READS = True    # NEW  – read-level diagnostics

def filter_by_file_and_exp_id(df, file_id_input):
    """
    file_id_input can be:
      - A string: e.g. "example_file_id"
      - A list/tuple of length 1: e.g. ["example_file_id"]
      - A list/tuple of length 2: e.g. ["example_file_id", 0]
    If the second element is provided (e.g. index=0), we filter to the
    `exp_id` at that position among the unique `exp_id` values for that file.
    If the second element is not provided, we include all `exp_id` for that file.
    """
    if isinstance(file_id_input, (list, tuple)):
        actual_file_id = file_id_input[0]
        subset = df[df['file_id'] == actual_file_id]
        if len(file_id_input) > 1 and file_id_input[1] is not None:
            exp_index = file_id_input[1]
            unique_vals = subset['exp_id'].unique()
            unique_vals_sorted = np.sort(unique_vals)
            if exp_index < len(unique_vals_sorted):
                chosen_exp = unique_vals_sorted[exp_index]
                return subset[subset['exp_id'] == chosen_exp].copy()
            raise ValueError(
                f"Requested exp_id index {exp_index} is out of range for file_id={actual_file_id}. "
                f"Available exp_id values are: {unique_vals_sorted}"
            )
        else:
            return subset.copy()
    else:
        return df[df['file_id'] == file_id_input].copy()

# def get_allowed_reads(df, max_reads):
#     """
#     If max_reads == 0, include every read_index.
#     Otherwise, pick a random subset of up to max_reads read_index values
#     (uniformly, no length bias) using pandas’ .sample().
#     """
#     # only reads with a valid fiber_sequence
#     valid_reads = df.loc[
#         df['fiber_sequence'].apply(lambda x: isinstance(x, str)),
#         'read_index'
#     ].unique()
# 
#     if max_reads == 0 or len(valid_reads) <= max_reads:
#         return set(valid_reads)
# 
#     # random sample via pandas
#     return set(
#         pd.Series(valid_reads)
#           .sample(n=max_reads, replace=False)
#           .tolist()
#     )

# ───────────────────────────────────────────────
# 0)  put these two lines near the top once
RNG = np.random.default_rng(42)        # deterministic sampling
OVERLAP_ONLY = True                    # skip fibres that miss the window
# ───────────────────────────────────────────────

def get_allowed_reads(df, max_reads, *, bed_window, rng):
    """
    Return a *single* random subset of read_index values.

    • If OVERLAP_ONLY=True  → keep only reads that cover ≥1 bp of the analysis window
    • Sampling is reproducible via `rng`.

    Notes
    -----
    This function must be called exactly ONCE per file_id and the
    resulting set passed verbatim to *both* compute_histogram_metrics()
    and compute_scatter_data(); do **not** call it inside those helpers.
    """
    # keep only reads with a string fibre sequence
    good = df.loc[df['fiber_sequence'].apply(lambda x: isinstance(x, str))]

    if OVERLAP_ONLY:
        # read covers the window if [read_start, read_end] overlaps [-W, +W]
        covers = (
            (good['read_st_adj'] + 1 + good['fiber_sequence'].str.len() - 1 >= -bed_window)
            & (good['read_st_adj'] + 1                                    <=  bed_window)
        )
        good = good.loc[covers]

    read_ids = good['read_index'].unique()

    if max_reads == 0 or len(read_ids) <= max_reads:
        return set(read_ids)

    return set(rng.choice(read_ids, size=max_reads, replace=False))




import numpy as np
import pandas as pd

def compute_histogram_metrics(df, bed_window, allowed_reads, metric_list,
                              smooth_window=20, agg_bin_width=10):
    """
    Updated to count unique reads for 'nuc' metric: numerator is the number
    of reads with at least one nucleosome overlapping each bin; denominator
    is the number of reads overlapping each bin.

    Parameters unchanged.
    """
    # retain blank_reads tracking if needed
    compute_histogram_metrics.blank_reads = {}

    # setup bins
    window_bins = np.arange(-bed_window, bed_window + 1)
    n_bins = window_bins.size
    offset = bed_window

    # initialize counts
    denom_counts = np.zeros(n_bins, dtype=np.int32)
    metric_counts = {m: np.zeros(n_bins, dtype=np.int32) for m in metric_list}

    # filter to allowed reads
    df_allowed = df[df['read_index'].isin(allowed_reads)]

    # process each read once
    for _, read_df in df_allowed.groupby('read_index'):
        # determine which bins this read covers
        covered = np.zeros(n_bins, dtype=bool)
        for row in read_df.itertuples(index=False):
            if not isinstance(row.fiber_sequence, str):
                continue
            r_start = row.read_st_adj + 1
            r_end = r_start + len(row.fiber_sequence) - 1
            lo = max(r_start, -bed_window)
            hi = min(r_end, bed_window)
            if lo <= hi:
                covered[lo + offset:hi + offset + 1] = True
        # update denominator: each read contributes at most 1 per bin
        denom_counts += covered.astype(np.int32)

        # numerator for 'nuc': unique reads with any nuc overlap
        if 'nuc' in metric_list:
            nuc_mask = np.zeros(n_bins, dtype=bool)
            for row in read_df.itertuples(index=False):
                starts = getattr(row, 'adjusted_ref_nuc', None)
                lengths = getattr(row, 'adjusted_ref_nuc_lengths', None)
                if not isinstance(starts, list):
                    continue
                for i, s in enumerate(starts):
                    ln = lengths[i] if (isinstance(lengths, list) and i < len(lengths)) else 1
                    e = s + ln - 1
                    lo = max(s, -bed_window)
                    hi = min(e, bed_window)
                    if lo <= hi:
                        nuc_mask[lo + offset:hi + offset + 1] = True
            metric_counts['nuc'] += nuc_mask.astype(np.int32)

        # existing logic for other metrics
        for metric in metric_list:
            if metric == 'nuc':
                continue
            if metric == 'msp':
                col = 'adjusted_ref_msp'
                length_col = f'{col}_lengths'
                starts = np.concatenate([
                    np.asarray(v, dtype=int) if isinstance(v, list) else np.empty(0, int)
                    for v in read_df[col]
                ])
                lens = np.concatenate([
                    np.asarray(l, dtype=int) if isinstance(l, list) else np.empty(0, int)
                    for l in read_df.get(length_col, [])
                ]) if length_col in read_df else np.empty(0, int)
                if lens.size and lens.size == starts.size:
                    ends = starts + lens - 1
                else:
                    ends = starts
                for s, e in zip(starts, ends):
                    lo = max(s, -bed_window)
                    hi = min(e, bed_window)
                    if lo <= hi:
                        metric_counts['msp'][lo + offset:hi + offset + 1] += 1
            else:  # m6a / 5mC
                col = f'adjusted_ref_{metric}'
                pos = np.concatenate([
                    np.asarray(v, dtype=int) if isinstance(v, list) else np.empty(0, int)
                    for v in read_df[col]
                ])
                if pos.size:
                    mask = (pos >= -bed_window) & (pos <= bed_window)
                    bins = pos[mask] + offset
                    np.add.at(metric_counts[metric], bins, 1)

    # convert to percentages
    bins_dict = {}
    metric_percentages = {}
    for metric, counts in metric_counts.items():
        if metric in ('m6a', '5mC'):
            trim = (-n_bins) % agg_bin_width
            denom_trim = denom_counts[:-trim] if trim else denom_counts
            counts_trim = counts[:-trim] if trim else counts
            c2d = counts_trim.reshape(-1, agg_bin_width)
            d2d = denom_trim.reshape(-1, agg_bin_width)
            agg_counts = c2d.sum(1)
            agg_denom = d2d.sum(1)
            pct = np.where(agg_denom > 0, 100 * agg_counts / agg_denom, np.nan)
            agg_bins = window_bins[:c2d.size * agg_bin_width:agg_bin_width] + (agg_bin_width - 1) / 2.0
            if smooth_window > 1:
                pct = pd.Series(pct).rolling(window=smooth_window, center=True, min_periods=1).mean().to_numpy()
            bins_dict[metric] = agg_bins
            metric_percentages[metric] = pct
        else:
            pct = np.where(denom_counts > 0, 100 * counts / denom_counts, np.nan)
            if smooth_window > 1:
                pct = pd.Series(pct).rolling(window=smooth_window, center=True, min_periods=1).mean().to_numpy()
            bins_dict[metric] = window_bins
            metric_percentages[metric] = pct

    return bins_dict, metric_percentages



def compute_scatter_data(df, bed_window, clustering_window, max_reads, metric_list,
                         nuc_axis=0, msp_axis=0, m6a_axis=0, c5m_axis=0):
    """
    Same signature; internal loops minimised by grouping once per read.
    """
    allowed_reads = get_allowed_reads(df, max_reads,bed_window=bed_window, rng=RNG)
    df_allowed    = df[df['read_index'].isin(allowed_reads)]

    clustering_bins        = np.arange(-clustering_window, clustering_window + 1)
    n_clustering_bins      = clustering_bins.size
    pos_to_clustering_idx  = {p:i for i, p in enumerate(clustering_bins)}

    # -------------------------------------------------------
    # which reads overlap the clustering window?
    # -------------------------------------------------------
    read_overlaps = {}
    for r, group in df_allowed.groupby('read_index'):
        any_overlap = False
        for row in group.itertuples(index=False):
            if not isinstance(row.fiber_sequence, str): continue
            r_start = row.read_st_adj + 1
            r_end   = r_start + len(row.fiber_sequence) - 1
            if (r_start <=  clustering_window) and (r_end >= -clustering_window):
                any_overlap = True
                break
        read_overlaps[r] = any_overlap

    overlapping_reads     = [r for r, o in read_overlaps.items() if o]
    non_overlapping_reads = [r for r, o in read_overlaps.items() if not o]

    # -------------------------------------------------------
    # MSP matrix for clustering (only overlapping reads)
    # -------------------------------------------------------
    msp_matrix = np.full((len(overlapping_reads), n_clustering_bins), np.nan, dtype=np.float32)
    read_row   = {r:i for i, r in enumerate(overlapping_reads)}

    if 'msp' in metric_list:
        # fill '0' everywhere the read covers the window
        for r, group in df_allowed.groupby('read_index'):
            if r not in read_row: continue
            rr = read_row[r]
            for row in group.itertuples(index=False):
                if not isinstance(row.fiber_sequence, str): continue
                r_start = row.read_st_adj + 1
                L       = len(row.fiber_sequence)
                for pos in range(max(-clustering_window, r_start),
                                 min(clustering_window, r_start+L-1)+1):
                    msp_matrix[rr, pos_to_clustering_idx[pos]] = 0       # covered

        # set 1 where MSP intervals lie
        for r, group in df_allowed.groupby('read_index'):
            if r not in read_row: continue
            rr = read_row[r]
            col, len_col = "adjusted_ref_msp", "adjusted_ref_msp_lengths"
            starts  = np.concatenate([
                np.asarray(v, dtype=int) if isinstance(v, list) else np.empty(0, int)
                for v in group[col]
            ])
            lens    = np.concatenate([
                np.asarray(v, dtype=int) if isinstance(v, list) else np.empty(0, int)
                for v in group[len_col]
            ]) if (len_col in group) else np.empty(0, int)

            ends = starts + lens - 1 if (lens.size and lens.size == starts.size) else starts
            for s, e in zip(starts, ends):
                for pos in range(max(-clustering_window, s),
                                 min(clustering_window, e)+1):
                    msp_matrix[rr, pos_to_clustering_idx[pos]] = 1

    # simple clustering: cluster=1 for overlapping reads, 0 otherwise
    all_clusters = {r: (1 if r in overlapping_reads else 0) for r in allowed_reads}
    sorted_reads = sorted(all_clusters, key=lambda x: (all_clusters[x], x))
    y_pos        = {r:i for i, r in enumerate(sorted_reads)}

    # -------------------------------------------------------
    # Build scatter_data dict
    # -------------------------------------------------------
    scatter_data = {m: ([], []) if m in ('nuc','msp') else [] for m in metric_list}
    metric_cols  = {'nuc':'adjusted_ref_nuc',
                    'msp':'adjusted_ref_msp',
                    'm6a':'adjusted_ref_m6a',
                    '5mC':'adjusted_ref_5mC'}

    for r, group in df_allowed.groupby('read_index'):
        y = y_pos[r]

        # segments for nuc / msp
        for metric in ('nuc', 'msp'):
            if metric not in metric_list: continue
            col, len_col = metric_cols[metric], f"{metric_cols[metric]}_lengths"
            starts = np.concatenate([
                np.asarray(v, dtype=int) if isinstance(v, list) else np.empty(0, int)
                for v in group[col]
            ])
            lens   = np.concatenate([
                np.asarray(v, dtype=int) if isinstance(v, list) else np.empty(0, int)
                for v in group[len_col]
            ]) if (len_col in group) else np.empty(0, int)
            ends   = starts + lens - 1 if (lens.size and lens.size == starts.size) else starts

            seg_x, seg_y = scatter_data[metric]
            for s, e in zip(starts, ends):
                lo = max(-bed_window, s)
                hi = min(bed_window,  e)
                if lo > hi: continue
                seg_x.extend([lo, hi, None])
                seg_y.extend([y,  y,  None])

        # points for m6a / 5mC
        for metric in ('m6a', '5mC'):
            if metric not in metric_list: continue
            col  = metric_cols[metric]
            pos  = np.concatenate([
                np.asarray(v, dtype=int) if isinstance(v, list) else np.empty(0, int)
                for v in group[col]
            ])
            pos  = pos[(pos >= -bed_window) & (pos <= bed_window)]
            scatter_data[metric].extend([(y, p) for p in pos])

    return scatter_data, y_pos, [(r, all_clusters[r]) for r in sorted_reads], all_clusters

def compute_nuc_center_histogram(df, bed_window, allowed_reads, agg_bin_width, smooth_window):
    """
    Returns bin_centers and % coverage for the midpoints of each nucleosome interval.
    """
    centers = []
    for _, row in df.iterrows():
        r_index = row['read_index']
        if r_index not in allowed_reads:
            continue
        starts = row.get('adjusted_ref_nuc')
        lengths = row.get('adjusted_ref_nuc_lengths')
        if isinstance(starts, list):
            if lengths is not None and isinstance(lengths, list) and len(lengths) == len(starts):
                for s, l in zip(starts, lengths):
                    center = s + l / 2.0
                    if -bed_window <= center <= bed_window:
                        centers.append(center)
            else:
                # no lengths, treat each start as the "center"
                for s in starts:
                    if -bed_window <= s <= bed_window:
                        centers.append(s)
    bins = np.arange(-bed_window, bed_window + agg_bin_width, agg_bin_width)
    counts, _ = np.histogram(centers, bins=bins)
    if len(allowed_reads) == 0:
        percentage = np.zeros_like(counts)
    else:
        percentage = 100.0 * counts / len(allowed_reads)

    bin_centers = (bins[:-1] + bins[1:]) / 2.0
    if smooth_window and smooth_window > 1:
        series = pd.Series(percentage)
        percentage = series.rolling(window=smooth_window, center=True, min_periods=1).mean().to_numpy()

    return bin_centers, percentage


def plot_metrics_multi_file(
    df,
    file_id1, file_id2,
    bed_window, clustering_window, max_reads, max_hist,
    # 0=skip, 1=primary axis, 2=secondary axis, 3=read-level only
    nuc_axis=0,
    msp_axis=0,
    m6a_axis=0,
    c5m_axis=0,
    # 0=skip, 1=primary axis, 2=secondary axis
    nuc_center_axis=0,
    match_min=False,
    smooth_window=20,
    agg_bin_width=10,
    plot_motifs=False,
    smooth_window_nuc_center=5,
    agg_bin_width_nuc_center=10,
    selected_type=None,
    sub_plot=False,
    subtraction_bin_size=3  # ← NEW parameter: bin width for the difference histogram
):
    """
    Axis meaning:
      - nuc_axis, msp_axis: 0=skip, 1=primary axis, 2=secondary axis, 3=read-level only
      - m6a_axis, c5m_axis: 0=skip, 1=primary axis, 2=secondary axis, 3=read-level only
      - nuc_center_axis: 0=skip, 1=primary, 2=secondary

    sub_plot (bool): if True, insert a 4th subplot (difference of the primary-axis metric).
    """

    # 1) Filter the big df based on file_id & exp_id index if present
    df1 = filter_by_file_and_exp_id(df, file_id1)
    df2 = filter_by_file_and_exp_id(df, file_id2)

    # Get total read counts before filtering
    total_reads_file1 = df1['read_index'].nunique()
    total_reads_file2 = df2['read_index'].nunique()

    # 2) Decide if we do motifs
    do_plot_motifs = plot_motifs
    if selected_type is not None:
        # e.g. MOTIFS_rex5
        if not re.match(r"^MOTIFS_rex\d+$", selected_type):
            do_plot_motifs = False

    # ---------- Build histogram metrics ----------
    hist_metric_list = []
    if nuc_axis in [1, 2]:
        hist_metric_list.append('nuc')
    if msp_axis in [1, 2]:
        hist_metric_list.append('msp')
    if m6a_axis in [1, 2]:
        hist_metric_list.append('m6a')
    if c5m_axis in [1, 2]:
        hist_metric_list.append('5mC')

    # ---------- Build scatter metrics ----------
    scatter_metric_list = []
    if nuc_axis in [1, 2, 3]:
        scatter_metric_list.append('nuc')
    if msp_axis in [1, 2, 3]:
        scatter_metric_list.append('msp')
    if m6a_axis in [1, 2, 3]:
        scatter_metric_list.append('m6a')
    if c5m_axis in [1, 2, 3]:
        scatter_metric_list.append('5mC')

    # Colors for each metric
    base_colors = {
        'nuc': 'rgba(128,128,128,0.8)',
        'msp': 'rgba(240,70,40,0.7)',
        'm6a': 'rgba(200,70,30,0.7)',
        '5mC': 'purple'
    }

    # Possibly restrict the # of reads
    allowed_hist1 = get_allowed_reads(df1, max_hist,bed_window=bed_window, rng=RNG)
    allowed_hist2 = get_allowed_reads(df2, max_hist,bed_window=bed_window, rng=RNG)

    if DEBUG_HIST:
        print(f"Allowed reads (file1): {len(allowed_hist1)}  "
              f"(min_read_len={min(df1[df1['read_index'].isin(allowed_hist1)].fiber_sequence.map(len))})")
        print(f"Allowed reads (file2): {len(allowed_hist2)}  "
              f"(min_read_len={min(df2[df2['read_index'].isin(allowed_hist2)].fiber_sequence.map(len))})")

    #allowed_hist1 = set(df1['read_index'].unique())
    #allowed_hist2 = set(df2['read_index'].unique())
    
    if match_min:
        n_target = min(len(allowed_hist1), len(allowed_hist2))
    
        # pick the SAME random subset for histograms and scatter
        np.random.seed(None)           # or pass a seed if you want determinism
        allowed_hist1 = set(np.random.choice(list(allowed_hist1),
                                             size=n_target, replace=False))
        allowed_hist2 = set(np.random.choice(list(allowed_hist2),
                                             size=n_target, replace=False))
    
        df1_scatter = df1[df1['read_index'].isin(allowed_hist1)]
        df2_scatter = df2[df2['read_index'].isin(allowed_hist2)]
    else:
        df1_scatter = df1
        df2_scatter = df2
        n_reads_file1 = len(allowed_hist1)
        n_reads_file2 = len(allowed_hist2)

    if DEBUG_READS:
        def read_level_ratio(df_sub, allowed):
            ratios = []
            for r, g in df_sub[df_sub['read_index'].isin(allowed)].groupby('read_index'):
                covered_bp = 0
                nuc_bp     = 0
                for row in g.itertuples(index=False):
                    if not isinstance(row.fiber_sequence, str):
                        continue
                    r_start = row.read_st_adj + 1
                    r_end   = r_start + len(row.fiber_sequence) - 1
                    lo = max(r_start, -bed_window)
                    hi = min(r_end,   bed_window)
                    covered_bp += max(0, hi-lo+1)
                    if isinstance(row.adjusted_ref_nuc, list):
                        nuc_bp += sum(
                            max(0, min(e, bed_window) - max(s, -bed_window) + 1)
                            for s,e in (
                                (s, s+l-1) if row.adjusted_ref_nuc_lengths and i < len(row.adjusted_ref_nuc_lengths)
                                else (s, s)
                                for i,s in enumerate(row.adjusted_ref_nuc)
                                for l in [row.adjusted_ref_nuc_lengths[i] if row.adjusted_ref_nuc_lengths else 1]
                            )
                        )
                if covered_bp:
                    ratios.append(nuc_bp / covered_bp)
            return np.array(ratios)
    
        print("\n=== READ-LEVEL %NUC ===")
        for tag, sub, allowed in [
            (file_id1, df1, allowed_hist1),
            (file_id2, df2, allowed_hist2)]:
            rl = read_level_ratio(sub, allowed)*100
            print(f"{tag} :  median={np.median(rl):.1f}%   "
                  f"10th={np.percentile(rl,10):.1f}%   "
                  f"90th={np.percentile(rl,90):.1f}%   "
                  f"n={rl.size}")
        print("========================\n")

    # Compute the line-plot histograms
    bins1, histo1 = compute_histogram_metrics(
        df1, bed_window, allowed_hist1, hist_metric_list,
        smooth_window=smooth_window, agg_bin_width=agg_bin_width
    )
    bins2, histo2 = compute_histogram_metrics(
        df2, bed_window, allowed_hist2, hist_metric_list,
        smooth_window=smooth_window, agg_bin_width=agg_bin_width
    )
    
    if DEBUG_READS:
        print("\n=== BLANK-READ SUMMARY ===")
        for tag in (file_id1, file_id2):
            blanks = compute_histogram_metrics.blank_reads.get(tag, 0)
            print(f"{tag} : reads_without_nuc = {blanks}")
        compute_histogram_metrics.blank_reads.clear()
        print("==========================\n")


    # Compute the scatter (read-level) data
    scatter1, y_positions1, sorted_clusters1, clusters1 = compute_scatter_data(
        df1_scatter, bed_window, clustering_window, max_reads, scatter_metric_list,
        nuc_axis=nuc_axis, msp_axis=msp_axis, m6a_axis=m6a_axis, c5m_axis=c5m_axis
    )
    scatter2, y_positions2, sorted_clusters2, clusters2 = compute_scatter_data(
        df2_scatter, bed_window, clustering_window, max_reads, scatter_metric_list,
        nuc_axis=nuc_axis, msp_axis=msp_axis, m6a_axis=m6a_axis, c5m_axis=c5m_axis
    )

    # Identify which metric is primary
    primary_metric = None
    if nuc_axis == 1:
        primary_metric = 'nuc'
    elif msp_axis == 1:
        primary_metric = 'msp'
    elif m6a_axis == 1:
        primary_metric = 'm6a'
    elif c5m_axis == 1:
        primary_metric = '5mC'

    # Compute nuc-center lines if requested
    nuc_center_bins1, nuc_center_pct1 = [], []
    nuc_center_bins2, nuc_center_pct2 = [], []
    if nuc_center_axis != 0:
        nuc_center_bins1, nuc_center_pct1 = compute_nuc_center_histogram(
            df1, bed_window, allowed_hist1, agg_bin_width_nuc_center, smooth_window_nuc_center
        )
        nuc_center_bins2, nuc_center_pct2 = compute_nuc_center_histogram(
            df2, bed_window, allowed_hist2, agg_bin_width_nuc_center, smooth_window_nuc_center
        )

    # Decide row structure
    if sub_plot:
        # 4 subplots total
        specs = [
            [{}],                   # Row1: top read-level (file1)
            [{"secondary_y": True}],# Row2: line plots
            [{}],                   # Row3: difference histogram
            [{}]                    # Row4: bottom read-level (file2)
        ]
        row_heights = [0.25, 0.15, 0.35, 0.25]
        subplot_titles = (
            f"{file_id1} (n={total_reads_file1} reads)",
            "",
            "",
            f"{file_id2} (n={total_reads_file2} reads)"
        )
        bottom_row = 4
        diff_row = 3
    else:
        # 3 subplots
        specs = [
            [{}],
            [{"secondary_y": True}],
            [{}]
        ]
        row_heights = [0.25, 0.5, 0.25]
        subplot_titles = (
            f"{file_id1} (n={n_reads_file1} reads)",
            "",
            f"{file_id2} (n={n_reads_file2} reads)"
        )
        bottom_row = 3
        diff_row = None

    fig = make_subplots(
        rows=len(specs),
        cols=1,
        shared_xaxes=True,
        specs=specs,
        row_heights=row_heights,
        vertical_spacing=0.05,
        subplot_titles=subplot_titles
    )

    # ---- TOP SUBPLOT (file_id1), row=1 ----
    for metric, data in scatter1.items():
        if not data:
            continue
        if metric in ['nuc', 'msp']:
            seg_x, seg_y = data
            fig.add_trace(
                go.Scatter(
                    x=seg_x,
                    y=seg_y,
                    mode='lines',
                    name=f"{file_id1} {metric}",
                    line=dict(color=base_colors[metric])
                ),
                row=1, col=1
            )
        else:
            # m6a or 5mC => points
            points_y, points_x = zip(*data)
            fig.add_trace(
                go.Scatter(
                    x=points_x,
                    y=points_y,
                    mode='markers',
                    name=f"{file_id1} {metric}",
                    marker=dict(color=base_colors[metric], size=1.5, symbol='square')
                ),
                row=1, col=1
            )

    # line styles for histograms
    histogram_styles = {
        str(file_id1): dict(dash='solid', width=3),
        str(file_id2): dict(dash='dot', width=3)
    }

    # ---- MIDDLE SUBPLOT (line plots), row=2 ----
    for metric in hist_metric_list:
        # Decide if the line plot is on the primary or secondary y-axis
        if metric == 'nuc':
            secondary = (nuc_axis == 2)
        elif metric == 'msp':
            secondary = (msp_axis == 2)
        elif metric == 'm6a':
            secondary = (m6a_axis == 2)
        else:  # '5mC'
            secondary = (c5m_axis == 2)

        if metric in bins1 and metric in histo1:
            fig.add_trace(
                go.Scatter(
                    x=bins1[metric],
                    y=histo1[metric],
                    mode='lines',
                    line=dict(color=base_colors[metric], **histogram_styles[str(file_id1)]),
                    name=f"{file_id1} {metric}"
                ),
                row=2, col=1, secondary_y=secondary
            )
        if metric in bins2 and metric in histo2:
            fig.add_trace(
                go.Scatter(
                    x=bins2[metric],
                    y=histo2[metric],
                    mode='lines',
                    line=dict(color=base_colors[metric], **histogram_styles[str(file_id2)]),
                    name=f"{file_id2} {metric}"
                ),
                row=2, col=1, secondary_y=secondary
            )

    # Nucleosome center lines
    if nuc_center_axis != 0 and len(nuc_center_bins1) > 0:
        is_secondary = (nuc_center_axis == 2)
        fig.add_trace(
            go.Scatter(
                x=nuc_center_bins1,
                y=nuc_center_pct1,
                mode='lines',
                name=f"{file_id1} nuc center",
                line=dict(color=base_colors['nuc'], dash='solid', width=3)
            ),
            row=2, col=1, secondary_y=is_secondary
        )
    if nuc_center_axis != 0 and len(nuc_center_bins2) > 0:
        is_secondary = (nuc_center_axis == 2)
        fig.add_trace(
            go.Scatter(
                x=nuc_center_bins2,
                y=nuc_center_pct2,
                mode='lines',
                name=f"{file_id2} nuc center",
                line=dict(color=base_colors['nuc'], dash='dot', width=3)
            ),
            row=2, col=1, secondary_y=is_secondary
        )

    # ---- OPTIONAL DIFFERENCE SUBPLOT (row=3 if sub_plot==True) ----
    if sub_plot and primary_metric is not None and primary_metric in histo1 and primary_metric in histo2:
        diff_x = bins1[primary_metric]
        diff_y = histo1[primary_metric] - histo2[primary_metric]
        
        if DEBUG_HIST and primary_metric:
            print("\n=== DEBUG_DIFF ================================")
            print(f"Primary metric      : {primary_metric}")
            print(f"Allowed reads file1 : {len(allowed_hist1)}")
            print(f"Allowed reads file2 : {len(allowed_hist2)}")
            print("Diff_y (first 15 bins):", diff_y[:15].round(2))
            print("min / max / mean      :", diff_y.min().round(2),
                  diff_y.max().round(2), diff_y.mean().round(2))
            print("===============================================\n")


        # Define a larger bin width for the subtraction histogram
        bins = np.arange(diff_x.min(), diff_x.max() + subtraction_bin_size, subtraction_bin_size)

        # Compute weighted sum and counts per bin
        weighted_sum, bin_edges = np.histogram(diff_x, bins=bins, weights=diff_y)
        counts, _ = np.histogram(diff_x, bins=bins)

        # Compute the average; for bins with 0 counts, set the average to 0
        averages = np.divide(weighted_sum, counts, out=np.zeros_like(weighted_sum), where=counts != 0)

        # Calculate bin centers for plotting
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

        # Create a color array based on the average values
        color_array = [
            'rgba(40,190,0,0.7)' if v >= 0 else 'rgba(255,177,0,0.7)'
            for v in averages
        ]

        fig.add_trace(
            go.Bar(
                x=bin_centers,
                y=averages,
                width=subtraction_bin_size,    # ← also set bar width in the trace
                marker_color=color_array,
                name="Displacement"
            ),
            row=diff_row, col=1
        )

        label_map = {
            'nuc': "WT nuc disp. (% reads)",
            'msp': "WT msp disp. (% reads)",
            'm6a': "WT m6a disp. (% reads)",
            '5mC': "WT 5mC disp. (% reads)"
        }
        y_label = label_map.get(primary_metric, "WT nuc disp. (% reads)")
        fig.update_yaxes(title_text=y_label, row=diff_row, col=1)

    # ---- BOTTOM SUBPLOT (file_id2) => row=3 or row=4 ----
    for metric, data in scatter2.items():
        if not data:
            continue
        if metric in ['nuc', 'msp']:
            seg_x, seg_y = data
            fig.add_trace(
                go.Scatter(
                    x=seg_x,
                    y=seg_y,
                    mode='lines',
                    name=f"{file_id2} {metric}",
                    line=dict(color=base_colors[metric])
                ),
                row=bottom_row, col=1
            )
        else:
            points_y, points_x = zip(*data)
            fig.add_trace(
                go.Scatter(
                    x=points_x,
                    y=points_y,
                    mode='markers',
                    name=f"{file_id2} {metric}",
                    marker=dict(color=base_colors[metric], size=1.5, symbol='square')
                ),
                row=bottom_row, col=1
            )

    # Gray rectangle behind overlap region in top and bottom read-level subplots
    fig.add_shape(
        type="rect",
        x0=-clustering_window,
        x1=clustering_window,
        y0=0,
        y1=1,
        yref="paper",
        fillcolor="rgba(200,200,200,0.2)",
        line=dict(width=0),
        layer="below",
        row=1,
        col=1
    )
    fig.add_shape(
        type="rect",
        x0=-clustering_window,
        x1=clustering_window,
        y0=0,
        y1=1,
        yref="paper",
        fillcolor="rgba(200,200,200,0.2)",
        line=dict(width=0),
        layer="below",
        row=bottom_row,
        col=1
    )

    # Motif lines if requested (drawn across row=2 domain)
    if do_plot_motifs and (('motif_rel_start' in df1.columns) or ('motif_rel_start' in df2.columns)):
        motif_positions = []
        for d in [df1, df2]:
            if all(x in d.columns for x in ['motif_rel_start','motif_id','motif_strand']):
                for _, row in d.iterrows():
                    mr = row.get('motif_rel_start')
                    mid = row.get('motif_id')
                    mstr = row.get('motif_strand')
                    if (isinstance(mr, (list, tuple))
                        and isinstance(mid, (list, tuple))
                        and isinstance(mstr, (list, tuple))):
                        for pos, m, s in zip(mr, mid, mstr):
                            if -bed_window <= pos <= bed_window:
                                motif_positions.append((pos, f"{m}{s}"))
        motif_positions = list(set(motif_positions))
        if hasattr(fig.layout, 'yaxis2') and fig.layout.yaxis2.domain:
            y_domain = fig.layout.yaxis2.domain
        else:
            y_domain = [0, 1]

        for pos, label in motif_positions:
            fig.add_shape(
                type="line",
                x0=pos, x1=pos,
                yref="paper",
                y0=y_domain[0],
                y1=y_domain[1],
                line=dict(color="black", dash="dash", width=1)
            )
            fig.add_annotation(
                x=pos,
                y=y_domain[1],
                xref="x",
                yref="paper",
                text=label,
                showarrow=False,
                yanchor="bottom",
                font=dict(color="black", size=10)
            )

    # Cosmetics
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=False)

    # Hide y-tick labels for top and bottom read-level subplots
    fig.update_yaxes(showticklabels=False, row=1, col=1)
    fig.update_yaxes(showticklabels=False, row=bottom_row, col=1)

    # X-axis labels
    fig.update_xaxes(title_text="", row=1, col=1)
    fig.update_xaxes(title_text="genomic position (bp)", row=bottom_row, col=1)

    # Overall figure layout
    fig.update_layout(
        template='plotly_white',
        title="Combined Metrics Plot",
        height=800 if not sub_plot else 700,
        showlegend=False
    )

    # Middle subplot axes labels (row=2 => line plot)
    fig.update_yaxes(title_text="% nuc", row=2, col=1, secondary_y=False, dtick=25)
    fig.update_yaxes(title_text="% (secondary)", row=2, col=1, secondary_y=True)

    fig.show()
    return fig, (clusters1, clusters2)


#(["N2_old_fiber_R10","96_DPY27_degron_old","107_SDC2_degron_old_R10","113_114_rex1scr_old_R10","52_old_dpy21jmjc_fiber_R10","51_old_dpy21null_fiber_R10"])
#"107_SDC2_degron_old_R10","96_DPY27_degron "N2_mixed_R9","SDC2_degron_mixed_R9"])

#file_id1 = 'N2_old_fiber_37C_R10' # for DPY21 null
#file_id2 = '51_old_dpy21null_fiber_R10'
#file_id2 = '52_old_dpy21jmjc_fiber_R10'


#file_id1 = 'N2_old_fiber_R10' # For SDC2, SDC3 degron
#file_id2 = '107_SDC2_degron_old_R10'
#file_id2 = "SDC3_degron_old_R10"

#file_id2 = '96_DPY27_degron_old'
#file_id1 = 'N2_biorep1_fiber_old_R10_04_2025' # for DPY27 degron
file_id1 = 'N2_rep3' # for DPY27 degron

#file_id2= 'SDC2deg_rep3'
file_id2= 'DPY27deg_rep3'
#file_id2= 'N2_biorep2_fiber_old_R10_04_2025'

#file_id1 = '113_114_rex1scr_old_R10'
#file_id1 = '96_DPY27_degron_old'



#file_id1 = "N2_mixed_R9"
#file_id2 = 'SDC2_degron_mixed_R9'
#file_id2 = '96_DPY27_degron_old'
#file_id2 = "SDC2_degron_mixed_R9"


# --- Example usage (assuming combined_df is defined) ---
fig, (clusters1, clusters2) = plot_metrics_multi_file(
    combined_df,
    file_id1=file_id1,
    file_id2=file_id2,
    bed_window=5000,
    clustering_window=150,
    max_reads=100,
    max_hist=200,
    # old booleans replaced by integer codes:
    nuc_axis=1,  # put nucleosomes on primary axis
    msp_axis=3,  # put msp on primary axis
    m6a_axis=3,  # put m6A on secondary axis
    c5m_axis=0,  # skip 5mC
    nuc_center_axis=0, # plot nuc_center lines on the primary axis
    match_min=False,
    smooth_window=10,
    agg_bin_width=10,
    plot_motifs=False,
    smooth_window_nuc_center=0, # 5
    agg_bin_width_nuc_center=10,
    selected_type=selected_type,
    sub_plot=True,
    subtraction_bin_size=50  # <— width of bars in subtraction plot
)




# save as png and svg, incorporating file_id1 and file_id2 and use output_stem that has already been defined, and define height / width
height = 600
width = 800
output_name = f"{output_stem}01_{file_id1}_{file_id2}_{selected_type}_plot"
print(output_name)
fig.write_image(f"{output_name}.png", height=height, width=width)
fig.write_image(f"{output_name}.svg", height=height, width=width)



In [None]:

# def compute_histogram_metrics(df, bed_window, allowed_reads, metric_list, smooth_window=20, agg_bin_width=10):
#     """
#     Computes per-base or aggregated histograms for each metric in `metric_list`.
#     Returns:
#         bins_dict: {metric -> array of bin positions}
#         metric_percentages: {metric -> array of % coverage}
#     """
#     per_base_bins = np.arange(-bed_window, bed_window + 1)
#     n_bins = len(per_base_bins)
# 
#     coverage_sets = [set() for _ in range(n_bins)]
#     for _, row in df.iterrows():
#         r_index = row['read_index']
#         if r_index not in allowed_reads:
#             continue
#         fiber_seq = row.get('fiber_sequence')
#         if isinstance(fiber_seq, str):
#             L = len(fiber_seq)
#             for i in range(L):
#                 pos = row['read_st_adj'] + 1 + i
#                 if -bed_window <= pos <= bed_window:
#                     bin_idx = pos + bed_window
#                     coverage_sets[bin_idx].add(r_index)
#     denom_counts = np.array([len(s) for s in coverage_sets])
# 
#     metric_coverage_sets = {metric: [set() for _ in range(n_bins)] for metric in metric_list}
#     for _, row in df.iterrows():
#         r_index = row['read_index']
#         if r_index not in allowed_reads:
#             continue
#         fiber_seq = row.get('fiber_sequence')
#         if not isinstance(fiber_seq, str):
#             continue
#         for metric in metric_list:
#             if metric in ['nuc', 'msp']:
#                 col = f"adjusted_ref_{metric}"
#                 length_col = f"adjusted_ref_{metric}_lengths"
#                 starts = row.get(col)
#                 if isinstance(starts, list):
#                     if (row.get(length_col) is not None and
#                         isinstance(row.get(length_col), list) and
#                         len(row.get(length_col)) == len(starts)):
#                         lengths = row.get(length_col)
#                         for s, l in zip(starts, lengths):
#                             for pos in range(s, s + l):
#                                 if -bed_window <= pos <= bed_window:
#                                     bin_idx = pos + bed_window
#                                     metric_coverage_sets[metric][bin_idx].add(r_index)
#                     else:
#                         for s in starts:
#                             if -bed_window <= s <= bed_window:
#                                 bin_idx = s + bed_window
#                                 metric_coverage_sets[metric][bin_idx].add(r_index)
#             else:
#                 # m6a or 5mC
#                 col = f"adjusted_ref_{metric}"
#                 positions = row.get(col)
#                 if isinstance(positions, list):
#                     for pos in positions:
#                         if -bed_window <= pos <= bed_window:
#                             bin_idx = pos + bed_window
#                             metric_coverage_sets[metric][bin_idx].add(r_index)
# 
#     bins_dict = {}
#     metric_percentages = {}
#     N_reads = len(allowed_reads)
#     for metric, sets in metric_coverage_sets.items():
#         counts = np.array([len(s) for s in sets])
#         if metric in ['m6a', '5mC']:
#             # Use aggregated bins for m6a / 5mC
#             agg_bins = []
#             agg_percentages = []
#             for i in range(0, len(counts), agg_bin_width):
#                 current_slice = counts[i:i+agg_bin_width]
#                 agg_count = np.sum(current_slice)
#                 current_bin_length = len(current_slice)
#                 denom_agg = current_bin_length * N_reads
#                 agg_pct = 100 * agg_count / denom_agg if denom_agg > 0 else np.nan
#                 agg_percentages.append(agg_pct)
#                 agg_bins.append(np.mean(per_base_bins[i:i+current_bin_length]))
#             agg_bins = np.array(agg_bins)
#             agg_percentages = np.array(agg_percentages)
#             if smooth_window and smooth_window > 1:
#                 series = pd.Series(agg_percentages)
#                 agg_percentages = series.rolling(window=smooth_window, center=True, min_periods=1).mean().to_numpy()
#             bins_dict[metric] = agg_bins
#             metric_percentages[metric] = agg_percentages
#         else:
#             # nuc or msp => direct per-base
#             percentages = np.where(denom_counts > 0, 100 * counts / denom_counts, np.nan)
#             if smooth_window and smooth_window > 1:
#                 series = pd.Series(percentages)
#                 percentages = series.rolling(window=smooth_window, center=True, min_periods=1).mean().to_numpy()
#             bins_dict[metric] = per_base_bins
#             metric_percentages[metric] = percentages
#     return bins_dict, metric_percentages


# def compute_scatter_data(df, bed_window, clustering_window, max_reads, metric_list,
#                          nuc_axis=0, msp_axis=0, m6a_axis=0, c5m_axis=0):
#     """
#     Builds per-read scatter data for each metric in metric_list (top and bottom subplots).
#     """
#     allowed_reads = get_allowed_reads(df, max_reads)
# 
#     clustering_bins = np.arange(-clustering_window, clustering_window + 1)
#     n_clustering_bins = len(clustering_bins)
#     pos_to_clustering_idx = {pos: idx for idx, pos in enumerate(clustering_bins)}
# 
#     # Identify which reads overlap the region of interest
#     reads_overlapping = set()
#     for _, row in df.iterrows():
#         r_index = row['read_index']
#         if r_index not in allowed_reads:
#             continue
#         fiber_seq = row.get('fiber_sequence')
#         if not isinstance(fiber_seq, str):
#             continue
#         L = len(fiber_seq)
#         for i in range(L):
#             pos = row['read_st_adj'] + 1 + i
#             if -clustering_window <= pos <= clustering_window:
#                 reads_overlapping.add(r_index)
#                 break
# 
#     read_indices = sorted(list(allowed_reads))
#     overlapping_read_indices = [r for r in read_indices if r in reads_overlapping]
#     non_overlapping_read_indices = [r for r in read_indices if r not in reads_overlapping]
# 
#     # Prepare an array for msp-based clustering
#     msp_matrix = np.full((len(overlapping_read_indices), n_clustering_bins), np.nan)
#     read_idx_to_row = {r: i for i, r in enumerate(overlapping_read_indices)}
# 
#     # First fill with zeros for any overlapping region
#     for _, row in df.iterrows():
#         r_index = row['read_index']
#         if r_index not in allowed_reads:
#             continue
#         fiber_seq = row.get('fiber_sequence')
#         if not isinstance(fiber_seq, str):
#             continue
#         if r_index in read_idx_to_row:
#             r_row = read_idx_to_row[r_index]
#             L = len(fiber_seq)
#             for i in range(L):
#                 pos = row['read_st_adj'] + 1 + i
#                 if -clustering_window <= pos <= clustering_window:
#                     col_idx = pos_to_clustering_idx[pos]
#                     msp_matrix[r_row, col_idx] = 0
# 
#     # Fill positions of msp intervals with 1
#     if 'msp' in metric_list:
#         for _, row in df.iterrows():
#             r_index = row['read_index']
#             if r_index not in allowed_reads:
#                 continue
#             if r_index in read_idx_to_row:
#                 r_row = read_idx_to_row[r_index]
#                 col = "adjusted_ref_msp"
#                 length_col = "adjusted_ref_msp_lengths"
#                 starts = row.get(col)
#                 if isinstance(starts, list):
#                     if (row.get(length_col) is not None and
#                         isinstance(row.get(length_col), list) and
#                         len(row.get(length_col)) == len(starts)):
#                         lengths = row.get(length_col)
#                         for s, l in zip(starts, lengths):
#                             for pos in range(s, s + l):
#                                 if -clustering_window <= pos <= clustering_window:
#                                     col_idx = pos_to_clustering_idx[pos]
#                                     msp_matrix[r_row, col_idx] = 1
#                     else:
#                         for s in starts:
#                             if -clustering_window <= s <= clustering_window:
#                                 col_idx = pos_to_clustering_idx[s]
#                                 msp_matrix[r_row, col_idx] = 1
# 
#     # Cluster reads by msp pattern
#     all_clusters = {}
#     if len(overlapping_read_indices) > 0:
#         if len(overlapping_read_indices) >= 2:
#             msp_for_dist = np.where(np.isnan(msp_matrix), -1, msp_matrix)
#             distances = pdist(msp_for_dist, metric='euclidean')
#             Z = linkage(distances, method='ward')
#             max_d = 0.5 * np.max(Z[:, 2]) if len(Z) > 0 else 0
#             clusters = fcluster(Z, max_d, criterion='distance')
#             clusters = clusters + 1
#             for i, r in enumerate(overlapping_read_indices):
#                 all_clusters[r] = int(clusters[i])
#         else:
#             # Only 1 read => single cluster
#             all_clusters[overlapping_read_indices[0]] = 1
#     for r in non_overlapping_read_indices:
#         all_clusters[r] = 0
# 
#     # Sort reads first by cluster number, then by read index
#     sorted_indices_with_clusters = [(r, all_clusters[r]) for r in read_indices]
#     sorted_indices_with_clusters.sort(key=lambda x: (x[1], x[0]))
#     sorted_indices = [r for r, clust in sorted_indices_with_clusters]
#     new_y_positions = {r: i for i, r in enumerate(sorted_indices)}
# 
#     scatter_data = {}
#     metric_columns = {
#         'nuc': 'adjusted_ref_nuc',
#         'msp': 'adjusted_ref_msp',
#         'm6a': 'adjusted_ref_m6a',
#         '5mC': 'adjusted_ref_5mC'
#     }
#     for metric in metric_list:
#         if metric in ['nuc', 'msp']:
#             col = metric_columns[metric]
#             length_col = f"{col}_lengths"
#             seg_scatter_x = []
#             seg_scatter_y = []
#             for _, row in df.iterrows():
#                 r_index = row['read_index']
#                 if r_index not in new_y_positions:
#                     continue
#                 y_pos = new_y_positions[r_index]
#                 starts = row.get(col)
#                 if isinstance(starts, list):
#                     if (row.get(length_col) is not None and
#                         isinstance(row.get(length_col), list) and
#                         len(row.get(length_col)) == len(starts)):
#                         lengths = row.get(length_col)
#                         for s, l in zip(starts, lengths):
#                             # Clip to bed_window
#                             if s + l < -bed_window or s > bed_window:
#                                 continue
#                             start_pos = max(s, -bed_window)
#                             end_pos = min(s + l, bed_window)
#                             seg_scatter_x.extend([start_pos, end_pos, None])
#                             seg_scatter_y.extend([y_pos, y_pos, None])
#                     else:
#                         for s in starts:
#                             if -bed_window <= s <= bed_window:
#                                 seg_scatter_x.append(s)
#                                 seg_scatter_y.append(y_pos)
#             scatter_data[metric] = (seg_scatter_x, seg_scatter_y)
#         else:
#             # m6a or 5mC => points
#             col = metric_columns[metric]
#             points = []
#             for _, row in df.iterrows():
#                 r_index = row['read_index']
#                 if r_index not in new_y_positions:
#                     continue
#                 y_pos = new_y_positions[r_index]
#                 pos_list = row.get(col)
#                 if isinstance(pos_list, list):
#                     for pos in pos_list:
#                         if -bed_window <= pos <= bed_window:
#                             points.append((y_pos, pos))
#             scatter_data[metric] = points
# 
#     return scatter_data, new_y_positions, sorted_indices_with_clusters, all_clusters

In [None]:
### OLD CODE BELOW

In [None]:
### Single fiber plotting
#modkit_bed_name = modkit_bed_name_ext
# ─────────────────── Configuration flags ────────────────────
FORCE_REPLACE = True   # True → always rerun modkit extract even if the
                        #        target .bed already exists


# For centered on MEX:
modkit_bed_name = temp_file_path

### Extracting per read modifications
out_file_names = [output_stem +
    "modkit-extract-" +
    each_condition +
    str(round(each_thresh,2)) +
    str(each_index) +
    str(each_bamfrac) +
    str(bed_window) +
    "-".join([str(x)[0] for x in type_selected[-3:]]) +  # Limiting to first 5 types
    "-".join([str(x)[0]+str(x)[-3:] for x in chromosome_selected]) +
    ".bed" for each_condition,each_thresh,each_index, each_bamfrac in zip(conditions,thresh_list,sample_indices,bam_fracs)]

modkit_bed_df = pd.read_csv(modkit_bed_name,sep='\t',header=None)
### Define bed file for modkit

# Function to run a single extract command
def modkit_extract(args):
    each_bam, each_thresh, each_condition, each_index, each_bamfrac,modkit_path, output_stem, modkit_bed_name, bed_window = args

    each_output = (
    output_stem +
    "modkit-extract-" +
    each_condition +
    str(round(each_thresh,2)) +
    str(each_index) +
    str(each_bamfrac) +
    str(bed_window) +
    "-".join([str(x)[0] for x in type_selected[-3:]]) +  # Limiting to first 5 types
    "-".join([str(x)[0]+str(x)[-3:] for x in chromosome_selected]) +
    ".bed"
)

    ### NOTE: Name of pileup file is not based on configurations
    ### TODO: Name of output file should be based on configs so that we aren't recomputing pileups withidentical conditions.

    # ─── Skip logic ───────────────────────────────────────────
    if os.path.exists(each_output) and not FORCE_REPLACE:
        print(f"Skipping (exists & FORCE_REPLACE is False): {each_output}")
        return
    # ---------------------------------------------------------

    print(f"Starting on: {each_bam}")
    command = [
        modkit_path,
        "extract",
        "--threads",
        "16",
        "--force",
        "--mapped",
        "--ignore",
        "m",
        "--include-bed",
        modkit_bed_name,
        "--log-filepath",
        each_output + each_condition + "_modkit-extract.log",
        each_bam,
        each_output
    ]
    subprocess.run(command, text=True)

    # Create a list of arguments for each task
task_args = list(zip(
    new_bam_files,
    thresh_list,
    conditions,
    sample_indices,
    bam_fracs,
    [modkit_path]*len(new_bam_files),
    [output_stem]*len(new_bam_files),
    [modkit_bed_name]*len(new_bam_files),
    [bed_window]*len(new_bam_files)
))

# Execute commands in parallel
with Pool(processes=10) as pool:
    pool.map(modkit_extract, task_args)

print("finished with:")
print(out_file_names)

In [None]:
### CELL A One row per read, use this

import os
import shutil
import logging
from math import floor
from multiprocessing import Pool, cpu_count

import numpy as np
import pandas as pd
from tqdm import tqdm

# ──────────────────────────────────────────────────────────────────────────────
# Configuration: where to write per-worker Feather files
# ──────────────────────────────────────────────────────────────────────────────
FEATHER_DIR = "./_feather_tmp"

# Features for which minus-strand rows should be flipped to feature 5′→3′
FLIP_FEATURES_CONTAINS = ["TSS", "TES", "MEX", "gene"]

###############################################################################
# Values that vary per‑base → roll into Arrow/Pandas list‑columns
POSITION_LIST_COLS = [
    "mod_qual",
    "base_qual",
    "forward_read_position",
    "ref_position",
    "rel_pos",
    "mod_qual_bin",
]

# Single‑value (read‑ / region‑level) columns we keep as scalars
READ_LEVEL_COLS = [
    "read_id", "chrom", "mod_code", "ref_mod_strand", "mod_strand",
    "ref_strand", "canonical_base", "modified_primary_base", "inferred",
    "fw_soft_clipped_start", "fw_soft_clipped_end", "read_length",
    "bed_start", "bed_end", "bed_strand", "chr_type", "type",
    "m6A_thresh", "m5mC_thresh", "condition", "exp_id",
    "rel_read_start", "rel_read_end",
]

# ──────────────────────────────────────────────────────────────────────────────
# Logging setup
# ──────────────────────────────────────────────────────────────────────────────
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)

# ──────────────────────────────────────────────────────────────────────────────
# Helper: safe DataFrame append (future‑proof for pandas 3.0)
# ──────────────────────────────────────────────────────────────────────────────
def _append_df(target: pd.DataFrame, other: pd.DataFrame) -> pd.DataFrame:
    """Append *other* to *target* (handles empty target)."""
    return other if target.empty else pd.concat([target, other], ignore_index=True)

import pyranges as pr             #  ← NEW import (install via `pip install pyranges`)
# Prevent NumPy/SciPy from spawning OpenMP/MKL threads
os.environ["OMP_NUM_THREADS"]         = "1"
os.environ["MKL_NUM_THREADS"]         = "1"
# Prevent pandas/NumExpr from spawning its own pool
os.environ["NUMEXPR_NUM_THREADS"]     = "1"

# Logging setup
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)

# Safe DataFrame append
def _append_df(target: pd.DataFrame, other: pd.DataFrame) -> pd.DataFrame:
    return other if target.empty else pd.concat([target, other], ignore_index=True)

def add_bed_columns_no_loops(
    bedmethyl_df: pd.DataFrame, bed_df: pd.DataFrame
) -> pd.DataFrame:
    # Filter to +/- strands
    bedmethyl_df = bedmethyl_df.loc[
        bedmethyl_df["mod_strand"].isin(["+", "-"])
    ].copy()
    if bedmethyl_df.empty or bed_df.empty:
        return pd.DataFrame()

    # Build PyRanges objects
    bm_pr = pr.PyRanges(
        bedmethyl_df.rename(
            columns={"chrom": "Chromosome", "ref_position": "Start"}
        )
        .assign(End=lambda df: df["Start"] + 1)
    )
    bed_pr = pr.PyRanges(
        bed_df.rename(
            columns={"chrom": "Chromosome",
                     "bed_start": "Start",
                     "bed_end": "End"}
        )
    )

    joined = bm_pr.join(bed_pr, nb_cpu=1)

    if joined.df.empty:
        return pd.DataFrame()

    # Restore original column names and ordering
    df = joined.df.rename(
        columns={
            "Chromosome": "chrom",
            "Start":      "ref_position",
            "End":        "tmp_End",
            "Start_b":    "bed_start",
            "End_b":      "bed_end",
        }
    ).drop(columns=["tmp_End"])
    keep_cols = list(bedmethyl_df.columns) + [
        "bed_start", "bed_end", "bed_strand", "chr_type", "type"
    ]
    df = df[keep_cols].copy()
    df.sort_values(
        ["chrom", "bed_start", "ref_position"],
        inplace=True, ignore_index=True
    )
    return df



# ──────────────────────────────────────────────────────────────────────────────
# ❷  Worker to process one bedmethyl file & spill to Feather
# ──────────────────────────────────────────────────────────────────────────────
def _needs_flip_row(row) -> bool:
    """Minus-strand and type contains any flip token."""
    if row.get("bed_strand") != "-":
        return False
    t = str(row.get("type", "")).lower()
    return any(tok.lower() in t for tok in FLIP_FEATURES_CONTAINS)

def _flip_lists_row(row):
    """
    Flip all POSITION_LIST_COLS for minus-strand feature rows:
      • rel_pos → negate then reorder ascending
      • all other list columns → reorder with the same index order
      • rel_read_start/end → recompute from flipped rel_pos
    Genomic coordinates (ref_position, forward_read_position) are reordered only.
    """
    if not _needs_flip_row(row):
        return row

    # ensure rel_pos exists and is list-like
    rel = row["rel_pos"]
    if not isinstance(rel, (list, tuple, np.ndarray)) or len(rel) == 0:
        return row

    rel = np.asarray(rel, dtype=np.int64)
    rel_flipped = -rel
    order = np.argsort(rel_flipped, kind="mergesort")  # stable

    # reorder every list column with the same order
    for c in POSITION_LIST_COLS:
        arr = row.get(c, None)
        if isinstance(arr, (list, tuple, np.ndarray)):
            arr = np.asarray(arr, dtype=object)
            row[c] = arr[order].tolist()

    # set flipped rel_pos explicitly, then recompute per-read span
    row["rel_pos"] = rel_flipped[order].tolist()
    row["rel_read_start"] = int(row["rel_pos"][0])
    row["rel_read_end"]   = int(row["rel_pos"][-1])
    return row
###############################################################################
# PER‑WORKER PIPELINE  – completely rewritten aggregation logic
###############################################################################
def _process_one_file(args):
    """
    Worker:

    1. Read one bedmethyl TSV (dropping columns we no longer need).
    2. Join to BED ⇒ attach interval metadata.
    3. Compute rel_pos, mod_qual_bin, strand‑specific flip.
    4. **Collapse to one row per (read_id, bed_start)**, rolling per‑base
       columns into lists.
    5. Suffix duplicate mappings of the same read ( …_1, …_2, …).
    6. Feather‑spill to {FEATHER_DIR}/{idx}.feather (list columns are OK).
    """
    (
        idx,
        each_output,
        each_condition,
        each_exp_id,
        each_bam,
        combined_bed_df,
        type_selected,
        bed_window,
        temp_dir,
    ) = args

    sample_label = os.path.basename(each_output)
    log.info("▶  Reading %s …", sample_label)

    ###########################################################################
    # 1. Load & trim to just the columns we still need
    ###########################################################################
    cols_to_keep = {
        "read_id",
        "chrom",
        "mod_qual",
        "mod_code",
        "ref_mod_strand",
        "fw_soft_clipped_end",
        "base_qual",
        "mod_strand",
        "ref_strand",
        "canonical_base",
        "modified_primary_base",
        "inferred",
        "fw_soft_clipped_start",
        "forward_read_position",
        "ref_position",
        "read_length",
    }
    try:
        bm = pd.read_csv(
            each_output,
            sep="\t",
            comment="#",
            usecols=lambda c: c in cols_to_keep,   # ← drop ref_/query_kmer, etc.
            dtype={"chrom": str},
        )
    except pd.errors.ParserError as e:
        log.error("ParserError in %s: %s", sample_label, e)
        return None

    if bm.empty:
        return None

    ###########################################################################
    # 2. Attach BED metadata
    ###########################################################################
    bm.sort_values(["chrom", "ref_position"], inplace=True, ignore_index=True)
    bm.drop_duplicates(inplace=True, ignore_index=True)
    bm = add_bed_columns_no_loops(bm, combined_bed_df)
    if bm.empty:
        log.warning("  %s → no overlapping BED intervals", sample_label)
        return None

    ###########################################################################
    # 3. Per‑base calculations
    ###########################################################################
    bm["rel_pos"] = bm["ref_position"] - bm["bed_start"] - bed_window + 1

    bm["m6A_thresh"] = m6A_thresh_dict.get(each_bam, np.nan)
    bm["m5mC_thresh"] = m5mC_thresh_dict.get(each_bam, np.nan)
    bm["mod_qual_bin"] = (bm["mod_qual"] > bm["m6A_thresh"]).astype(int)

    bm["condition"] = each_condition
    bm["exp_id"]     = each_exp_id
    ###########################################################################
    # 4. Collapse to one row per (read_id, bed_start)
    ###########################################################################
    # ──────────────────────────────────────────────────────────────────────────
    # 4. Collapse to one row per (read_id, bed_start, type)  ← NEW KEY
    # ──────────────────────────────────────────────────────────────────────────
    agg = (
        bm.groupby(["read_id", "bed_start", "type"], sort=False)  # added “type”
          .agg(
              chrom                 = ("chrom",                 "first"),
              mod_code              = ("mod_code",              "first"),
              ref_mod_strand        = ("ref_mod_strand",        "first"),
              mod_strand            = ("mod_strand",            "first"),
              ref_strand            = ("ref_strand",            "first"),
              canonical_base        = ("canonical_base",        "first"),
              modified_primary_base = ("modified_primary_base", "first"),
              inferred              = ("inferred",              "first"),
              fw_soft_clipped_start = ("fw_soft_clipped_start", "first"),
              fw_soft_clipped_end   = ("fw_soft_clipped_end",   "first"),
              read_length           = ("read_length",           "first"),
              bed_end               = ("bed_end",               "first"),
              bed_strand            = ("bed_strand",            "first"),
              chr_type              = ("chr_type",              "first"),
              m6A_thresh            = ("m6A_thresh",            "first"),
              m5mC_thresh           = ("m5mC_thresh",           "first"),
              condition             = ("condition",             "first"),
              exp_id                = ("exp_id",                "first"),
              # list‑columns
              **{c: (c, list) for c in POSITION_LIST_COLS},
              # per‑read span
              rel_read_start        = ("rel_pos",               "min"),
              rel_read_end          = ("rel_pos",               "max"),
          )
          .reset_index()   # read_id, bed_start, type now back as columns
    )

    # Flip minus-strand features to feature 5′→3′ orientation
    pre_flip_n = len(agg)
    flip_mask = (agg["bed_strand"] == "-") & agg["type"].str.contains("|".join(FLIP_FEATURES_CONTAINS), case=False, na=False)
    log.info("  Candidates to flip: %d / %d", int(flip_mask.sum()), pre_flip_n)
    
    # Row-wise apply so we can reorder all list columns consistently
    agg = agg.apply(_flip_lists_row, axis=1)

    ###########################################################################
    # 5. Append “…_1/…_2” for duplicate regions of the same read_id
    ###########################################################################
    agg["dup_idx"] = agg.groupby("read_id").cumcount() + 1
    agg["read_id"] = agg["read_id"] + "_" + agg["dup_idx"].astype(str)
    agg.drop(columns="dup_idx", inplace=True)

    ###########################################################################
    # 6. Feather spill
    ###########################################################################
    feather_path = os.path.join(temp_dir, f"{idx}.feather")
    try:
        agg.to_feather(feather_path)
        log.debug("Wrote %d aggregated rows to %s", len(agg), feather_path)
        return idx
    except Exception as e:
        log.error("Failed to write Feather for %s: %s", sample_label, e)
        return None


# ──────────────────────────────────────────────────────────────────────────────
# ❸  Parallel orchestrator, now using Feather spill
# ──────────────────────────────────────────────────────────────────────────────
def parallel_build_combined_df(
    out_file_names,
    conditions,
    exp_ids,
    new_bam_files,
    combined_bed_df,
    type_selected,
    bed_window=500,
):
    """
    Launches a Pool (≤ 90 % of CPUs), spills each worker to Feather,
    then reads them back, concatenates, and cleans up.
    """
    max_workers = max(1, floor(cpu_count() * 0.9))
    log.info("Using %d worker processes", max_workers)

    # prepare temp directory
    if os.path.exists(FEATHER_DIR):
        shutil.rmtree(FEATHER_DIR)
    os.makedirs(FEATHER_DIR, exist_ok=True)

    args_list = [
        (
            idx,
            f,
            cond,
            exp_id,
            bam,
            combined_bed_df,
            type_selected,
            bed_window,
            FEATHER_DIR,
        )
        for idx, (f, cond, exp_id, bam) in enumerate(
            zip(out_file_names, conditions, exp_ids, new_bam_files)
        )
    ]

    results = []
    with Pool(processes=max_workers) as pool:
        for idx in tqdm(
            pool.imap_unordered(_process_one_file, args_list),
            total=len(args_list),
            desc="Processing bedmethyl files",
        ):
            if idx is not None:
                results.append(idx)

    if not results:
        return pd.DataFrame()

    # read back all Feather files
    dfs = []
    for idx in sorted(results):
        path = os.path.join(FEATHER_DIR, f"{idx}.feather")
        try:
            dfs.append(pd.read_feather(path))
        except Exception as e:
            log.error("Failed reading feather %s: %s", path, e)

    combined = pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()

    # cleanup
    shutil.rmtree(FEATHER_DIR)

    return combined


# ──────────────────────────────────────────────────────────────────────────────
# ❹  Build combined_bed_df  (keeps original logic/vars)
# ──────────────────────────────────────────────────────────────────────────────
# combined_bed_df = pd.DataFrame()
# for each_bed in new_bed_files:
#     bed_path = each_bed[:-3]  # strip .gz
#     log.info("Reading BED: %s", bed_path)
#     combined_bed_df = _append_df(
#         combined_bed_df, pd.read_csv(bed_path, sep="\t", header=None)
#     )
#
# combined_bed_df.columns = [
#     "chrom",
#     "bed_start",
#     "bed_end",
#     "bed_strand",
#     "type",
#     "chr_type",
# ]
# combined_bed_df.sort_values(["chrom", "bed_start"], inplace=True, ignore_index=True)
# log.info("combined_bed_df → %d rows", len(combined_bed_df))
# ──────────────────────────────────────────────────────────────────────────────
# ❹  Use motif‑centered, MOTIFS_* regions
# ──────────────────────────────────────────────────────────────────────────────
# (assumes you've already run the centering + prefix logic into combined_bed_df_mex_cat)
combined_bed_df = combined_bed_df_mex_cat.copy()
# ensure chr_type still exists (if not already set):
if 'chr_type' not in combined_bed_df:
    combined_bed_df['chr_type'] = combined_bed_df['chrom'].apply(
        lambda x: 'X' if x == 'CHROMOSOME_X' else 'Autosome'
    )
log.info("combined_bed_df → %d rows (centered MOTIFS_)", len(combined_bed_df))


# ──────────────────────────────────────────────────────────────────────────────
# ❺  Kick off parallel bedmethyl processing
# ──────────────────────────────────────────────────────────────────────────────
comb_bedmethyl_plot_df = parallel_build_combined_df(
    out_file_names=out_file_names,
    conditions=conditions,
    exp_ids=exp_ids,
    new_bam_files=new_bam_files,
    combined_bed_df=combined_bed_df,
    type_selected=type_selected,
    bed_window=bed_window,
)

if comb_bedmethyl_plot_df.empty:
    log.error("No data returned – terminating early.")
    raise SystemExit(1)

log.info("Combined bedmethyl rows: %d", len(comb_bedmethyl_plot_df))

# # ──────────────────────────────────────────────────────────────────────────────
# # ❻  Strand‑specific rel_pos flip  (TSS / TES / MEX / gene)
# # ──────────────────────────────────────────────────────────────────────────────
# flip_features = ("TSS", "TES", "MEX", "gene")
# flip_mask = (
#     comb_bedmethyl_plot_df["type"].str.contains("|".join(flip_features), na=False)
#     & (comb_bedmethyl_plot_df["bed_strand"] == "-")
# )
# flipped = flip_mask.sum()
# if flipped:
#     log.info("Flipping rel_pos sign for %d rows (minus‑strand features)", flipped)
#     comb_bedmethyl_plot_df.loc[flip_mask, "rel_pos"] *= -1
#
# # ──────────────────────────────────────────────────────────────────────────────
# # ❼  Per‑read start / end aggregation (unchanged)
# # ──────────────────────────────────────────────────────────────────────────────
# log.info("Computing rel_read_start / rel_read_end per read_id …")
# _small = comb_bedmethyl_plot_df[["read_id", "bed_start", "rel_pos"]]
#
# _grouped = (
#     _small.groupby(["read_id", "bed_start"])["rel_pos"]
#     .agg(rel_read_start="min", rel_read_end="max")
#     .reset_index()
# )
#
# comb_bedmethyl_plot_df = comb_bedmethyl_plot_df.merge(
#     _grouped, on=["read_id", "bed_start"], how="left"
# )
#
# del _small, _grouped

# ──────────────────────────────────────────────────────────────────────────────
# ❽  Debug summaries
# ──────────────────────────────────────────────────────────────────────────────
log.info(
    "Unique read_id count by condition:\n%s",
    comb_bedmethyl_plot_df.groupby("condition")["read_id"].nunique(),
)

# ──────────────────────────────────────────────────────────────────────────────
# comb_bedmethyl_plot_df is ready for downstream plotting / analysis
# ──────────────────────────────────────────────────────────────────────────────
print(comb_bedmethyl_plot_df.iloc[0])


In [None]:
###############################################################################
# Parallel motif ↔ read‑interval join  – NOW COMPATIBLE WITH PER‑READ DATAFRAME
#
# You already created `comb_bedmethyl_plot_df` so that each row is
#  (read_id_with_suffix, bed_start) and all per‑base columns are Python lists.
#
# This pipeline:
#   1. Loads & filters MEX / MEXII / motifC motifs (optional cluster collapse)
#   2. Splits the per‑read BED intervals into N chunks
#   3. Runs bedtools intersect in parallel (chunk × motif)
#   4. Streams the worker results back into two list‑columns:
#        • motif_rel_start    → list[int]
#        • motif_attributes   → list[tuple(motif_id, ln_p, strand)]
#
# Result: two new list‑columns are **added** to comb_bedmethyl_plot_df
#         (row alignment via comb_idx).
###############################################################################

import os, math, gc, json, subprocess, shutil, tempfile, datetime
from pathlib import Path
from collections import defaultdict
import multiprocessing as mp

import pandas as pd, numpy as np, pyranges as pr, psutil
from tqdm import tqdm
import pyarrow.feather as feather
from concurrent.futures import ThreadPoolExecutor

# ─────────────────────────── Config & flags ────────────────────────────────
BEDTOOLS_BIN           = "/Data1/software/bedtools2/bin/bedtools"
MOTIF_WINDOW           = bed_window
NUM_WORKERS            = min(16, os.cpu_count() - 2)
STEP_LINES             = 500_000          # progress print interval

COLLAPSE_MOTIF_CLUSTERS = True            # True → deduplicate motif clusters
DEBUG_PROGRESS          = True            # RSS + progress prints
DEBUG_SUMMARY           = True            # .info() / JSON summaries

TMPDIR      = Path(tempfile.mkdtemp(prefix="motif_intersect_"))
RESULT_DIR  = TMPDIR / "worker_results"
RESULT_DIR.mkdir()

# ───────────────────────────── utilities ───────────────────────────────────
def _rss_gb() -> float:
    return psutil.Process(os.getpid()).memory_info().rss / 1_073_741_824

def log_mem(label):
    if DEBUG_PROGRESS:
        gc.collect()
        ts = datetime.datetime.now().strftime("%H:%M:%S")
        print(f"[{ts}] {label:<40}: {_rss_gb():6.2f} GB RSS")

def log_df(df: pd.DataFrame, name: str):
    if DEBUG_SUMMARY:
        print(f"\n── {name} ──"); print(df.info(memory_usage='deep')); print(df.head())

def log_json(obj, name: str):
    if DEBUG_SUMMARY:
        print(f"\n── {name} ──"); print(json.dumps(obj, indent=2))

# ───────────────────── Stage 0: environment snapshot ───────────────────────
if DEBUG_PROGRESS:
    print(f"bedtools: {subprocess.check_output([BEDTOOLS_BIN,'--version']).decode().strip()}")
    print(f"CPUs     : {os.cpu_count()}  (workers = {NUM_WORKERS})")
    print(f"System RAM: {round(psutil.virtual_memory().total/1_073_741_824)} GB\n")

# ───────────────────── Stage 1: load + prepare FIMO motifs ─────────────────
fimo_files = [
    "/Data1/ext_data/motifs/fimo_MEX_0.01.tsv",
    "/Data1/ext_data/motifs/fimo_MEXII_0.01.tsv",
    "/Data1/ext_data/motifs/fimo_motifc_0.01.tsv",
]
dfs = []
for f in fimo_files:
    if DEBUG_PROGRESS: print(f"Loading {f}")
    dfs.append(pd.read_csv(f, sep='\t'))
fimo_df = pd.concat(dfs, ignore_index=True)
log_df(fimo_df, "raw FIMO combined")

# basic harmonisation / filtering
fimo_df = (
    fimo_df
      .assign(
          Chromosome     = lambda d: d['sequence_name'].str.replace('chr','CHROMOSOME_'),
          Start          = lambda d: d['start'],
          End            = lambda d: d['stop'],
          motif_priority = lambda d: d['motif_id'].map({'MEXII':1,'MEX':2,'motifC':3}),
          ln_p           = lambda d: np.log(d['p-value']),
      )
      .query("(motif_id=='MEX'    and ln_p<=-13) or "
             "(motif_id=='MEXII'  and ln_p<=-12) or "
             "(motif_id=='motifC' and ln_p<= -9 )")
      # ▼ keep matched_sequence and rename p‑value → p_value
      .loc[:, ['Chromosome','Start','End','strand','score','p-value',
               'motif_id','motif_priority','ln_p','matched_sequence']]
      .rename(columns={'p-value':'p_value'})
)
valid_chroms = set(comb_bedmethyl_plot_df['chrom'].unique())
fimo_df = fimo_df[fimo_df['Chromosome'].isin(valid_chroms)].copy()
log_df(fimo_df, "after motif filters")

# optional motif‑cluster collapse
if COLLAPSE_MOTIF_CLUSTERS:
    pr_df = pr.PyRanges(fimo_df.astype({'Start':int,'End':int}))
    clusters_df = (pr_df.cluster(strand=False).df
                      .sort_values(['Cluster','motif_priority','Start'])
                      .drop_duplicates('Cluster', keep='first'))
else:
    clusters_df = fimo_df.copy(); clusters_df['Cluster'] = np.arange(len(clusters_df))
motif_bed = TMPDIR/"fimo.bed"
# after (optional) cluster‑collapse …
bed_cols = ['chrom','motif_start','motif_end','strand','score','p_value',
            'motif_id','motif_priority','ln_p','matched_sequence']   # ◀︎ NEW
(clusters_df
     .rename(columns={'Chromosome':'chrom','Start':'motif_start','End':'motif_end'})
     .sort_values(['chrom','motif_start','motif_end'], kind='mergesort')
     [bed_cols]
     .to_csv(motif_bed, sep='\t', header=False, index=False))

log_mem("after writing motif BED")

# ───────────────── Stage 1 b: load + prepare TSS / TES sites ───────────────
TSS_TYPES = [f"TSS_q{i}" for i in range(1, 5)]
TES_TYPES = [f"TES_q{i}" for i in range(1, 5)]
QTYPES    = TSS_TYPES + TES_TYPES

tss_bed = TMPDIR / "tss_q.bed"
tes_bed = TMPDIR / "tes_q.bed"

# create empty files upfront
for p in (tss_bed, tes_bed):
    p.parent.mkdir(parents=True, exist_ok=True)
    open(p, "wb").close()

n_tss = n_tes = 0
try:
    tt_df = (
        pd.read_csv(bed_file, sep="\t")
          .rename(columns={"chromosome":"chrom", "start":"Site", "end":"Site_end"})
          .query("type in @QTYPES")
          .loc[:, ["chrom","Site","Site_end","strand","type"]]
    )
    # keep only valid chroms
    tt_df = tt_df[tt_df["chrom"].isin(valid_chroms)].copy()

    # force integer coordinates, drop bad rows
    tt_df["Site"]     = pd.to_numeric(tt_df["Site"],     errors="coerce").astype("Int64")
    tt_df["Site_end"] = pd.to_numeric(tt_df["Site_end"], errors="coerce").astype("Int64")
    tt_df = tt_df.dropna(subset=["Site","Site_end"]).astype({"Site":int,"Site_end":int})

    if not tt_df.empty:
        ss = tt_df.query("type in @TSS_TYPES").sort_values(["chrom","Site","Site_end"], kind="mergesort")
        ts = tt_df.query("type in @TES_TYPES").sort_values(["chrom","Site","Site_end"], kind="mergesort")
        n_tss, n_tes = len(ss), len(ts)

        if n_tss:
            ss.to_csv(tss_bed, sep="\t", header=False, index=False)
        if n_tes:
            ts.to_csv(tes_bed, sep="\t", header=False, index=False)

    log_df(tt_df, "TSS/TES filtered")
finally:
    RUN_TSS = n_tss > 0
    RUN_TES = n_tes > 0
    if DEBUG_PROGRESS:
        print(f"RUN_TSS={RUN_TSS} (n={n_tss})  RUN_TES={RUN_TES} (n={n_tes})")
    log_mem("after TSS/TES setup")

# ───────────────── Stage 2: prepare per‑read BED chunks ────────────────────
comb_bedmethyl_plot_df['bed_start'] = comb_bedmethyl_plot_df['bed_start'].astype(int)
comb_bedmethyl_plot_df['bed_end']   = comb_bedmethyl_plot_df['bed_end'].astype(int)

comb_bed = comb_bedmethyl_plot_df[['chrom','bed_start','bed_end']].copy()
comb_bed['comb_idx'] = comb_bed.index               # preserve row → index mapping

chunk_size  = math.ceil(len(comb_bed)/NUM_WORKERS)
chunk_paths = []
for i in range(NUM_WORKERS):
    s,e = i*chunk_size, min((i+1)*chunk_size,len(comb_bed))
    if s>=e: break
    path = TMPDIR/f"comb_chunk_{i}.bed"
    (comb_bed.iloc[s:e]
        .sort_values(['chrom','bed_start','bed_end'],kind='mergesort')
        .to_csv(path, sep='\t', header=False, index=False))
    chunk_paths.append(path)

log_json({"n_chunks":len(chunk_paths), "rows_per_chunk":chunk_size}, "chunking stats")
log_mem("after writing chunk BEDs")

# ───────────────── Stage 3: worker – bedtools intersect ────────────────────
def _run_intersect(chunk_path: Path):
    import math, json, datetime, psutil, gc, pandas as pd
    from collections import defaultdict
    out_file = RESULT_DIR / f"{chunk_path.stem}.feather"

    agg_rel   = defaultdict(list)   # comb_idx → list[int]
    agg_attrs = defaultdict(list)   # comb_idx → list[(motif_id, ln_p, strand)]

    cmd = [BEDTOOLS_BIN,"intersect","-a",str(chunk_path),"-b",str(motif_bed.resolve()),
       "-wa","-wb","-sorted"]
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True, bufsize=1)

    for n,line in enumerate(proc.stdout,1):
        c   = line.rstrip("\n").split("\t")
        idx = int(c[3]); bed_start = int(c[1])
        rel = int(c[5]) - bed_start - MOTIF_WINDOW   # centre‑based offset
        agg_rel[idx].append(rel)
        ln_p = round(math.log(float(c[9])),1)            # natural‑log p, 1 dp
        # (c) 4‑tuple with matched_sequence in col 13
        agg_attrs[idx].append((c[10], ln_p, c[7], c[13]))   # ◀︎ UPDATED
        if DEBUG_PROGRESS and n%STEP_LINES==0:
            print(f"[{datetime.datetime.now():%H:%M:%S}] {chunk_path.name} "
                  f"{n/1e6:5.1f} M  {psutil.Process().memory_info().rss/1e9:4.1f} GB")

    proc.wait()
    if proc.returncode:
        raise RuntimeError(f"{chunk_path.name} bedtools exit {proc.returncode}")

    df_keys = list(agg_rel.keys())
    pd.DataFrame({
        "comb_idx":         df_keys,
        "motif_rel_start":  [agg_rel[k]   for k in df_keys],
        "motif_attributes": [json.dumps(agg_attrs[k]) for k in df_keys],
    }).to_feather(out_file)
    return chunk_path.name

# ───────────────── Stage 3 a: worker – TSS intersect ───────────────────────
def _run_intersect_tss(chunk_path: Path):
    import datetime, psutil, gc, math, json, pandas as pd
    from collections import defaultdict
    out_file = RESULT_DIR / f"{chunk_path.stem}_tss.feather"

    # no TSS sites → emit empty shard and exit
    if not RUN_TSS:
        pd.DataFrame({"comb_idx":[], "tss_rel_start":[], "tss_attributes":[]}).to_feather(out_file)
        return chunk_path.name

    agg_rel, agg_attrs = defaultdict(list), defaultdict(list)
    cmd = [BEDTOOLS_BIN,"intersect","-a",str(chunk_path),"-b",str(tss_bed.resolve()),
       "-wa","-wb","-sorted"]
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True, bufsize=1)

    for n,line in enumerate(proc.stdout,1):
        c = line.rstrip("\n").split("\t")
        idx, bed_start = int(c[3]), int(c[1])
        rel = int(c[5]) - bed_start -  MOTIF_WINDOW                    # no window shift
        if abs(rel) > MOTIF_WINDOW:          # keep only ±MOTIF_WINDOW hits
            continue
        agg_rel[idx].append(rel)
        agg_attrs[idx].append((int(c[5]), c[7], c[8]))  # (start, strand, type)
        if DEBUG_PROGRESS and n%STEP_LINES==0:
            print(f"[{datetime.datetime.now():%H:%M:%S}] {chunk_path.name} "
                  f"TSS {n/1e6:5.1f} M  {psutil.Process().memory_info().rss/1e9:4.1f} GB")
    proc.wait()
    if proc.returncode:
        raise RuntimeError(f"{chunk_path.name} TSS bedtools exit {proc.returncode}")

    pd.DataFrame({
        "comb_idx":        list(agg_rel.keys()),
        "tss_rel_start":   [agg_rel[k]   for k in agg_rel],
        "tss_attributes":  [json.dumps(agg_attrs[k]) for k in agg_attrs],
    }).to_feather(out_file)
    return chunk_path.name

# ───────────────── Stage 3 b: worker – TES intersect ───────────────────────
def _run_intersect_tes(chunk_path: Path):
    import datetime, psutil, gc, math, json, pandas as pd
    from collections import defaultdict
    out_file = RESULT_DIR / f"{chunk_path.stem}_tes.feather"

    # no TES sites → emit empty shard and exit
    if not RUN_TES:
        pd.DataFrame({"comb_idx":[], "tes_rel_start":[], "tes_attributes":[]}).to_feather(out_file)
        return chunk_path.name

    agg_rel, agg_attrs = defaultdict(list), defaultdict(list)
    cmd = [BEDTOOLS_BIN,"intersect","-a",str(chunk_path),"-b",str(tes_bed.resolve()),
       "-wa","-wb","-sorted"]
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True, bufsize=1)

    for n,line in enumerate(proc.stdout,1):
        c = line.rstrip("\n").split("\t")
        idx, bed_start = int(c[3]), int(c[1])
        rel = int(c[5]) - bed_start - MOTIF_WINDOW                     # no window shift
        if abs(rel) > MOTIF_WINDOW:
            continue
        agg_rel[idx].append(rel)
        agg_attrs[idx].append((int(c[5]), c[7], c[8]))  # (start, strand, type)
        if DEBUG_PROGRESS and n%STEP_LINES==0:
            print(f"[{datetime.datetime.now():%H:%M:%S}] {chunk_path.name} "
                  f"TES {n/1e6:5.1f} M  {psutil.Process().memory_info().rss/1e9:4.1f} GB")
    proc.wait()
    if proc.returncode:
        raise RuntimeError(f"{chunk_path.name} TES bedtools exit {proc.returncode}")

    pd.DataFrame({
        "comb_idx":        list(agg_rel.keys()),
        "tes_rel_start":   [agg_rel[k]   for k in agg_rel],
        "tes_attributes":  [json.dumps(agg_attrs[k]) for k in agg_attrs],
    }).to_feather(out_file)
    return chunk_path.name


# ───────────────── Stage 4: parallel intersection ──────────────────────────
with mp.Pool(len(chunk_paths)) as pool, \
     tqdm(total=len(chunk_paths), desc="Chunks", unit="chunk") as pbar:
    for _ in pool.imap_unordered(_run_intersect, chunk_paths):
        pbar.update()
log_mem("all intersects done (worker files on disk)")

# ───────────────── Stage 4 a: parallel TSS intersection ────────────────────
if RUN_TSS:
    with mp.Pool(len(chunk_paths)) as pool, \
         tqdm(total=len(chunk_paths), desc="Chunks-TSS", unit="chunk") as pbar:
        for _ in pool.imap_unordered(_run_intersect_tss, chunk_paths):
            pbar.update()
    log_mem("TSS intersects done")
else:
    if DEBUG_PROGRESS: print("Skipping TSS intersects (no TSS sites).")

# ───────────────── Stage 4 b: parallel TES intersection ────────────────────
if RUN_TES:
    with mp.Pool(len(chunk_paths)) as pool, \
         tqdm(total=len(chunk_paths), desc="Chunks-TES", unit="chunk") as pbar:
        for _ in pool.imap_unordered(_run_intersect_tes, chunk_paths):
            pbar.update()
    log_mem("TES intersects done")
else:
    if DEBUG_PROGRESS: print("Skipping TES intersects (no TES sites).")

# ───────────────── Stage 5: consolidate worker shards  ─────────────────────
# (motifs only – exclude the “_tss” and “_tes” worker outputs)
worker_files = [
    p for p in RESULT_DIR.glob("*.feather")
    if not p.stem.endswith(("_tss", "_tes"))
]
all_rel, all_attrs = defaultdict(list), defaultdict(list)

def _merge_one(path: Path):
    df = feather.read_table(path).to_pandas()
    for idx, rels, attrs_json in zip(
            df['comb_idx'], df['motif_rel_start'], df['motif_attributes']):
        all_rel[idx].extend(rels)
        all_attrs[idx].extend(map(tuple, json.loads(attrs_json)))
    return len(df)

with ThreadPoolExecutor(max_workers=min(8, len(worker_files))) as pool, \
     tqdm(total=len(worker_files), desc="merge‑motif", unit="file") as pbar:
    for _ in pool.map(_merge_one, worker_files):
        pbar.update()
log_mem("motif master dicts built")

# ───────────────── merge TSS files ─────────────────────────────────────────
worker_files_tss = list(RESULT_DIR.glob("*_tss.feather"))
all_tss_rel, all_tss_attrs = defaultdict(list), defaultdict(list)

def _merge_tss(path: Path):
    df = feather.read_table(path).to_pandas()
    for idx, rels, attrs_json in zip(df['comb_idx'], df['tss_rel_start'], df['tss_attributes']):
        all_tss_rel[idx].extend(rels)
        all_tss_attrs[idx].extend(map(tuple, json.loads(attrs_json)))
    return len(df)

with ThreadPoolExecutor(max_workers=min(8,len(worker_files_tss))) as pool, \
     tqdm(total=len(worker_files_tss), desc="merge‑TSS", unit="file") as pbar:
    for _ in pool.map(_merge_tss, worker_files_tss):
        pbar.update()

# ───────────────── merge TES files ─────────────────────────────────────────
worker_files_tes = list(RESULT_DIR.glob("*_tes.feather"))
all_tes_rel, all_tes_attrs = defaultdict(list), defaultdict(list)

def _merge_tes(path: Path):
    df = feather.read_table(path).to_pandas()
    for idx, rels, attrs_json in zip(df['comb_idx'], df['tes_rel_start'], df['tes_attributes']):
        all_tes_rel[idx].extend(rels)
        all_tes_attrs[idx].extend(map(tuple, json.loads(attrs_json)))
    return len(df)

with ThreadPoolExecutor(max_workers=min(8,len(worker_files_tes))) as pool, \
     tqdm(total=len(worker_files_tes), desc="merge‑TES", unit="file") as pbar:
    for _ in pool.map(_merge_tes, worker_files_tes):
        pbar.update()
log_mem("TSS/TES master dicts built")

# ───────────────── Stage 6: attach list‑columns to main DataFrame ──────────
motif_rel_list, motif_attr_list = [], []
for i in range(len(comb_bedmethyl_plot_df)):
    motif_rel_list.append(tuple(all_rel.get(i, [])))
    motif_attr_list.append(tuple(all_attrs.get(i, [])))

comb_bedmethyl_plot_df['motif_rel_start']  = motif_rel_list
comb_bedmethyl_plot_df['motif_attributes'] = motif_attr_list

# stage 6
assert all(len(r) == len(a) for r, a in zip(
    comb_bedmethyl_plot_df['motif_rel_start'],
    comb_bedmethyl_plot_df['motif_attributes']
)), "row alignment broken – lengths differ after 4‑tuple upgrade"


# ───────────────── Stage 6: attach TSS / TES columns ──────────────────────
tss_rel_list  = [tuple(all_tss_rel .get(i, [])) for i in range(len(comb_bedmethyl_plot_df))]
tss_attr_list = [tuple(all_tss_attrs.get(i, [])) for i in range(len(comb_bedmethyl_plot_df))]
tes_rel_list  = [tuple(all_tes_rel .get(i, [])) for i in range(len(comb_bedmethyl_plot_df))]
tes_attr_list = [tuple(all_tes_attrs.get(i, [])) for i in range(len(comb_bedmethyl_plot_df))]

comb_bedmethyl_plot_df['tss_rel_start']   = tss_rel_list
comb_bedmethyl_plot_df['tss_attributes']  = tss_attr_list
comb_bedmethyl_plot_df['tes_rel_start']   = tes_rel_list
comb_bedmethyl_plot_df['tes_attributes']  = tes_attr_list


# ───────────────── Stage 7: clean‑up scratch dir ───────────────────────────

log_mem("after attaching motif columns")

log_df(comb_bedmethyl_plot_df.head(), "final merged preview")
shutil.rmtree(TMPDIR, ignore_errors=True)
print("\n✅  motif ↔ read pipeline complete\n")


In [None]:
# check that comb_bedmethyl_plot_df contains rows where type == "MOTIFS_rex8"
if 'type' in comb_bedmethyl_plot_df.columns:
    if not comb_bedmethyl_plot_df[comb_bedmethyl_plot_df['type'] == 'MOTIFS_rex32'].empty:
        print("Found rows with type 'MOTIFS_rex32':")
        print(comb_bedmethyl_plot_df[comb_bedmethyl_plot_df['type'] == 'MOTIFS_rex32'].head())
    else:
        print("No rows found with type 'MOTIFS_rex32'.")

# print all column names
print("\nAll column names in comb_bedmethyl_plot_df:")
print(comb_bedmethyl_plot_df.columns.tolist())


In [None]:
### Required
import os
import numpy as np
import pandas as pd
from glob import glob
from multiprocessing import Pool, cpu_count
from tqdm import tqdm

# Assume nanotools is a custom module available in your environment
# For example:
# class Nanotools:
#     def display_sample_rows(self, df, n=5):
#         if df is not None and not df.empty:
#             print(f"\nDisplaying first {n} sample rows:")
#             print(df.head(n))
#         elif df is not None and df.empty:
#             print("\nDataFrame is empty.")
#         else:
#             print("\nDataFrame is None.")
# nanotools = Nanotools()
# Ensure this (or your actual nanotools) is defined if you run this standalone.


# -----------------------------------------------------------------------------
# Configuration (Assumed to be defined elsewhere as per your request)
# -----------------------------------------------------------------------------
# These would be defined in your actual script:
# type_selected = ["your_type1", "your_type2", "your_type3"] # Example
# thresh_list = [0.1] # Example
# bam_fracs = [0.2] # Example
# bed_window = 150 # Example
# comb_bedmethyl_plot_df = pd.DataFrame() # Example: This would be your large DataFrame to be chunked

FORCE_OVERWRITE = True  # set to True if you want to overwrite existing chunk files

base_fn = (
    "temp_files/"
    + "plot_df_"
    + "-".join([str(x)[:3] for x in type_selected[-3:]])
    + "_"
    + str(thresh_list[0])
    + "_"
    + str(bam_fracs[0])
    + str(bed_window) # Original had no separator before bed_window, kept as is.
)
# glob pattern for chunk files
chunk_pattern = base_fn + "_part*.pkl"


# Create temp_files directory if it doesn't exist, for writing.
if not os.path.exists("temp_files"):
    os.makedirs("temp_files", exist_ok=True)

chunks_found = sorted(glob(chunk_pattern))

# -----------------------------------------------------------------------------
# Helper function for parallel reading
def _read_pickle_part(filepath):
    """Reads a single pickle chunk into a pandas DataFrame."""
    return pd.read_pickle(filepath)


# -----------------------------------------------------------------------------
# Load or write in parallel depending on existing files and FORCE_OVERWRITE
# -----------------------------------------------------------------------------
plot_df = None # Initialize plot_df

if chunks_found and not FORCE_OVERWRITE:
    print(f"Found {len(chunks_found)} chunk file(s), loading in parallel…")
    num_read_workers = max(len(chunks_found), cpu_count(), 64)

    if len(chunks_found) == 1:
        print("Only one chunk found, reading directly.")
        plot_df = pd.read_pickle(chunks_found[0])
    else:
        try:
            with Pool(processes=num_read_workers) as pool:
                df_parts = list(tqdm(
                    pool.imap(_read_pickle_part, chunks_found),
                    total=len(chunks_found),
                    desc="Reading pickle chunks"
                ))
            plot_df = pd.concat(df_parts, ignore_index=True)
        except Exception as e:
            print(f"Error during parallel read: {e}. Falling back to sequential.")
            plot_df = pd.concat(
                (pd.read_pickle(fn) for fn in chunks_found),
                ignore_index=True
            )

elif chunks_found and FORCE_OVERWRITE:
    print(f"Found {len(chunks_found)} chunk file(s), but FORCE_OVERWRITE=True → rewriting…")
    # Fall through to writing below
    chunks_found = [] # This will trigger the 'if not chunks_found:' block
else:
    print("No chunk files found — writing in parallel…")
    # This implies chunks_found is empty, so the next block will execute

if not chunks_found: # This block executes if chunks weren't found OR if FORCE_OVERWRITE was true
    import numpy as np
    from multiprocessing import Pool, cpu_count
    from tqdm import tqdm

    def _write_chunk(args):
        i, start, stop = args
        # comb_bedmethyl_plot_df and base_fn must be defined in the global scope
        part = comb_bedmethyl_plot_df.iloc[start:stop]
        fn = f"{base_fn}_part{i}.pkl"
        part.to_pickle(fn)
        return i  # so tqdm knows one task completed

    # calculate how many rows and splits
    num_splits = 54
    n = len(comb_bedmethyl_plot_df)
    step = int(np.ceil(n / num_splits))

    # prepare (chunk_index, start, stop) tuples, skipping empty slices
    tasks = [
        (i, i * step, min((i + 1) * step, n))
        for i in range(num_splits)
        if i * step < n
    ]

    n_workers = max(cpu_count() - 20, 1)
    with Pool(processes=n_workers) as pool:
        for _ in tqdm(
            pool.imap_unordered(_write_chunk, tasks),
            total=len(tasks),
            desc="Writing pickle chunks"
        ):
            pass

    print(f"Finished writing {len(tasks)} chunk(s).")


In [None]:
### Process to add per-read statistics (placeholders for NRL, MAD, NUC) (REQUIRED)
import os, glob, math, json
# ---------------------------------------------------------------------------
# 1.  Stream the pickle shards (per-read rows) in parallel
# ---------------------------------------------------------------------------
print("DEBUG: Parallel streaming of pickle chunks …")
chunk_files = sorted(glob.glob(chunk_pattern))

# Use all but one core on a 64‐core machine
n_workers = max(cpu_count() - 10, 1)

def _process_chunk(pkl_path):
    df = pd.read_pickle(pkl_path)
    return df if not df.empty else None

merged_parts = []
with Pool(n_workers) as pool:
    for part in tqdm(pool.imap_unordered(_process_chunk, chunk_files),
                     total=len(chunk_files), unit="file"):
        if part is not None:
            merged_parts.append(part)

merged_df = pd.concat(merged_parts, ignore_index=True) if merged_parts else pd.DataFrame()
print(f"DEBUG: merged_df assembled → shape={merged_df.shape}")

# ---------------------------------------------------------------------------
# 2.  Placeholder columns for NUC midpoint calculations
# ---------------------------------------------------------------------------
merged_df['smallest_positive_nuc_midpoint'] = np.nan
merged_df['greatest_negative_nuc_midpoint'] = np.nan
merged_df['closest_nuc']                   = np.nan
merged_df['inter_nuc_dist']                = np.nan

# ---------------------------------------------------------------------------
# 3.  Placeholder columns for MAD region summary
# ---------------------------------------------------------------------------
merged_df['closest_MAD_region']    = None
merged_df['MAD_size']              = np.nan
merged_df['closest_MAD_midpoint']  = np.nan

# ---------------------------------------------------------------------------
# 4.  Placeholder columns for percent_MAD / percent_NUC / percent_OTHER
# ---------------------------------------------------------------------------
merged_df['percent_MAD'] = np.nan
merged_df['percent_NUC'] = np.nan
merged_df['percent_OTHER'] = np.nan

# ---------------------------------------------------------------------------
# 5.  Placeholder columns for Fiber-NRL lists
# ---------------------------------------------------------------------------
n = len(merged_df)
merged_df['fiber_NRL_list']     = [[] for _ in range(n)]
merged_df['fiber_NRL_list_pos'] = [[] for _ in range(n)]
merged_df['fiber_NRL_list_neg'] = [[] for _ in range(n)]

# ---------------------------------------------------------------------------
# 6.  Final tidy-up & sorting
# ---------------------------------------------------------------------------
merged_df = merged_df.sort_values(by='percent_MAD', ascending=True).reset_index(drop=True)
print("DEBUG: merged_df sorted ✓")

# ---------------------------------------------------------------------------
# 7.  Build grouped (one row per read_id) helper table
# ---------------------------------------------------------------------------
grouped = (
    merged_df
    .drop_duplicates('read_id')
    .sort_values('read_id')
    .reset_index(drop=True)
)

# Add placeholder lists for nucleosome-related columns
grouped['nucs_list']                    = [[] for _ in range(len(grouped))]
grouped['inter_nuc_sub']                = [[] for _ in range(len(grouped))]
grouped['nuc_list_internuc_aligned']    = [[] for _ in range(len(grouped))]
grouped['nucs_list_closest_aligned']    = [[] for _ in range(len(grouped))]

# Map exp_id from condition
grouped['exp_id'] = grouped['condition'].map(dict(zip(conditions, exp_ids)))

nanotools.display_sample_rows(grouped, 5)
nanotools.display_sample_rows(merged_df, 5)

# ─────────────────────────────────────────────────────────────────
# 8.  Optional down-sampling for plotting
# ─────────────────────────────────────────────────────────────────

# Number of reads per condition for plotting (0 = disable)
n_read_ids = 0

def downsample_group(group):
    if n_read_ids == 0:
        return group
    ids = group['read_id'].unique()
    if len(ids) <= n_read_ids:
        return group
    # … rest of your sampling logic …

down_sampled_plot_df = (
    merged_df
    .groupby(['condition','chr_type','type'], group_keys=False)
    .apply(downsample_group)
    .reset_index(drop=True)
    .sort_values(by=['smallest_positive_nuc_midpoint', 'greatest_negative_nuc_midpoint'])
)

down_sampled_group_df = grouped[
    grouped['read_id'].isin(down_sampled_plot_df['read_id'])
]

print("DEBUG: Down-sampling complete →",
      down_sampled_plot_df.shape, down_sampled_group_df.shape)


In [None]:
down_sampled_plot_df.head()

In [None]:

"""
Refactored plotting pipeline
────────────────────────────
All external variables referenced here (e.g. `type_selected`, `analysis_cond`,
`down_sampled_plot_df`, …) are *assumed* to be defined elsewhere exactly as in
your current workflow.  Variables and constants that were originally declared
in this file remain top‑level and unmodified.

The code is organized as:

1. Imports & global constants
2. Generic helpers
3. Motif‑table builder
4. Main plotting routine (`create_plot`)
5. Worker wrapper (`_save_plot`)
6. CLI / script entry‑point
"""
# ───────────────────────────── Imports ────────────────────────────── #
import importlib
import io
import os
from itertools import product
from multiprocessing import Pool, cpu_count
from typing import List, Sequence, Tuple, Optional

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from PIL import Image as PILImage           # disambiguate from IPython.display.Image
from plotly.subplots import make_subplots
from scipy.ndimage import gaussian_filter1d

# still imported (for compatibility with earlier notebooks)
from scipy.signal import find_peaks
from IPython.display import Image, display

# reload local dev library for hot‑reload loops in notebooks
import nanotools
importlib.reload(nanotools)

# ──────────────────────────── Constants ───────────────────────────── #
# (unchanged ‑ feel free to edit values directly)
READ_PIXEL_CLR    = "#b12537"
M6A_DOT_CLR       = "#b12537"
NUC_LINE_CLR      = "#35b779"
NUC_HISTO_OPACITY = 0.60

TEMPLATE          = "plotly_white"
PNG_DPI           = 300
TABLE_CELL_HEIGHT = 22
TABLE_HEADER_H    = 28
TABLE_MARGIN_PX   = 12

# Figure size (new top‑level “config” values)
FIG_WIDTH  = 700
FIG_HEIGHT = 500

# Apply global Plotly template immediately
pio.templates.default = TEMPLATE

# ─────────────────────────── Global BED load ───────────────────────── #
# This loads once at module load time:
# We will filter for rows where “type” contains "TSS_q" or "TES_q"
tss_tes_df = pd.read_csv(
    "/Data1/reference/tss_tes_rex_combined_v27_WS235.bed",
    sep=r"\s+",
    usecols=["chromosome", "start", "end", "strand", "type"]
)

# ───────────────────── Sorting config ───────────────────── #
SORT_METHOD    = 0        # 1=binned‐HC, 3=PCA, 4=DTW (else=ACF)
BIN_COUNT      = 2000        # number of bins for methods 1, 3, 4
HC_LINKAGE     = "ward"    # for hierarchical clustering
PCA_COMPONENTS = 1         # PCs to extract
DTW_DISTANCE   = "euclidean"  # placeholder, uses fastdtw under the hood
AC_SORT_LAG_MIN = 155
AC_SORT_LAG_MAX = 195

# ────────────────────────────────────────────────────────── #

### SORTING HELPERS

import scipy.cluster.hierarchy as sch
from sklearn.decomposition import PCA
from fastdtw import fastdtw
from scipy.spatial.distance import squareform

def _order_reads_by_ac(
    read_df: pd.DataFrame,
    plot_window: int,
    lag_lo: int = AC_SORT_LAG_MIN,
    lag_hi: int = AC_SORT_LAG_MAX,
    debug: bool = False
) -> pd.DataFrame:
    """
    Compute the mean autocorrelation for every read between `lag_lo`
    and `lag_hi`, then return a *deduplicated* DataFrame that carries:

        read_id | smallest_positive_nuc_midpoint | greatest_negative_nuc_midpoint
                | ac_mean | read_count

    `read_count` is 1‑based, descending by `ac_mean` (sign kept).
    """
    # ensure sensible limits
    lag_lo, lag_hi = sorted((lag_lo, lag_hi))
    lag_hi = min(lag_hi, 2 * plot_window)          # never exceed vector length

    means = {}
    vec_len = 2 * plot_window + 1                  # index 0 == –plot_window

    for r in read_df.itertuples():
        rel_pos  = np.asarray(r.rel_pos, dtype=int)
        mod_bin  = np.asarray(r.mod_qual_bin, dtype=float)   # 0/1; NaNs will stay NaN

        # build full vector (NaN‑initialised)
        vec = np.full(vec_len, np.nan, dtype=float)
        idx = rel_pos + plot_window
        keep = (idx >= 0) & (idx < vec_len)
        vec[idx[keep]] = mod_bin[keep]

        ac = nanotools._autocorr_vec(vec, lag_hi)            # already NaN‑aware
        win = ac[lag_lo : lag_hi + 1]
        #means[r.read_id] = np.nanmean(win)         # np.nanmean keeps sign
        means[r.read_id] = np.nanmax(win)

    # build ordering table
    nodups = (
        read_df[["read_id",
                  "smallest_positive_nuc_midpoint",
                  "greatest_negative_nuc_midpoint"]]
        .drop_duplicates("read_id")
        .assign(ac_mean=lambda d: d["read_id"].map(means).fillna(-np.inf))
        .sort_values("ac_mean", ascending=False)
        .reset_index(drop=True)
    )
    nodups["read_count"] = np.arange(1, len(nodups) + 1, dtype=int)

    if debug:
        print(f"[acf‑sort] lag window {lag_lo}–{lag_hi} bp | "
              f"{len(nodups)} reads | top‑5 means:",
              nodups["ac_mean"].head().round(3).tolist())

    return nodups

def _build_binned_matrix(read_df, plot_window, n_bins=BIN_COUNT):
    """
    Returns:
      M      – reads×bins matrix of normalized hit‐densities
      read_ids – list of read_id in the same order as rows of M
    """
    edges = np.linspace(-plot_window, plot_window, n_bins+1)
    M, read_ids = [], []
    for r in read_df.itertuples():
        pos = np.asarray(r.rel_pos)[np.asarray(r.mod_qual_bin)==1]
        counts, _ = np.histogram(pos, bins=edges)
        M.append(counts / np.diff(edges))   # density per bp
        read_ids.append(r.read_id)
    return np.vstack(M), read_ids

def _order_by_binned(read_df, plot_window, debug=False):
    M, ids = _build_binned_matrix(read_df, plot_window)
    Z      = sch.linkage(M, method=HC_LINKAGE, metric="euclidean")
    leaves = sch.dendrogram(Z, no_plot=True)["leaves"]
    return _make_nodups(read_df, [ids[i] for i in leaves], debug=debug)

def _order_by_pca(read_df, plot_window, debug=False):
    M, ids = _build_binned_matrix(read_df, plot_window)
    pc1    = PCA(n_components=PCA_COMPONENTS).fit_transform(M)[:,0]
    order  = np.argsort(pc1)  # ascending; flip if you want descending
    return _make_nodups(read_df, [ids[i] for i in order], debug=debug)

def _order_by_dtw(read_df, plot_window, debug=False):
    M, ids = _build_binned_matrix(read_df, plot_window)
    n      = M.shape[0]
    D      = np.zeros((n,n))
    for i in range(n):
        for j in range(i):
            dist, _ = fastdtw(M[i], M[j])
            D[i,j] = D[j,i] = dist
    # convert to condensed form for linkage
    Z      = sch.linkage(squareform(D), method=HC_LINKAGE)
    leaves = sch.dendrogram(Z, no_plot=True)["leaves"]
    return _make_nodups(read_df, [ids[i] for i in leaves], debug=debug)

def _make_nodups(read_df, ordered_ids, debug=False):
    """
    Build the nodups DataFrame (read_id, nuc_midpoints, read_count)
    in the exact same format your create_plot expects.
    """
    tmp = (
        read_df[["read_id",
                 "smallest_positive_nuc_midpoint",
                 "greatest_negative_nuc_midpoint"]]
        .drop_duplicates("read_id")
    )
    # preserve only those in ordered_ids, in that order
    tmp = tmp.set_index("read_id").loc[ordered_ids].reset_index()
    tmp["read_count"] = np.arange(1, len(tmp)+1)
    if debug:
        print(f"[sort] {len(tmp)} reads ordered")
    return tmp



# ─────────────────────────── Helpers ──────────────────────────────── #
def run_in_pool(
    func,
    iterable: Sequence,
    n_workers: int,
    desc: str,
    unit: str = "item"
):
    """Utility wrapper: tqdm progress bar around `multiprocessing.Pool`."""
    from tqdm.auto import tqdm

    with Pool(processes=n_workers) as pool:
        return list(
            tqdm(
                pool.imap_unordered(func, iterable),
                total=len(iterable),
                desc=desc,
                unit=unit
            )
        )


def stitch_vertical(png_top: bytes, png_bottom: bytes) -> PILImage.Image:
    """Merge two PNG byte‑streams vertically with a small white margin."""
    top = PILImage.open(io.BytesIO(png_top)).convert("RGBA")
    bot = PILImage.open(io.BytesIO(png_bottom)).convert("RGBA")

    w = max(top.width, bot.width)
    h = top.height + TABLE_MARGIN_PX + bot.height

    canvas = PILImage.new("RGBA", (w, h), "white")
    canvas.paste(top, (0, 0))
    canvas.paste(bot, (0, top.height + TABLE_MARGIN_PX))
    return canvas

# ─────────────────── Chip‑rank lookup ─────────────────── #
# loads a small table mapping each motif “type” to its chip‑rank score
# (you may need to adjust the path to your actual .bed or .tsv file)
chiprank_df = pd.read_csv(
    "/Data1/reference/rex_chiprank.bed",
    sep=r"\s+",
    usecols=["type", "chip_rank"]
)

# prepend your “MOTIFS_” prefix so it matches the values of `typ`
chiprank_df["type"] = "MOTIFS_" + chiprank_df["type"].astype(str)

# build the lookup: motif → (chip_rank × 100, rounded)
chip_rank_lookup = {
    t: round(float(rk) * 100, 0)
    for t, rk in zip(chiprank_df["type"], chiprank_df["chip_rank"])
}


# ────────────────────── 1) Motif‐DF builder ──────────────────────── #
def _build_motif_df(read_df: pd.DataFrame, plot_window: int, debug: bool=False) -> pd.DataFrame:
    """
    Returns a DataFrame with one row per unique motif (rel_start, motif_id, ln_p, strand, seq, num),
    filtered to ±plot_window.  'num' is 1-based row index after sorting by rel_start.
    """
    if debug: print(f"[motifs] building motif_df from {len(read_df)} reads…")

    # trim out‐of‐window
    def _trim(row):
        starts = [] if pd.isna(row.motif_rel_start) else list(row.motif_rel_start)
        attrs  = [] if pd.isna(row.motif_attributes) else list(row.motif_attributes)
        keep   = [i for i,p in enumerate(starts) if -plot_window <= p <= plot_window]
        return (tuple(starts[i] for i in keep),
                tuple(attrs[i]  for i in keep))

    tmp = read_df.apply(_trim, axis=1, result_type="expand")
    tmp.columns = ["rel_start_list", "attributes_list"]

    # explode → one pair per row
    exploded = (
        tmp.assign(pairs=tmp.apply(lambda r: list(zip(r.rel_start_list, r.attributes_list)), axis=1))
           .explode("pairs")
           .dropna(subset=["pairs"])
    )
    if exploded.empty:
        if debug: print("[motifs] no motifs after explode → returning empty df")
        return pd.DataFrame(columns=[
            "rel_start", "motif_id", "ln_p", "strand", "seq", "num"
        ])

    motif_df = pd.DataFrame(exploded["pairs"].tolist(),
                            columns=["rel_start", "attributes"])
    motif_df = motif_df.drop_duplicates().sort_values("rel_start").reset_index(drop=True)
    motif_df["num"] = motif_df.index + 1

    # unpack attributes tuple into columns
    motif_df[["motif_id","ln_p","strand","seq"]] = pd.DataFrame(
        motif_df["attributes"].tolist(),
        index=motif_df.index
    )
    motif_df = motif_df[["rel_start","motif_id","ln_p","strand","seq","num"]]

    if debug:
        print(f"[motifs] found {len(motif_df)} unique motifs:")
        print(motif_df.head())
    return motif_df


# ────────────────────── 2) Motif‐table builder ────────────────────── #
def create_motif_table(read_df: pd.DataFrame, plot_window: int, debug: bool=False) -> Optional[go.Figure]:
    if debug:
        print(f"[motif_table] starting; incoming read_df rows = {len(read_df)}")
        # Show a snippet of the raw motif columns:
        print("    → motif_rel_start sample:", read_df["motif_rel_start"].head().tolist())
        print("    → motif_attributes sample:", read_df["motif_attributes"].head().tolist())

    motif_df = _build_motif_df(read_df, plot_window, debug)
    if motif_df.empty:
        if debug:
            print(f"[motif_table] _build_motif_df returned empty (no motifs within ±{plot_window}); returning None")
        return None

    if debug:
        print(f"[motif_table] actually found {len(motif_df)} motif rows; here are the first few:\n{motif_df.head()}")

    # compute column widths
    px_per_char, pad = 7, 24
    def _w(col): return len(str(max(col, key=lambda x: len(str(x))))) * px_per_char + pad
    widths = [
        _w(motif_df["num"]),
        _w(motif_df["motif_id"]),
        _w(motif_df["ln_p"]),
        _w(motif_df["strand"]),
        _w(motif_df["seq"]),
    ]

    fig = go.Figure(data=[go.Table(
        columnwidth=widths,
        header=dict(
            values=["#", "Motif", "ln(p‑val)", "Strand", "Sequence"],
            fill_color="white", font=dict(color="black",size=14),
            line_color="lightgrey", align="center"
        ),
        cells=dict(
            values=[
                motif_df["num"],
                motif_df["motif_id"],
                motif_df["ln_p"],
                motif_df["strand"],
                motif_df["seq"]
            ],
            fill_color="white", font=dict(color="black",size=12),
            line_color="lightgrey", align="center",
            height=TABLE_CELL_HEIGHT
        )
    )])
    fig.update_layout(
        margin=dict(l=0,r=0,t=0,b=0),
        paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)",
        height=TABLE_HEADER_H + len(motif_df)*TABLE_CELL_HEIGHT
    )
    if debug: print(f"[motif_table] rendered {len(motif_df)} rows")
    return fig

# ─────────────────── 2.5) Gene‐table builder (NEW) ─────────────────── #
# ─────────────────── 2.5) Gene‐table builder (UPDATED) ─────────────────── #
def create_gene_table(read_df: pd.DataFrame, debug: bool=False) -> Optional[go.Figure]:
    """
    Returns a Plotly Table (up to two rows: closest upstream + closest downstream gene)
    relative to the 0th position (i.e. bed_start). Now filters for 'gene_q' types only.

    Columns:
      ["chromosome", "abs_start", "abs_end", "strand", "rel_start", "rel_end", "exp_quart"]
    """
    if read_df.empty:
        return None

    # 1) All reads in read_df should share the same chromosome and bed_start (0 position)
    chrom   = read_df["chrom"].iloc[0]       # e.g. "CHROMOSOME_IV"
    ref_pos = int(read_df["bed_start"].iloc[0])

    # 2) Filter tss_tes_df to rows on this chromosome AND type containing "gene_q"
    subset = tss_tes_df.loc[
        (tss_tes_df["chromosome"] == chrom)
        & (tss_tes_df["type"].str.contains("gene_q"))
    ].copy()
    if subset.empty:
        if debug:
            print(f"[gene_table] no 'gene_q' entries on {chrom}")
        return None

    # 3) Now, for each "gene_q" row:
    #    - abs_start  = subset["start"]
    #    - abs_end    = subset["end"]
    subset["abs_start"] = subset["start"]
    subset["abs_end"]   = subset["end"]

    # 4) Compute rel_start/rel_end relative to ref_pos
    subset["rel_start"] = subset["abs_start"] - ref_pos
    subset["rel_end"]   = subset["abs_end"]   - ref_pos

    # 5) Extract exp_quart (e.g. "q4" from "gene_q4")
    subset["exp_quart"] = subset["type"].str.extract(r"_(q\d+)")[0]

    # 6) Split upstream/downstream:
    #    - Upstream   = rel_end < 0  → pick the one with max(rel_end)
    #    - Downstream = rel_start > 0 → pick the one with min(rel_start)
    upstream_df   = subset.loc[subset["rel_end"]   < 0]
    downstream_df = subset.loc[subset["rel_start"] > 0]

    rows = []
    if not upstream_df.empty:
        up_row = upstream_df.loc[upstream_df["rel_end"].idxmax()]
        rows.append(up_row)
    if not downstream_df.empty:
        down_row = downstream_df.loc[downstream_df["rel_start"].idxmin()]
        rows.append(down_row)

    if not rows:
        if debug:
            print(f"[gene_table] no genes upstream or downstream on {chrom}")
        return None

    gene_df = pd.DataFrame(rows)[
        ["chromosome", "abs_start", "abs_end", "strand",
         "rel_start", "rel_end", "exp_quart"]
    ].reset_index(drop=True)

    # 7) Build Plotly Table with exactly those two (or fewer) rows
    px_per_char, pad = 7, 24
    def _w(col):
        return len(str(max(col, key=lambda x: len(str(x))))) * px_per_char + pad

    widths = [
        _w(gene_df["chromosome"]),
        _w(gene_df["abs_start"]),
        _w(gene_df["abs_end"]),
        _w(gene_df["strand"]),
        _w(gene_df["rel_start"]),
        _w(gene_df["rel_end"]),
        _w(gene_df["exp_quart"]),
    ]

    fig = go.Figure(data=[go.Table(
        columnwidth=widths,
        header=dict(
            values=["Chromosome", "Abs Start", "Abs End", "Strand",
                    "Rel Start", "Rel End", "Exp Quart"],
            fill_color="white",
            font=dict(color="black", size=14),
            line_color="lightgrey",
            align="center"
        ),
        cells=dict(
            values=[
                gene_df["chromosome"],
                gene_df["abs_start"],
                gene_df["abs_end"],
                gene_df["strand"],
                gene_df["rel_start"],
                gene_df["rel_end"],
                gene_df["exp_quart"]
            ],
            fill_color="white",
            font=dict(color="black", size=12),
            line_color="lightgrey",
            align="center",
            height=TABLE_CELL_HEIGHT
        )
    )])
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
        paper_bgcolor="rgba(0,0,0,0)",
        plot_bgcolor="rgba(0,0,0,0)",
        height=TABLE_HEADER_H + len(gene_df) * TABLE_CELL_HEIGHT
    )
    if debug:
        print(f"[gene_table] rendered {len(gene_df)} row(s):\n{gene_df}")
    return fig


# ─────────────────────── Core plotting logic ───────────────────────── #
def _filter_reads(
    df: pd.DataFrame,
    condition: str,
    chr_type: str,
    data_type: str,
    plot_window: int,
    require_full_span: bool,
    debug: bool = False
) -> pd.DataFrame:
    """Metadata, span, and m6A filters → return copy of filtered reads."""
    reads = df.loc[
        (df["condition"] == condition)
        & (df["chr_type"] == chr_type)
        & (df["type"] == data_type)
    ].copy()

    if require_full_span:
        span = (reads["rel_read_start"] <= -round(plot_window * 0.3, 0)) & \
               (reads["rel_read_end"]   >=  round(plot_window * 0.3, 0))
        reads = reads.loc[span]

    # keep only reads with ≥1 high‑confidence m6A call
    reads = reads.loc[reads["mod_qual_bin"].apply(lambda lst: 1 in lst)]

    if debug:
        print(f"[filter] remaining rows: {len(reads)}")
    return reads


def _long_format(read_df: pd.DataFrame, plot_window: int) -> pd.DataFrame:
    """Explode list columns → long format (one row per base)."""
    records: List[Tuple[str, int, int, int]] = []
    for r in read_df.itertuples():
        idx = [i for i, p in enumerate(r.rel_pos) if -plot_window <= p <= plot_window]
        for i in idx:
            records.append((r.read_id, r.rel_pos[i], r.mod_qual_bin[i], r.read_count))
    return pd.DataFrame(records,
                        columns=["read_id", "rel_pos", "mod_qual_bin", "read_count"])


def _add_read_occupancy(fig, read_df, plot_window):
    """Horizontal bars showing read coverage."""
    genome_size = 2 * plot_window
    read_counts_vec = np.zeros(genome_size, dtype=int)
    line_x, line_y = [], []

    for row in read_df.itertuples():
        mn = max(-plot_window, row.rel_read_start)
        mx = min( plot_window, row.rel_read_end)
        for p in range(int(mn + plot_window), int(mx + plot_window) + 1):
            if 0 <= p < genome_size:
                read_counts_vec[p] += 1
        line_x += [mn, mx, None]
        line_y += [row.read_count] * 3

    fig.add_trace(go.Scatter(x=line_x, y=line_y, mode="lines",
                             line=dict(color=READ_PIXEL_CLR, width=0.1),
                             showlegend=False),
                  row=1, col=1)
    return read_counts_vec


def _add_nuc_segments(fig, down_df, plot_window):
    """Nucleosome midpoint segments on top read panel."""
    if down_df.empty:
        return [], None

    nuc_x, nuc_y, mids = [], [], []
    half = 147 / 2
    for r in down_df.itertuples():
        keep = [n for n in r.nucs_list if -plot_window <= n <= plot_window]
        for n in keep:
            mids.append(n)
            nuc_x += [n - half, n + half, None]
            nuc_y += [r.read_count] * 3

    fig.add_trace(go.Scatter(x=nuc_x, y=nuc_y, mode="lines",
                             line=dict(color=NUC_LINE_CLR, width=2),
                             opacity=0.75, showlegend=False),
                  row=1, col=1)
    return mids


def _add_m6a_layers(
    fig,
    long_df: pd.DataFrame,
    show_nuc_subplot: bool,
    smoothing_type: str,
    smoothing_window: int,
    gaussian_sigma: float
):
    """Scatter m6A hits + smoothed % track."""
    # individual m6A “pixels”
    hits = long_df[long_df["mod_qual_bin"] == 1]
    fig.add_trace(go.Scatter(x=hits["rel_pos"], y=hits["read_count"],
                             mode="markers",
                             marker=dict(symbol="square", size=2.5,
                                         color=M6A_DOT_CLR),
                             showlegend=False),
                  row=1, col=1)

    # % m6A per position + smoothing
    agg = long_df.groupby("rel_pos")["mod_qual_bin"].agg(["sum", "count"]).reset_index()
    agg["ratio"] = agg["sum"] / agg["count"]

    if smoothing_type == "none":
        y = agg["ratio"]
    elif smoothing_type == "moving":
        wt_s = agg["sum"].rolling(smoothing_window, center=True).sum()
        wt_c = agg["count"].rolling(smoothing_window, center=True).sum()
        y = wt_s / wt_c
    elif smoothing_type == "gaussian":
        y = gaussian_filter1d(agg["ratio"], sigma=gaussian_sigma)
    else:
        raise ValueError("smoothing_type must be 'none'|'moving'|'gaussian'")

    tgt_row = 2 if show_nuc_subplot else 2
    fig.add_trace(go.Scatter(
        x=agg["rel_pos"][~np.isnan(y)], y=y[~np.isnan(y)],
        mode="lines", line=dict(color=M6A_DOT_CLR, width=2),
        showlegend=False),
        row=tgt_row, col=1
    )
    fig.update_yaxes(tickformat=".0%", row=tgt_row, col=1)


def _add_nuc_hist(fig, mids, read_counts_vec, plot_window):
    """Optional nucleosome midpoint histogram + smoothed density."""
    if not mids:
        return
    genome_size = 2 * plot_window
    bins = int(round(genome_size / 10) + 1)

    fig.add_trace(go.Histogram(x=mids, nbinsx=bins,
                               marker=dict(color=NUC_LINE_CLR,
                                           opacity=NUC_HISTO_OPACITY),
                               showlegend=False),
                  row=3, col=1, secondary_y=False)

    nuc_vec = np.zeros(genome_size)
    for m in mids:
        idx = int(m + plot_window)
        if 0 <= idx < genome_size:
            nuc_vec[idx] += 1
    density = np.divide(nuc_vec, read_counts_vec, where=read_counts_vec != 0)
    smoothed = gaussian_filter1d(density, sigma=10)
    xs = np.arange(-plot_window, plot_window)

    fig.add_trace(go.Scatter(x=xs, y=smoothed, mode="lines",
                             line=dict(color=NUC_LINE_CLR, width=1),
                             name="Smoothed nuc density"),
                  row=3, col=1, secondary_y=True)


# ───────────────── 3) Annotation upgrader ──────────────────── #
def _add_motif_annotations(fig, read_df, plot_window, debug: bool=False):
    """
    Draws dashed lines + places the motif‐table row 'num' at a fixed y in paper coords.
    """
    motif_df = _build_motif_df(read_df, plot_window, debug)
    if motif_df.empty:
        if debug: print("[annot] no motifs → skipping annotations")
        return

    ANNOT_Y = 1.02
    if debug: print(f"[annot] adding annotations for {len(motif_df)} motifs")

    for row in motif_df.itertuples():
        pos, num = row.rel_start, row.num
        # dashed line
        fig.add_shape(
            type="line", x0=pos, x1=pos, y0=0, y1=1,
            xref="x", yref="paper",
            line=dict(color="grey",width=1,dash="dash")
        )
        # label = the same 'num' used in table
        fig.add_annotation(
            x=pos, y=ANNOT_Y, yref="paper", text=str(num),
            showarrow=False, font=dict(size=10),
            bgcolor="rgba(255,255,255,0.5)"
        )
        if debug:
            print(f"[annot] motif#{num} at x={pos}")



# ───────────────────────── create_plot (public) ─────────────────────── #
def create_plot(
    plot_df: pd.DataFrame,
    group_df: pd.DataFrame,
    condition: str,
    chr_type: str,
    data_type: str,
    plot_window: int,
    show_nuc_subplot: bool = True,
    plot_motifs: bool = True,
    require_full_span: bool = False,
    smoothing_type: str = "moving",
    smoothing_window: int = 50,
    gaussian_sigma: float = 10,
    debug: bool = False
):
    """Main figure builder – now with flexible read‐ordering."""

    # 1) Filter reads
    read_df = _filter_reads(
        plot_df, condition, chr_type, data_type,
        plot_window, require_full_span, debug
    )
    if read_df.empty:
        if debug:
            print("[create_plot] no data – returning empty fig")
        return make_subplots(rows=2, cols=1)

    # 2) Order reads by chosen method
    if SORT_METHOD == 0:
            nodups = (
                read_df[["read_id",
                         "smallest_positive_nuc_midpoint",
                         "greatest_negative_nuc_midpoint"]]
                .drop_duplicates("read_id")
            )
            nodups["read_count"] = np.arange(1, len(nodups) + 1, dtype=int)
    if SORT_METHOD == 1:
        nodups = _order_by_binned(read_df, plot_window, debug)
    elif SORT_METHOD == 3:
        nodups = _order_by_pca(read_df, plot_window, debug)
    elif SORT_METHOD == 4:
        nodups = _order_by_dtw(read_df, plot_window, debug)
    else:
        nodups = _order_reads_by_ac(read_df, plot_window, debug=debug)

    # graft the ordering back onto every row
    read_df = read_df.merge(
        nodups[["read_id", "read_count"]],
        on="read_id",
        how="left"
    )

    # 3) Prepare nuc‐midpoint df
    down_df = (
        group_df.loc[group_df["read_id"].isin(nodups["read_id"])]
                  .merge(nodups[["read_id", "read_count"]], on="read_id")
                  .dropna(subset=["nucs_list"])
    )
    # 3) Explode → long tidy format
    long_df = _long_format(read_df, plot_window)
    if debug:
        print(f"[create_plot] long_df rows: {len(long_df):,}")

    # 4) Figure scaffold
    n_rows, row_heights = ((3, [0.50, 0.25, 0.25])
                           if show_nuc_subplot else
                           (2, [0.33, 0.67]))
    fig = make_subplots(rows=n_rows, cols=1,
                        shared_xaxes=True,
                        vertical_spacing=0.05,
                        specs=[[{"type": "xy"}]] * n_rows,
                        row_heights=row_heights)
    fig.update_xaxes(range=[-plot_window, plot_window])

    # 5) Layers
    read_counts_vec = _add_read_occupancy(fig, read_df, plot_window)
    mids = _add_nuc_segments(fig, down_df, plot_window) \
        if show_nuc_subplot else []
    _add_m6a_layers(fig, long_df, show_nuc_subplot,
                    smoothing_type, smoothing_window, gaussian_sigma)
    if show_nuc_subplot:
        _add_nuc_hist(fig, mids, read_counts_vec, plot_window)
    if plot_motifs:
        _add_motif_annotations(fig, read_df, plot_window)

    # 6) Layout / labels
    fig.update_layout(
        template="simple_white",
        height=FIG_HEIGHT,
        width=FIG_WIDTH,
        title=f"Condition: {condition} | Type: {data_type}",
        font=dict(size=14),
        xaxis_title_font=dict(size=16),
        yaxis_title_font=dict(size=16)
    )

    # style Read track
    fig.update_yaxes(title_text="Read ID",
                     row=1,
                     col=1,
                     showgrid=False,
                    showline=True,
                    linecolor="darkgrey",
                    linewidth=1,
                    ticks="outside"
                     )
    fig.update_xaxes(
                     row=1,
                     col=1,
                     showgrid=False,
                    showline=True,
                    linecolor="darkgrey",
                    linewidth=1,
                    ticks="outside"
                     )

    # style m6A track (no grid, dark‑grey axes, percent ticks)
    tgt_row = 2 if show_nuc_subplot else 2
    fig.update_yaxes(
        title_text="% m6A",
        tickformat=".0%",
        showgrid=False,
        showline=True,
        linecolor="darkgrey",
        linewidth=1,
        ticks="outside",
        row=tgt_row, col=1
    )
    fig.update_xaxes(
        title_text="Genomic position (bp)",
        showgrid=False,
        showline=True,
        linecolor="darkgrey",
        linewidth=1,
        ticks="outside",
        row=tgt_row, col=1
    )

    # style nucleosome subplot (if present)
    if show_nuc_subplot:
        # you can choose to style these similarly or leave defaults
        fig.update_yaxes(title_text="Nuc count",   row=3, col=1)
        fig.update_yaxes(title_text="Nuc density", row=3, col=1,
                         secondary_y=True)
        fig.update_xaxes(title_text="Genomic position (bp)", row=3, col=1)

    return fig



# ────────────────────── Worker wrapper (unchanged API) ──────────────── #
# ────────────────────── Worker wrapper (UPDATED) ───────────────────────── #
def _save_plot(args):
    cond, typ = args
    chr_type    = "X"
    plot_window = temp_bed_w

    # ─── skip if no rows ─────────────────────────────────────────
    mask = (
        (down_sampled_plot_df["condition"] == cond) &
        (down_sampled_plot_df["chr_type"]  == chr_type) &
        (down_sampled_plot_df["type"]      == typ)
    )
    if not mask.any():
        if debug_usage:
            print(f"[Worker] skipping {cond}/{typ}: no read rows found")
        return None

    if debug_usage:
        print(f"[Worker] starting plot for condition={cond}, type={typ}")

    main_fig = create_plot(
        plot_df           = down_sampled_plot_df,
        group_df          = down_sampled_group_df,
        condition         = cond,
        chr_type          = chr_type,
        data_type         = typ,
        plot_window       = plot_window,
        show_nuc_subplot  = False,
        plot_motifs       = True,
        require_full_span = True,
        debug             = debug_usage,
        smoothing_type    = "moving",
        smoothing_window  = 50,
        gaussian_sigma    = 10
    )

    # ─── Motif table ───────────────────────────────────────────────
    read_subset = down_sampled_plot_df.loc[
        (down_sampled_plot_df["condition"] == cond)
        & (down_sampled_plot_df["chr_type"] == chr_type)
        & (down_sampled_plot_df["type"] == typ)
    ].copy()
    table_fig = create_motif_table(read_subset, plot_window)

    # ─── Gene table (closest upstream/downstream) ───────────────────
    gene_table_fig = create_gene_table(read_subset)

     # ─── Export images ─────────────────────────────────────────────
    main_png = main_fig.to_image(format="png", scale=2)

    if table_fig and gene_table_fig:
        tbl_png      = table_fig.to_image(format="png", scale=2)
        gene_tbl_png = gene_table_fig.to_image(format="png", scale=2)

        # 1) Stitch motif table below main plot → produce "stitched_1" as a PIL.Image
        stitched_1 = stitch_vertical(main_png, tbl_png)

        # 2) Re‐encode stitched_1 to PNG bytes before stitching gene table:
        buf = io.BytesIO()
        stitched_1.save(buf, format="PNG")
        stitched_1_png = buf.getvalue()

        # 3) Now stitch the gene table PNG below stitched_1:
        stitched = stitch_vertical(stitched_1_png, gene_tbl_png)

    elif table_fig:
        # Only motif table, no gene table
        tbl_png  = table_fig.to_image(format="png", scale=2)
        stitched = stitch_vertical(main_png, tbl_png)

    elif gene_table_fig:
        # Only gene table, no motif table
        gene_tbl_png = gene_table_fig.to_image(format="png", scale=2)
        stitched = stitch_vertical(main_png, gene_tbl_png)

    else:
        # Neither table
        stitched = PILImage.open(io.BytesIO(main_png)).convert("RGBA")

    # ─── Build save paths ────────────────────────────────────────────
    save_dir = os.path.join("images", cond, str(plot_window))
    os.makedirs(save_dir, exist_ok=True)
    rank_val = chip_rank_lookup.get(typ, None)
    if rank_val is None:
        rank_pref = "NA"
    else:
        rank_pref = str(int(round(rank_val)))

    base = f"{rank_pref}_{cond}_{typ}_b{plot_window}_chr{chr_type}"
    png_path = os.path.join(save_dir, base + ".png")
    svg_path = os.path.join(save_dir, base + ".svg")

    # ─── Write PNG ───────────────────────────────────────────────────
    stitched.save(png_path, dpi=(PNG_DPI, PNG_DPI))

    # ─── Write SVG (concatenate motif+gene tables if exist) ──────────
    main_svg = main_fig.to_image(format="svg").decode()

    if table_fig and gene_table_fig:
        tbl_svg       = table_fig.to_image(format="svg").decode()
        gene_tbl_svg  = gene_table_fig.to_image(format="svg").decode()

        # Build a triple‐stacked <g> structure
        svg_head = '<svg xmlns="http://www.w3.org/2000/svg">'
        translate_tbl = f'translate(0,{main_fig.layout.height + TABLE_MARGIN_PX})'
        translate_gene = f'translate(0,{main_fig.layout.height + TABLE_MARGIN_PX + table_fig.layout.height + TABLE_MARGIN_PX})'
        combined_svg = (
            svg_head
            + f'<g>{main_svg}</g>'
            + f'<g transform="{translate_tbl}">{tbl_svg}</g>'
            + f'<g transform="{translate_gene}">{gene_tbl_svg}</g>'
            + '</svg>'
        )

    elif table_fig:
        tbl_svg = table_fig.to_image(format="svg").decode()
        svg_head = '<svg xmlns="http://www.w3.org/2000/svg">'
        translate_tbl = f'translate(0,{main_fig.layout.height + TABLE_MARGIN_PX})'
        combined_svg = (
            svg_head
            + f'<g>{main_svg}</g>'
            + f'<g transform="{translate_tbl}">{tbl_svg}</g>'
            + '</svg>'
        )

    elif gene_table_fig:
        gene_tbl_svg = gene_table_fig.to_image(format="svg").decode()
        svg_head = '<svg xmlns="http://www.w3.org/2000/svg">'
        translate_gene = f'translate(0,{main_fig.layout.height + TABLE_MARGIN_PX})'
        combined_svg = (
            svg_head
            + f'<g>{main_svg}</g>'
            + f'<g transform="{translate_gene}">{gene_tbl_svg}</g>'
            + '</svg>'
        )

    else:
        combined_svg = main_svg

    with open(svg_path, "w") as fh:
        fh.write(combined_svg)

    if debug_usage:
        print(f"[Worker] saved → {png_path}")
    return png_path


# ───────────────────────────── Script main ─────────────────────────── #
if __name__ == "__main__":
    debug_usage = True
    num_workers = 50
    temp_bed_w  = 1000

    motifs_type_selected = [f"MOTIFS_{t}" for t in type_selected]

    test_combo = [[analysis_cond[0], "univ_nuc"]]
    all_combos = list(product(analysis_cond, motifs_type_selected))
    combos =  test_combo # quick test combo

    # Use helper for pool + progress bar
    saved_pngs = run_in_pool(
        func=_save_plot,
        iterable=combos,
        n_workers=num_workers,
        desc="Saving plots",
        unit="plot"
    )

    if len(combos) == 1:
        display(Image(filename=saved_pngs[0]))


In [None]:
# ─────────────────────────── Dual‑track plotting ────────────────────────── #
from itertools import product, combinations
import plotly.graph_objects as go
import warnings


# Default palettes (override with kwargs)
DEFAULT_READ_COLORS     = ["#b12537", "#4974a5"]
DEFAULT_M6A_DOT_COLORS  = ["#b12537", "#4974a5"]
DEFAULT_M6A_LINE_COLORS = ["#b12537", "#4974a5"]

# ───────────────────── Sorting config ───────────────────── #
SORT_METHOD    = 0        # 1=binned‐HC, 3=PCA, 4=DTW (else=ACF)
BIN_COUNT      = 200        # number of bins for methods 1, 3, 4
HC_LINKAGE     = "ward"    # for hierarchical clustering
PCA_COMPONENTS = 1         # PCs to extract
DTW_DISTANCE   = "euclidean"  # placeholder, uses fastdtw under the hood
AC_SORT_LAG_MIN = 155
AC_SORT_LAG_MAX = 195


# ──────────────────────── Occupancy helper (colour) ───────────────────── #
def _add_read_occupancy_multi(fig, read_df, plot_window, row, color, read_width):
    """Horizontal bars showing read coverage (for multi‑condition plots)."""
    line_x, line_y = [], []
    for r in read_df.itertuples():
        mn = max(-plot_window, r.rel_read_start)
        mx = min( plot_window, r.rel_read_end)
        line_x += [mn, mx, None]
        line_y += [r.read_count] * 3

    fig.add_trace(go.Scatter(x=line_x, y=line_y, mode="lines",
                             line=dict(color=color, width=read_width),
                             showlegend=False),
                  row=row, col=1)

import math # Make sure this import is at the top with other imports if not already there.

def create_dual_plot(
    plot_df: pd.DataFrame,
    group_df: pd.DataFrame,
    condition,               # str | list[str]  (≤ 2)
    chr_type: str,
    data_type: str,
    plot_window: int,
    plot_motifs: bool = True,
    require_full_span: bool = False,
    smoothing_type: str = "moving",
    smoothing_window: int = 25,
    gaussian_sigma: float = 10,
    debug: bool = False,
    max_reads_per_condition: Optional[int] = None, # New parameter
    random_state: Optional[int] = 42,              # Added for subsampling reproducibility
    # colour overrides
    read_colors: list = None,
    m6a_dot_colors: list = None,
    m6a_line_colors: list = None,
):
    """
    Two‑track (or single‑track) read plot with a shared %m6A panel.
    """
    # ─── 1. Standardise condition input ──────────────────────────────────
    conds = list(condition) if isinstance(condition, (list, tuple)) else [condition]
    if not (1 <= len(conds) <= 2):
        raise ValueError("condition must be a string or a list/tuple of length ≤ 2")
    single = (len(conds) == 1)

    # ─── 2. Check for missing reads ─────────────────────────────────────
    missing = []
    for cond in conds:
        df_tmp = _filter_reads(
            plot_df, cond, chr_type, data_type,
            plot_window, require_full_span, debug
        )
        if df_tmp.empty:
            missing.append(cond)
    # warn for each missing condition
    for cond in missing:
        warnings.warn(
            f"No reads found for condition '{cond}' and type '{data_type}'",
            UserWarning
        )
    # if *all* are missing, bail out
    if len(missing) == len(conds):
        warnings.warn(
            f"No reads to plot for conditions {conds} and type '{data_type}'; skipping plot.",
            UserWarning
        )
        return None

    # ─── 3. Colour palette ───────────────────────────────────────────────
    read_colors     = read_colors     or DEFAULT_READ_COLORS[:len(conds)]
    m6a_dot_colors  = m6a_dot_colors  or DEFAULT_M6A_DOT_COLORS[:len(conds)]
    m6a_line_colors = m6a_line_colors or DEFAULT_M6A_LINE_COLORS[:len(conds)]

    # ─── 4. Figure scaffold ─────────────────────────────────────────────
    if single:
        rows, row_heights = 2, [0.4, 0.6]            # read | m6A
        read_rows         = [1]
        m6a_row           = 2
    else:
        rows, row_heights = 3, [0.2, 0.4, 0.2]       # readA | m6A | readB
        read_rows         = [1, 3]
        m6a_row           = 2

    fig = make_subplots(
        rows=rows, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.02,
        row_heights=row_heights
    )
    fig.update_xaxes(range=[-plot_window, plot_window])

    # ─── 5. Per‑condition layers ─────────────────────────────────────────
    motif_collect = []
    for idx, cond in enumerate(conds):
        df_reads = _filter_reads(
            plot_df, cond, chr_type, data_type,
            plot_window, require_full_span, debug
        )
        if df_reads.empty:
            continue
        motif_collect.append(df_reads[["motif_rel_start", "motif_attributes"]])

        # Ensure df_reads is not empty before proceeding
        if df_reads.empty:
            continue

        # Subsample reads for this condition before ordering
        current_reads_count = df_reads["read_id"].nunique()
        # Use a consistent max_reads_to_display here, let's pass it as a parameter to create_dual_plot
        # or define it as a local variable for clarity. Let's make it a parameter for dual_plot.

        # We need to determine the target number of reads based on `max_reads_per_condition`
        # and the actual number of unique reads after filtering.
        # We need to determine the target number of reads based on `max_reads_per_condition`
        # and the actual number of unique reads after filtering.
        if max_reads_per_condition is not None:
            num_reads_for_this_cond = min(max_reads_per_condition, current_reads_count)
        else:
            num_reads_for_this_cond = current_reads_count # No limit if None

        if num_reads_for_this_cond < current_reads_count: # Only subsample if needed
            read_ids_to_keep = df_reads["read_id"].drop_duplicates().sample(
                n=num_reads_for_this_cond, random_state=random_state
            )
            df_reads = df_reads[df_reads["read_id"].isin(read_ids_to_keep)].copy()
            if debug:
                print(f"[dual-plot] Condition '{cond}' subsampled to {len(read_ids_to_keep)} reads.")

        # 4‑B. Sort reads
        if SORT_METHOD == 0:
            nodups = (
                df_reads[["read_id",
                         "smallest_positive_nuc_midpoint",
                         "greatest_negative_nuc_midpoint"]]
                .drop_duplicates("read_id")
            )
            nodups["read_count"] = np.arange(1, len(nodups) + 1, dtype=int)
        elif SORT_METHOD == 1:
            nodups = _order_by_binned(df_reads, plot_window, debug)
        elif SORT_METHOD == 3:
            nodups = _order_by_pca(df_reads, plot_window, debug)
        elif SORT_METHOD == 4:
            nodups = _order_by_dtw(df_reads, plot_window, debug)
        else:
            nodups = _order_reads_by_ac(df_reads, plot_window, debug=debug)

        # merge sorted read_count back into the per-base DataFrame
        df_reads = df_reads.merge(
            nodups[["read_id", "read_count"]],
            on="read_id",
            how="left"
        )

        # 4‑C. Read occupancy lines
        _add_read_occupancy_multi(
            fig, df_reads, plot_window,
            row=read_rows[idx],
            color=read_colors[idx],
            read_width = 0.1
        )

        # 4‑D. m6A dots
        long_df = _long_format(df_reads, plot_window)
        hits = long_df[long_df["mod_qual_bin"] == 1]
        fig.add_trace(go.Scatter(
            x=hits["rel_pos"], y=hits["read_count"],
            mode="markers",
            marker=dict(symbol="square", size=2.5, color=m6a_dot_colors[idx]),
            showlegend=False),
            row=read_rows[idx], col=1
        )

        # 4‑E. Condition‑specific %m6A line (shared middle panel)
        agg = (long_df.groupby("rel_pos")["mod_qual_bin"]
                      .agg(["sum", "count"]).reset_index())
        agg["ratio"] = agg["sum"] / agg["count"]
        if smoothing_type == "none":
            y = agg["ratio"]
        elif smoothing_type == "moving":
            y = (agg["sum"].rolling(smoothing_window, center=True).sum() /
                 agg["count"].rolling(smoothing_window, center=True).sum())
        else:   # gaussian
            y = gaussian_filter1d(agg["ratio"], sigma=gaussian_sigma)

        mask = ~np.isnan(y)
        fig.add_trace(go.Scatter(
            x=agg["rel_pos"][mask], y=y[mask],
            mode="lines",
            line=dict(color=m6a_line_colors[idx], width=2),
            name=cond),
            row=m6a_row, col=1
        )

    # y‑axis formatting for %m6A
    fig.update_yaxes(tickformat=".0%", row=m6a_row, col=1)

    # ─── 5. Motif dashed lines + labels  (ONE authoritative numbering)        ──
    if plot_motifs and motif_collect:
        raw_motifs = pd.concat(motif_collect, ignore_index=True)
        motif_df   = _build_motif_df(raw_motifs, plot_window, debug)
        if not motif_df.empty:
            ANNOT_Y = 1.02
            for row in motif_df.itertuples():
                pos, num = row.rel_start, row.num
                fig.add_shape(type="line", x0=pos, x1=pos, y0=0, y1=1,
                              xref="x", yref="paper",
                              line=dict(color="grey", width=1, dash="dash"))
                fig.add_annotation(
                    x=pos, y=ANNOT_Y, yref="paper",
                    text=str(num),
                    showarrow=False,
                    font=dict(size=10),
                    bgcolor="rgba(255,255,255,0.5)"
                )
                if debug:
                    print(f"[dual‑plot] motif#{num} at x={pos}")

    # ─── 7. Layout tweaks ────────────────────────────────────────────────
    ttl = " & ".join(conds) if len(conds) == 2 else conds[0]
    fig.update_layout(
        template="simple_white",
        height=700 if single else 700,
        width=900,
        title=f"Condition(s): {ttl} | Type: {data_type}",
        legend=dict(orientation="h", yanchor="bottom", y=1.02,
                    xanchor="right", x=1)
    )

    # axis labels
    for rid in read_rows:
        fig.update_yaxes(title_text="Read ID", row=rid, col=1)
    fig.update_yaxes(title_text="% m6A", row=m6a_row, col=1)
    fig.update_xaxes(title_text="Genomic position (bp)", row=rows, col=1)

    return fig

# ───────────────────────── Worker: save / stitch ──────────────────────── #
def _save_dual_plot(args):
    cond1, cond2, typ = args
    chr_sel   = "X"
    win       = temp_bed_w  # global constant

    # 1) Build the dual‐plot figure
    fig = create_dual_plot(
        plot_df          = down_sampled_plot_df,
        group_df         = down_sampled_group_df,
        condition        = [cond1, cond2],
        chr_type         = chr_sel,
        data_type        = typ,
        plot_window      = win,
        plot_motifs      = True,
        require_full_span= True,
        debug            = debug_usage,
        smoothing_type   = "moving",
        smoothing_window = 50,
        gaussian_sigma   = 10,
        max_reads_per_condition = 300,
        read_colors      = ["#b12537", "#4974a5"],
        m6a_dot_colors   = ["#b12537", "#4974a5"],
        m6a_line_colors  = ["#b12537", "#4974a5"],
    )

    # If fig is None, there were no reads → skip
    if fig is None:
        return None

    # 2) Build read_subset for motif/gene tables
    read_subset = down_sampled_plot_df.loc[
        (down_sampled_plot_df["condition"].isin([cond1, cond2]))
        & (down_sampled_plot_df["chr_type"] == chr_sel)
        & (down_sampled_plot_df["type"] == typ)
    ].copy()

    # 3) Create the motif‐table (if any motifs)
    table_fig = create_motif_table(read_subset, win)

    # 4) Create the gene‐table (closest upstream/downstream)
    gene_table_fig = create_gene_table(read_subset)

    # 5) Export main figure to PNG bytes
    main_png = fig.to_image(format="png", scale=2)

    # 6) Stitch in this order: main_plot → motif_table → gene_table
    if table_fig and gene_table_fig:
        # (a) get motif PNG bytes
        tbl_png      = table_fig.to_image(format="png", scale=2)
        # (b) first stitch main + motif
        stitched_1   = stitch_vertical(main_png, tbl_png)

        # (c) re‐encode stitched_1 as PNG bytes
        buf = io.BytesIO()
        stitched_1.save(buf, format="PNG")
        stitched_1_png = buf.getvalue()

        # (d) get gene PNG bytes
        gene_tbl_png = gene_table_fig.to_image(format="png", scale=2)
        # (e) stitch stitched_1 + gene_table
        stitched     = stitch_vertical(stitched_1_png, gene_tbl_png)

    elif table_fig:
        # Only motif table present
        tbl_png  = table_fig.to_image(format="png", scale=2)
        stitched = stitch_vertical(main_png, tbl_png)

    elif gene_table_fig:
        # Only gene table present (no motifs)
        gene_tbl_png = gene_table_fig.to_image(format="png", scale=2)
        stitched     = stitch_vertical(main_png, gene_tbl_png)

    else:
        # Neither table present
        stitched = PILImage.open(io.BytesIO(main_png)).convert("RGBA")

    # 7) Save stitched PNG + assemble SVG similarly
    save_dir = os.path.join("images_pairwise", f"{cond1}_{cond2}_{win}bp")
    os.makedirs(save_dir, exist_ok=True)

    rank_val = chip_rank_lookup.get(typ, None)
    if rank_val is None:
        rank_pref = "NA"
    else:
        rank_pref = str(int(round(rank_val)))

    base_fn   = f"{rank_pref}_{typ}_b{win}_chr{chr_sel}"
    png_path  = os.path.join(save_dir, base_fn + ".png")
    svg_path  = os.path.join(save_dir, base_fn + ".svg")

    # Write PNG
    stitched.save(png_path, dpi=(PNG_DPI, PNG_DPI))

    # Build combined SVG:
    main_svg = fig.to_image(format="svg").decode()

    if table_fig and gene_table_fig:
        tbl_svg       = table_fig.to_image(format="svg").decode()
        gene_tbl_svg  = gene_table_fig.to_image(format="svg").decode()

        svg_head = '<svg xmlns="http://www.w3.org/2000/svg">'
        # Position motif table just below main
        translate_tbl  = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX})'
        # Position gene table below motif table
        translate_gene = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX + table_fig.layout.height + TABLE_MARGIN_PX})'

        combined_svg = (
            svg_head
            + f'<g>{main_svg}</g>'
            + f'<g transform="{translate_tbl}">{tbl_svg}</g>'
            + f'<g transform="{translate_gene}">{gene_tbl_svg}</g>'
            + '</svg>'
        )

    elif table_fig:
        tbl_svg    = table_fig.to_image(format="svg").decode()
        svg_head   = '<svg xmlns="http://www.w3.org/2000/svg">'
        translate_tbl = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX})'
        combined_svg  = (
            svg_head
            + f'<g>{main_svg}</g>'
            + f'<g transform="{translate_tbl}">{tbl_svg}</g>'
            + '</svg>'
        )

    elif gene_table_fig:
        gene_tbl_svg = gene_table_fig.to_image(format="svg").decode()
        svg_head     = '<svg xmlns="http://www.w3.org/2000/svg">'
        translate_gene = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX})'
        combined_svg   = (
            svg_head
            + f'<g>{main_svg}</g>'
            + f'<g transform="{translate_gene}">{gene_tbl_svg}</g>'
            + '</svg>'
        )

    else:
        combined_svg = main_svg

    with open(svg_path, "w") as fh:
        fh.write(combined_svg)

    if debug_usage:
        print(f"[Worker] saved → {png_path}")
    return png_path


# ───────────────────────────── CLI / main ─────────────────────────────── #
if __name__ == "__main__":
    debug_usage = False
    num_workers = 50
    temp_bed_w  = 1000

    # Build motif type list
    motifs_type_selected = [f"MOTIFS_{t}" for t in type_selected]

    # ── FULL run: every unordered pair × every motif type ───────────────
    #  Un‑comment to run the whole batch
    cond_pairs = list(combinations(analysis_cond, 2))
    combos = [(c1, c2, typ) for (c1, c2) in cond_pairs for typ in motifs_type_selected]

    # ── QUICK test: one pair + one type ─────────────────────────────────
    combos = [[analysis_cond[0], analysis_cond[5], "MOTIFS_rex48"]]#"MOTIFS_rex32"]]

    saved_pngs = run_in_pool(
        func=_save_dual_plot,
        iterable=combos,
        n_workers=num_workers,
        desc="Saving dual‑track plots",
        unit="plot"
    )

    # show the single figure inline when only one combo
    if len(saved_pngs) == 1:
        display(Image(filename=saved_pngs[0]))


In [None]:
# ─────────────────────────── Three‑track plotting ────────────────────────── #
import plotly.graph_objects as go

# default palettes – you can override via kwargs
# red, blue, green, grey
DEFAULT_READ_CLRS = ["#b12537", "#4974a5", "#47B562" ]   #"#47B562" green | #"#b12537" red | #808080 grey | #FF8559 coral

# grey color #

def create_triple_plot(
    plot_df: pd.DataFrame,
    condition,                             # str | list[str] (max 3)
    chr_type: str,
    data_type: str,
    plot_window: int,
    *,
    plot_motifs: bool          = True,
    require_full_span: bool    = False,
    smoothing_type: str        = "moving",
    smoothing_window: int      = 50,
    gaussian_sigma: float      = 10,
    debug: bool                = False,
    subsample_reads: bool      = False,
    max_reads_to_display: Optional[int] = 300, # New parameter with default
    random_state: int           = None,
    read_colors: list          = None,
    m6a_dot_colors: list       = None,
    m6a_line_colors: list      = None,
):
    """
    Up to three read tracks with interleaved combined %m6A panels.
    With three conditions the row order is now (%m6A, %m6A, read, read, read):
      1. condA+condB %m6A
      2. condB+condC %m6A
      3. condA       reads
      4. condB       reads
      5. condC       reads
    """
    # ─── normalize inputs & palettes ─────────────────────────────────
    conds = list(condition) if isinstance(condition, (list, tuple)) else [condition]
    if not (1 <= len(conds) <= 3):
        raise ValueError("condition must be 1–3 strings")
    n_cond = len(conds)

    def _fill(col, default_list):
        if col is None:
            if len(default_list) < n_cond:
                raise ValueError(f"Need at least {n_cond} default colours, got {len(default_list)}")
            return default_list[:n_cond]
        if len(col) != n_cond:
            raise ValueError(f"Need {n_cond} colours, got {len(col)}")
        return col

    read_colors     = _fill(read_colors,     DEFAULT_READ_CLRS)
    m6a_dot_colors  = _fill(m6a_dot_colors,  DEFAULT_READ_CLRS)
    m6a_line_colors = _fill(m6a_line_colors, DEFAULT_READ_CLRS)

    rng = np.random.default_rng(random_state)

    # ─── figure scaffold ────────────────────────────────────────────────
    if n_cond == 1:
        specs, row_h = [[{}], [{}]], [0.55, 0.45]
        read_rows        = [1]
        combined_mapping = [(2, [0])]
    elif n_cond == 2:
        specs, row_h = [[{}], [{}], [{}]], [0.4, 0.4, 0.2]
        read_rows        = [1, 2]
        combined_mapping = [(3, [0, 1])]
    else:
        specs  = [[{}], [{}], [{}], [{}], [{}]]
        row_h  = [0.35, 0.35, 0.10, 0.10, 0.10]
        combined_mapping = [(1, [0, 1]), (2, [1, 2])]
        read_rows = [3, 4, 5]

    fig = make_subplots(
        rows=len(specs), cols=1,
        shared_xaxes=True,
        vertical_spacing=0.02,
        specs=specs, row_heights=row_h
    )
    # Set x‑axis range and ticks every 300 bp centered on 0
    fig.update_xaxes(
        range=[-plot_window, plot_window],
        tickmode="linear",
        tick0=0,
        dtick=300
    )

    # ─── gather / optional subsample ───────────────────────────────────
    cond_dfs  = []
    read_lens = []
    for cond in conds:
        df = _filter_reads(plot_df, cond, chr_type, data_type,
                           plot_window, require_full_span, debug)
        cond_dfs.append(df)
        read_lens.append(df["read_id"].nunique())

    # ─── gather / optional subsample ───────────────────────────────────
    cond_dfs  = []
    read_lens = []
    for cond in conds:
        df = _filter_reads(plot_df, cond, chr_type, data_type,
                           plot_window, require_full_span, debug)
        cond_dfs.append(df)
        read_lens.append(df["read_id"].nunique())

    # Determine the number of reads to display for each condition
    # It should be the minimum of:
    # 1. The maximum number of reads you want to display (e.g., 300)
    # 2. The minimum number of reads available across all conditions (if subsample_reads is True)
    # ─── gather / optional subsample ───────────────────────────────────
    cond_dfs  = []
    read_lens = []
    for cond in conds:
        df = _filter_reads(plot_df, cond, chr_type, data_type,
                           plot_window, require_full_span, debug)
        cond_dfs.append(df)
        read_lens.append(df["read_id"].nunique())

    # Determine the number of reads to display for each condition
    # It should be the minimum of:
    # 1. The maximum number of reads you want to display (e.g., 300)
    # 2. The minimum number of reads available across all conditions (if subsample_reads is True)
    if subsample_reads and read_lens:
        actual_min_reads = min(read_lens)
        target_reads_per_condition = min(max_reads_to_display, actual_min_reads) if max_reads_to_display is not None else actual_min_reads
        if debug:
            print(f"[triple‑plot] Target reads per condition: {target_reads_per_condition} (min across conditions: {actual_min_reads})")
    else:
        # If no subsampling, or no reads, handle accordingly
        target_reads_per_condition = None # This means no explicit subsampling based on a max value

    motif_collect = []
    for idx, (cond, df_reads) in enumerate(zip(conds, cond_dfs)):
        if df_reads.empty:
            continue

        # Apply subsampling for each condition if target_reads_per_condition is set
        if target_reads_per_condition is not None and df_reads["read_id"].nunique() > target_reads_per_condition:
            unique_read_ids = df_reads["read_id"].drop_duplicates()
            # Ensure rng is initialized (it's initialized at the top of create_triple_plot)
            ids_to_keep = rng.choice(unique_read_ids, size=target_reads_per_condition, replace=False)
            df_reads = df_reads[df_reads["read_id"].isin(ids_to_keep)].copy()
            if debug:
                print(f"[triple-plot] Condition '{cond}' subsampled to {len(ids_to_keep)} reads.")

        motif_collect.append(df_reads[["motif_rel_start", "motif_attributes"]])

        rc_map = (pd.DataFrame({"read_id": df_reads["read_id"].unique()})
                  .assign(read_count=lambda d: np.arange(1, len(d) + 1)))
        df_reads = df_reads.merge(rc_map, on="read_id")

        _add_read_occupancy_multi(
            fig, df_reads, plot_window,
            row=read_rows[idx],
            color=read_colors[idx],
            read_width=0.025 #0.1 for large format display | 0.025 for compressed figure
            # max_reads_to_display is no longer passed here as subsampling is done earlier
        )
        long_df = _long_format(df_reads, plot_window)
        hits    = long_df[long_df["mod_qual_bin"] == 1]
        fig.add_trace(
            go.Scatter(
                x=hits["rel_pos"], y=hits["read_count"],
                mode="markers",
                marker=dict(symbol="square", size=1.5, #2.5 for large format display | 1.5 for compressed figure
                            color=m6a_dot_colors[idx]),
                showlegend=False
            ),
            row=read_rows[idx], col=1
        )
        fig.update_yaxes(title_text="Read ID",
                         row=read_rows[idx], col=1)

    # ─── plot combined %m6A panels & track global max ───────────────────
    combined_max = 0.0
    for comb_row, cond_idxs in combined_mapping:
        for ci in cond_idxs:
            df_meta = plot_df.loc[
                (plot_df["condition"] == conds[ci]) &
                (plot_df["chr_type"]   == chr_type) &
                (plot_df["type"]       == data_type)
            ].copy()

            if df_meta.empty:
                continue
            recs = [
                (p, b)
                for r in df_meta.itertuples()
                for p, b in zip(r.rel_pos, r.mod_qual_bin)
                if -plot_window <= p <= plot_window
            ]
            agg = (pd.DataFrame(recs, columns=["pos", "bin"])
                   .groupby("pos")["bin"]
                   .agg(["sum", "count"])
                   .reset_index())
            agg["ratio"] = agg["sum"] / agg["count"]

            if smoothing_type == "none":
                y = agg["ratio"]
            elif smoothing_type == "moving":
                y = (agg["sum"]
                     .rolling(smoothing_window, center=True).sum() /
                     agg["count"]
                     .rolling(smoothing_window, center=True).sum())
            else:
                y = gaussian_filter1d(agg["ratio"], sigma=gaussian_sigma)

            if len(y):
                combined_max = max(combined_max, float(y.max()))

            mask = ~np.isnan(y)
            fig.add_trace(
                go.Scatter(
                    x=agg["pos"][mask], y=y[mask],
                    mode="lines",
                    line=dict(color=m6a_line_colors[ci], width=2),
                    name=str(conds[ci])
                ),
                row=comb_row, col=1
            )
        fig.update_yaxes(title_text="% m6A", row=comb_row, col=1)

    # ─── equalize + style the two % m6A axes ─────────────────────────────
    if len(combined_mapping) > 1:
        # round up to the next 2%
        max_pct = math.ceil((combined_max * 100) / 2) * 2
        ymax    = max_pct / 100

        for comb_row, _ in combined_mapping:
            fig.update_xaxes(
                range=[-plot_window, plot_window],
                tickmode="linear",
                tick0=0,
                dtick=300,
                showgrid=False,
                showline=True,
                linecolor="darkgrey",
                linewidth=1.5,
                ticks="outside",
                row=comb_row, col=1
            )
            fig.update_yaxes(
                range=[0, ymax],
                tickformat=".0%",
                showgrid=False,
                showline=True,
                linecolor="darkgrey",
                linewidth=1.5,
                ticks="outside",
                row=comb_row, col=1
            )


    # ─── 5 . motif dashed lines + labels (table‑synced) ──────────────────
    if plot_motifs and motif_collect:
        # collapse all conditions’ motif columns → build one authoritative df
        raw_motifs = pd.concat(motif_collect, ignore_index=True)
        motif_df   = _build_motif_df(raw_motifs, plot_window, debug)

        if not motif_df.empty:
            ANNOT_Y = 1.02
            if debug: print(f"[triple‑plot] annotating {len(motif_df)} motifs")

            for row in motif_df.itertuples():
                pos, num = row.rel_start, row.num

                # vertical dashed guide
                fig.add_shape(
                    type="line", x0=pos, x1=pos, y0=0, y1=1,
                    xref="x", yref="paper",
                    line=dict(color="grey", width=1, dash="dash")
                )

                # numeric label (matches the table row number)
                fig.add_annotation(
                    x=pos, y=ANNOT_Y, yref="paper",
                    text=str(num),
                    showarrow=False,
                    font=dict(size=10),
                    bgcolor="rgba(255,255,255,0.5)"
                )
                if debug:
                    print(f"[triple‑plot] motif#{num} at x={pos}")
    else:
        if debug:
            print("[triple‑plot] no motifs found – skipping annotation")


    # ─── final layout ────────────────────────────────────────────────────
    fig.update_layout(
        template="plotly_white",
        width=FIG_WIDTH,
        height=FIG_HEIGHT,
        title=f"Conditions: {', '.join(conds)} | Type: {data_type}",
        legend=dict(orientation="h", yanchor="bottom", y=1.02,
                    xanchor="right", x=1)
    )
    # Ensure bottom x‑axis also has ticks every 300 bp
    fig.update_xaxes(
        tickmode="linear",
        tick0=0,
        dtick=300,
        row=len(specs), col=1
    )
    fig.update_xaxes(title_text="Genomic position (bp)",
                     row=len(specs), col=1)

    for r in read_rows:
        fig.update_yaxes(showgrid=False,zeroline=False, row=r, col=1)
        fig.update_xaxes(showgrid=False,showline=False, zeroline=False, row=r, col=1)

    return fig

# ────────────────────────── Worker: save / stitch ───────────────────────── #
def _save_triple_plot(args):
    conds, motif_type = args
    chr_sel   = "Autosome"
    win       = temp_bed_w  # global constant

    # 1) Mask to see if any reads exist for these three conditions + motif_type
    mask = (
        down_sampled_plot_df["condition"].isin(conds)
        & (down_sampled_plot_df["chr_type"] == chr_sel)
        & (down_sampled_plot_df["type"] == motif_type)
    )
    if not mask.any():
        if debug_usage:
            print(f"[Worker] skip {motif_type} – no reads for {conds}")
        return None

    # 2) Build the triple‐plot figure
    fig = create_triple_plot(
        plot_df           = down_sampled_plot_df,
        condition         = list(conds),
        chr_type          = chr_sel,
        data_type         = motif_type,
        plot_window       = win,
        plot_motifs       = False,
        require_full_span = True,
        smoothing_type    = "moving",
        smoothing_window  = 50,
        gaussian_sigma    = 10,
        debug             = debug_usage,
        subsample_reads   = True,
        max_reads_to_display=300, # Pass your desired max reads here
        random_state      = 42,
        read_colors       = DEFAULT_READ_CLRS,
        m6a_dot_colors    = DEFAULT_READ_CLRS,
        m6a_line_colors   = DEFAULT_READ_CLRS,
    )

    # 3) Build read_subset for tables
    read_subset = down_sampled_plot_df.loc[mask].copy()

    if debug_usage:
        print(f"[Worker] rows in read_subset: {len(read_subset)}")
        # How many reads even have a non‐null motif_rel_start?
        nonnull_motif = read_subset["motif_rel_start"].notna().sum()
        print(f"[Worker] of those, {nonnull_motif} rows have a non‐null motif_rel_start")
        # Show one example of the raw motif lists:
        example = read_subset["motif_rel_start"].iloc[0]
        print(f"[Worker] example motif_rel_start (first row): {example}")

    # 4) Create motif‐table (if any)
    table_fig       = create_motif_table(read_subset, win)

    # 5) Create gene‐table
    gene_table_fig  = create_gene_table(read_subset)

    # 6) Export main figure to PNG bytes; no forced height/width
    main_png = fig.to_image(format="png", scale=1)

    # 7) Render motif and gene tables at their “natural” size
    if table_fig and gene_table_fig:
        tbl_png      = table_fig.to_image(format="png", scale=1)
        stitched_1   = stitch_vertical(main_png, tbl_png)

        buf = io.BytesIO()
        stitched_1.save(buf, format="PNG")
        stitched_1_png = buf.getvalue()

        gene_tbl_png = gene_table_fig.to_image(format="png", scale=1)
        stitched = stitch_vertical(stitched_1_png, gene_tbl_png)

    elif table_fig:
        tbl_png  = table_fig.to_image(format="png", scale=1)
        stitched = stitch_vertical(main_png, tbl_png)

    elif gene_table_fig:
        gene_tbl_png = gene_table_fig.to_image(format="png", scale=1)
        stitched     = stitch_vertical(main_png, gene_tbl_png)

    else:
        stitched = PILImage.open(io.BytesIO(main_png)).convert("RGBA")

    # 8) Save stitched PNG and write out SVG… (same as before)
    save_dir = os.path.join("images_triple", f"{'_'.join(conds)}_{win}bp")
    os.makedirs(save_dir, exist_ok=True)
    rank_val = chip_rank_lookup.get(motif_type, None)
    if rank_val is None:
        rank_pref = "NA"
    else:
        rank_pref = str(int(round(rank_val)))

    base      = f"{rank_pref}_{motif_type}_b{win}_chr{chr_sel}"
    png_path  = os.path.join(save_dir, base + ".png")
    svg_path = os.path.join(save_dir, base + ".svg")

    stitched.save(png_path, dpi=(PNG_DPI, PNG_DPI))

    # Build combined SVG (also remove hard‑coded translations for table height)
    main_svg = fig.to_image(format="svg", width=FIG_WIDTH, height=FIG_HEIGHT).decode()

    if table_fig and gene_table_fig:
        tbl_svg      = table_fig.to_image(format="svg").decode()
        gene_tbl_svg = gene_table_fig.to_image(format="svg").decode()

        svg_head      = '<svg xmlns="http://www.w3.org/2000/svg">'
        # place motif table right under the main figure
        translate_tbl = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX})'
        # place gene table directly under motif table; use motif’s own height
        translate_gene = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX + table_fig.layout.height + TABLE_MARGIN_PX})'

        combined_svg = (
            svg_head
            + f'<g>{main_svg}</g>'
            + f'<g transform="{translate_tbl}">{tbl_svg}</g>'
            + f'<g transform="{translate_gene}">{gene_tbl_svg}</g>'
            + '</svg>'
        )

    elif table_fig:
        tbl_svg      = table_fig.to_image(format="svg").decode()
        svg_head      = '<svg xmlns="http://www.w3.org/2000/svg">'
        translate_tbl = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX})'

        combined_svg = (
            svg_head
            + f'<g>{main_svg}</g>'
            + f'<g transform="{translate_tbl}">{tbl_svg}</g>'
            + '</svg>'
        )

    elif gene_table_fig:
        gene_tbl_svg  = gene_table_fig.to_image(format="svg").decode()
        svg_head      = '<svg xmlns="http://www.w3.org/2000/svg">'
        translate_gene = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX})'

        combined_svg = (
            svg_head
            + f'<g>{main_svg}</g>'
            + f'<g transform="{translate_gene}">{gene_tbl_svg}</g>'
            + '</svg>'
        )

    else:
        combined_svg = main_svg

    with open(svg_path, "w") as fh:
        fh.write(combined_svg)

    if debug_usage:
        print(f"[Worker] saved → {png_path}")
    return png_path



# ─────────────────────────── CLI / main block ─────────────────────────── #
if __name__ == "__main__":
    # ────────────── CONFIGURATION ──────────────
    FIG_WIDTH     = 800    # width in pixels for exported PNG/SVG and layout
    FIG_HEIGHT    = 800     # height in pixels for exported PNG/SVG and layout
    DISPLAY_WIDTH = 1600     # width in pixels when displaying in‑notebook

    debug_usage = False
    num_workers = 50
    temp_bed_w  = 1000

    # three conditions in desired display order
    TRIPLE_COND = (analysis_cond[1], analysis_cond[0], analysis_cond[2])

    # prepend “MOTIFS_”
    motif_types = [f"MOTIFS_{t}" for t in type_selected]

    # ─── FULL RUN – every motif type for the fixed 3‑condition set ───────
    # Uncomment to generate the whole batch
    combos = [(TRIPLE_COND, mt) for mt in motif_types]

    # ─── QUICK TEST – one type, same 3 conditions ────────────────────────
    combos = [(TRIPLE_COND, "MOTIFS_rex48")]

    saved_pngs = [
        p for p in run_in_pool(
            func=_save_triple_plot,
            iterable=combos,
            n_workers=num_workers,
            desc="Saving three‑track plots",
            unit="plot"
        )
        if p is not None          # keep only the paths that were actually saved
    ]


    if len(saved_pngs) == 1:
        display(Image(filename=saved_pngs[0]))


In [None]:
# ╔══════════════════════════════════════════════════════════════════════════╗
# ║  Unified read-track plotting (1–4 conditions) × motif types              ║
# ║                                                                          ║
# ║  Usage in notebook (define these in a prior cell):                       ║
# ║    CONDITIONS  = (analysis_cond[1], analysis_cond[0], analysis_cond[2])  ║
# ║    MOTIF_TYPES = [f"MOTIFS_{t}" for t in type_selected]                  ║
# ║    run_batch_save(CONDITIONS, MOTIF_TYPES)                               ║
# ║                                                                          ║
# ║  Assumes your existing variables/dataframes exist:                       ║
# ║    down_sampled_plot_df, down_sampled_group_df, temp_bed_w, etc.         ║
# ╚══════════════════════════════════════════════════════════════════════════╝

# ───────────────────────────── Imports ────────────────────────────── #
import os, io, math, warnings
from itertools import product
from multiprocessing import Pool, cpu_count
from typing import List, Sequence, Tuple, Optional

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from PIL import Image as PILImage
from scipy.ndimage import gaussian_filter1d

# optional (used by some sorters)
import scipy.cluster.hierarchy as sch
from sklearn.decomposition import PCA
from fastdtw import fastdtw
from scipy.spatial.distance import squareform

# notebook helpers
from IPython.display import Image, display

# your local lib is assumed to exist already
import importlib, nanotools
importlib.reload(nanotools)

# ──────────────────────────── Plot template ───────────────────────── #
pio.templates.default = "plotly_white"  # per user preference

# ──────────────────────────── Global config ───────────────────────── #
# Debug toggles (as requested: progress + data summaries)
DEBUG_PROGRESS = False     # terse progress prints from workers
DEBUG_SUMMARY  = False     # one-shot data summaries

# Sorting config (same semantics as earlier cells)
SORT_METHOD       = 0      # 1=binned-HC, 3=PCA, 4=DTW, else=autocorr
BIN_COUNT         = 200    # for 1,3,4
HC_LINKAGE        = "ward"
PCA_COMPONENTS    = 1
DTW_DISTANCE      = "euclidean"
AC_SORT_LAG_MIN   = 155
AC_SORT_LAG_MAX   = 195

# Figure + export
FIG_WIDTH   = 900
FIG_HEIGHT  = 700
PNG_DPI     = 300
TABLE_CELL_HEIGHT = 22
TABLE_HEADER_H    = 28
TABLE_MARGIN_PX   = 12

# Colors (up to 4 conditions)
NUC_LINE_CLR    = "#35b779"

# ───────────────────── Reference tables (load once) ────────────────── #
# Adjust paths if needed
tss_tes_df = pd.read_csv(
    "/Data1/reference/tss_tes_rex_combined_v27_WS235.bed",
    sep=r"\s+",
    usecols=["chromosome", "start", "end", "strand", "type"]
)
_chiprank_df = pd.read_csv(
    "/Data1/reference/rex_chiprank.bed",
    sep=r"\s+",
    usecols=["type", "chip_rank"]
)
_chiprank_df["type"] = "MOTIFS_" + _chiprank_df["type"].astype(str)
chip_rank_lookup = {t: round(float(rk) * 100, 0)
                    for t, rk in zip(_chiprank_df["type"], _chiprank_df["chip_rank"])}

# ─────────────────────────── Utilities ─────────────────────────────── #
def run_in_pool(func, iterable: Sequence, n_workers: int, desc: str, unit: str = "item"):
    """tqdm-wrapped Pool.imap_unordered."""
    from tqdm.auto import tqdm
    with Pool(processes=n_workers) as pool:
        return list(tqdm(pool.imap_unordered(func, iterable),
                         total=len(iterable), desc=desc, unit=unit))

def stitch_vertical(png_top: bytes, png_bottom: bytes) -> PILImage.Image:
    """Vertically stack two PNG byte streams with a small margin."""
    top = PILImage.open(io.BytesIO(png_top)).convert("RGBA")
    bot = PILImage.open(io.BytesIO(png_bottom)).convert("RGBA")
    w = max(top.width, bot.width)
    h = top.height + TABLE_MARGIN_PX + bot.height
    canvas = PILImage.new("RGBA", (w, h), "white")
    canvas.paste(top, (0, 0))
    canvas.paste(bot, (0, top.height + TABLE_MARGIN_PX))
    return canvas

# ─────────────────────── Filtering / explode ───────────────────────── #
def _filter_reads(df: pd.DataFrame, condition: str, chr_type: str, data_type: str,
                  plot_window: int, require_full_span: bool,
                  strand_filter: Optional[Sequence[str]] = None,
                  debug: bool=False) -> pd.DataFrame:
    reads = df.loc[(df["condition"] == condition)
                   & (df["chr_type"] == chr_type)
                   & (df["type"] == data_type)].copy()
    if strand_filter is not None and "bed_strand" in reads.columns:
        reads = reads.loc[reads["bed_strand"].isin(list(strand_filter))]
    if require_full_span:
        span = ((reads["rel_read_start"] <= -round(plot_window * 0.3, 0)) &
                (reads["rel_read_end"]   >=  round(plot_window * 0.3, 0)))
        reads = reads.loc[span]
    reads = reads.loc[reads["mod_qual_bin"].apply(lambda lst: 1 in lst)]
    if debug and DEBUG_SUMMARY:
        print(f"[filter] {condition}/{data_type} → {len(reads)} rows"
              + (f" | strand∈{list(strand_filter)}" if strand_filter else ""))
    return reads

def _long_format(read_df: pd.DataFrame, plot_window: int) -> pd.DataFrame:
    recs: List[Tuple[str,int,int,int]] = []
    for r in read_df.itertuples():
        idx = [i for i, p in enumerate(r.rel_pos) if -plot_window <= p <= plot_window]
        for i in idx:
            recs.append((r.read_id, r.rel_pos[i], r.mod_qual_bin[i], r.read_count))
    return pd.DataFrame(recs, columns=["read_id", "rel_pos", "mod_qual_bin", "read_count"])

def _add_read_occupancy(fig, read_df, plot_window, row, color, read_width):
    line_x, line_y = [], []
    for r in read_df.itertuples():
        mn = max(-plot_window, r.rel_read_start)
        mx = min( plot_window, r.rel_read_end)
        line_x += [mn, mx, None]
        line_y += [r.read_count] * 3
    fig.add_trace(go.Scatter(x=line_x, y=line_y, mode="lines",
                             line=dict(color=color, width=read_width),
                             showlegend=False), row=row, col=1)

# ─────────────────────────── Sorting helpers ───────────────────────── #
def _build_binned_matrix(read_df, plot_window, n_bins=BIN_COUNT):
    edges = np.linspace(-plot_window, plot_window, n_bins+1)
    M, ids = [], []
    for r in read_df.itertuples():
        pos = np.asarray(r.rel_pos)[np.asarray(r.mod_qual_bin) == 1]
        counts, _ = np.histogram(pos, bins=edges)
        M.append(counts / np.diff(edges))
        ids.append(r.read_id)
    return np.vstack(M), ids

def _make_nodups(read_df, ordered_ids):
    tmp = (read_df[["read_id",
                    "smallest_positive_nuc_midpoint",
                    "greatest_negative_nuc_midpoint"]]
           .drop_duplicates("read_id")
           .set_index("read_id"))
    tmp = tmp.loc[[i for i in ordered_ids if i in tmp.index]].reset_index()
    tmp["read_count"] = np.arange(1, len(tmp) + 1)
    return tmp

def _order_by_binned(read_df, plot_window, debug=False):
    M, ids = _build_binned_matrix(read_df, plot_window)
    Z = sch.linkage(M, method=HC_LINKAGE, metric="euclidean")
    leaves = sch.dendrogram(Z, no_plot=True)["leaves"]
    return _make_nodups(read_df, [ids[i] for i in leaves])

def _order_by_pca(read_df, plot_window, debug=False):
    M, ids = _build_binned_matrix(read_df, plot_window)
    pc1 = PCA(n_components=PCA_COMPONENTS).fit_transform(M)[:, 0]
    order = np.argsort(pc1)
    return _make_nodups(read_df, [ids[i] for i in order])

def _order_by_dtw(read_df, plot_window, debug=False):
    M, ids = _build_binned_matrix(read_df, plot_window)
    n = M.shape[0]
    D = np.zeros((n, n))
    for i in range(n):
        for j in range(i):
            dist, _ = fastdtw(M[i], M[j])
            D[i, j] = D[j, i] = dist
    Z = sch.linkage(squareform(D), method=HC_LINKAGE)
    leaves = sch.dendrogram(Z, no_plot=True)["leaves"]
    return _make_nodups(read_df, [ids[i] for i in leaves])

def _order_reads_by_ac(read_df: pd.DataFrame, plot_window: int,
                       lag_lo: int = AC_SORT_LAG_MIN, lag_hi: int = AC_SORT_LAG_MAX,
                       debug: bool=False) -> pd.DataFrame:
    lag_lo, lag_hi = sorted((lag_lo, lag_hi))
    lag_hi = min(lag_hi, 2 * plot_window)
    means = {}
    vec_len = 2 * plot_window + 1
    for r in read_df.itertuples():
        rel_pos = np.asarray(r.rel_pos, dtype=int)
        mod_bin = np.asarray(r.mod_qual_bin, dtype=float)
        vec = np.full(vec_len, np.nan)
        idx = rel_pos + plot_window
        keep = (idx >= 0) & (idx < vec_len)
        vec[idx[keep]] = mod_bin[keep]
        ac = nanotools._autocorr_vec(vec, lag_hi)
        win = ac[lag_lo:lag_hi+1]
        means[r.read_id] = np.nanmax(win)
    nodups = (read_df[["read_id",
                       "smallest_positive_nuc_midpoint",
                       "greatest_negative_nuc_midpoint"]]
              .drop_duplicates("read_id"))
    nodups["ac_mean"] = nodups["read_id"].map(means).fillna(-np.inf)
    nodups = nodups.sort_values("ac_mean", ascending=False).drop(columns="ac_mean")
    nodups["read_count"] = np.arange(1, len(nodups) + 1)
    return nodups

def _order_read_df(df_reads, plot_window):
    if SORT_METHOD == 1:
        nodups = _order_by_binned(df_reads, plot_window)
    elif SORT_METHOD == 3:
        nodups = _order_by_pca(df_reads, plot_window)
    elif SORT_METHOD == 4:
        nodups = _order_by_dtw(df_reads, plot_window)
    elif SORT_METHOD == 0:
        nodups = (df_reads[["read_id",
                            "smallest_positive_nuc_midpoint",
                            "greatest_negative_nuc_midpoint"]]
                  .drop_duplicates("read_id"))
        nodups["read_count"] = np.arange(1, len(nodups) + 1)
    else:
        nodups = _order_reads_by_ac(df_reads, plot_window)
    return df_reads.merge(nodups[["read_id", "read_count"]], on="read_id", how="left")

# ───────────────────── Motifs: extract + tables + annotations ───────────── #
def _build_motif_df(read_df: pd.DataFrame, plot_window: int, debug: bool=False) -> pd.DataFrame:
    def _trim(row):
        starts = [] if pd.isna(row.motif_rel_start) else list(row.motif_rel_start)
        attrs  = [] if pd.isna(row.motif_attributes) else list(row.motif_attributes)
        keep   = [i for i,p in enumerate(starts) if -plot_window <= p <= plot_window]
        return (tuple(starts[i] for i in keep),
                tuple(attrs[i]  for i in keep))
    tmp = read_df.apply(_trim, axis=1, result_type="expand")
    tmp.columns = ["rel_start_list", "attributes_list"]
    exploded = (tmp.assign(pairs=tmp.apply(lambda r: list(zip(r.rel_start_list, r.attributes_list)), axis=1))
                   .explode("pairs").dropna(subset=["pairs"]))
    if exploded.empty:
        return pd.DataFrame(columns=["rel_start","motif_id","ln_p","strand","seq","num"])
    motif_df = pd.DataFrame(exploded["pairs"].tolist(), columns=["rel_start","attributes"])
    motif_df = motif_df.drop_duplicates().sort_values("rel_start").reset_index(drop=True)
    motif_df["num"] = motif_df.index + 1
    motif_df[["motif_id","ln_p","strand","seq"]] = pd.DataFrame(motif_df["attributes"].tolist(), index=motif_df.index)
    return motif_df[["rel_start","motif_id","ln_p","strand","seq","num"]]

def create_motif_table(read_df: pd.DataFrame, plot_window: int) -> Optional[go.Figure]:
    motif_df = _build_motif_df(read_df, plot_window)
    if motif_df.empty:
        return None
    px_per_char, pad = 7, 24
    def _w(col): return len(str(max(col, key=lambda x: len(str(x))))) * px_per_char + pad
    widths = [_w(motif_df[c]) for c in ["num","motif_id","ln_p","strand","seq"]]
    fig = go.Figure(data=[go.Table(
        columnwidth=widths,
        header=dict(values=["#", "Motif", "ln(p-val)", "Strand", "Sequence"],
                    fill_color="white", font=dict(color="black", size=14),
                    line_color="lightgrey", align="center"),
        cells=dict(values=[motif_df["num"], motif_df["motif_id"], motif_df["ln_p"],
                           motif_df["strand"], motif_df["seq"]],
                   fill_color="white", font=dict(color="black", size=12),
                   line_color="lightgrey", align="center",
                   height=TABLE_CELL_HEIGHT)
    )])
    fig.update_layout(margin=dict(l=0,r=0,t=0,b=0),
                      paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)",
                      height=TABLE_HEADER_H + len(motif_df)*TABLE_CELL_HEIGHT)
    return fig

def create_gene_table(read_df: pd.DataFrame) -> Optional[go.Figure]:
    if read_df.empty:
        return None
    chrom   = read_df["chrom"].iloc[0]
    ref_pos = int(read_df["bed_start"].iloc[0])
    subset = tss_tes_df.loc[(tss_tes_df["chromosome"] == chrom) &
                            (tss_tes_df["type"].str.contains("gene_q"))].copy()
    if subset.empty:
        return None
    subset["abs_start"] = subset["start"]
    subset["abs_end"]   = subset["end"]
    subset["rel_start"] = subset["abs_start"] - ref_pos
    subset["rel_end"]   = subset["abs_end"]   - ref_pos
    subset["exp_quart"] = subset["type"].str.extract(r"_(q\d+)")[0]
    upstream   = subset.loc[subset["rel_end"]   < 0]
    downstream = subset.loc[subset["rel_start"] > 0]
    rows = []
    if not upstream.empty:
        rows.append(upstream.loc[upstream["rel_end"].idxmax()])
    if not downstream.empty:
        rows.append(downstream.loc[downstream["rel_start"].idxmin()])
    if not rows:
        return None
    gene_df = pd.DataFrame(rows)[["chromosome","abs_start","abs_end","strand","rel_start","rel_end","exp_quart"]]
    px_per_char, pad = 7, 24
    def _w(col): return len(str(max(col, key=lambda x: len(str(x))))) * px_per_char + pad
    widths = [_w(gene_df[c]) for c in ["chromosome","abs_start","abs_end","strand","rel_start","rel_end","exp_quart"]]
    fig = go.Figure(data=[go.Table(
        columnwidth=widths,
        header=dict(values=["Chromosome","Abs Start","Abs End","Strand","Rel Start","Rel End","Exp Quart"],
                    fill_color="white", font=dict(color="black", size=14),
                    line_color="lightgrey", align="center"),
        cells=dict(values=[gene_df[c] for c in ["chromosome","abs_start","abs_end","strand","rel_start","rel_end","exp_quart"]],
                   fill_color="white", font=dict(color="black", size=12),
                   line_color="lightgrey", align="center",
                   height=TABLE_CELL_HEIGHT)
    )])
    fig.update_layout(margin=dict(l=0,r=0,t=0,b=0),
                      paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)",
                      height=TABLE_HEADER_H + len(gene_df)*TABLE_CELL_HEIGHT)
    return fig

def _annotate_motifs(fig, motif_df: pd.DataFrame):
    if motif_df.empty:
        return
    ANNOT_Y = 1.02
    for row in motif_df.itertuples():
        pos, num = row.rel_start, row.num
        fig.add_shape(type="line", x0=pos, x1=pos, y0=0, y1=1,
                      xref="x", yref="paper",
                      line=dict(color="grey", width=1, dash="dash"))
        fig.add_annotation(x=pos, y=ANNOT_Y, yref="paper",
                           text=str(num), showarrow=False,
                           font=dict(size=10),
                           bgcolor="rgba(255,255,255,0.5)")

# ───────────────────── %m6A series helper (smoothed) ─────────────────── #
def _m6a_series(plot_df: pd.DataFrame, cond: str, chr_type: str, data_type: str,
                plot_window: int, smoothing_type: str, smoothing_window: int,
                gaussian_sigma: float, strand_filter: Optional[Sequence[str]] = None):
    df = plot_df.loc[(plot_df["condition"] == cond)
                     & (plot_df["chr_type"] == chr_type)
                     & (plot_df["type"] == data_type)]
    if strand_filter is not None and "bed_strand" in df.columns:
        df = df.loc[df["bed_strand"].isin(list(strand_filter))]
    if df.empty:
        return pd.Series([], dtype=float), pd.Series([], dtype=float)
    recs = [(p, b)
            for r in df.itertuples()
            for p, b in zip(r.rel_pos, r.mod_qual_bin)
            if -plot_window <= p <= plot_window]
    if not recs:
        return pd.Series([], dtype=float), pd.Series([], dtype=float)
    agg = (pd.DataFrame(recs, columns=["pos","bin"])
             .groupby("pos")["bin"].agg(["sum","count"]).reset_index())
    agg["ratio"] = agg["sum"] / agg["count"]
    if smoothing_type == "none":
        y = agg["ratio"].to_numpy()
    elif smoothing_type == "moving":
        y = (agg["sum"].rolling(smoothing_window, center=True).sum() /
             agg["count"].rolling(smoothing_window, center=True).sum()).to_numpy()
    else:
        y = gaussian_filter1d(agg["ratio"].to_numpy(), sigma=gaussian_sigma)
    x = agg["pos"].to_numpy()
    mask = ~np.isnan(y)
    return pd.Series(x[mask]), pd.Series(y[mask])

# ─────────────────────────── Core multi-plot ─────────────────────────── #
def create_multi_plot(
    plot_df: pd.DataFrame,
    group_df: pd.DataFrame,         # kept for API parity
    conditions,                     # str | list[str] (1–4)
    chr_type: str,
    data_type: str,
    plot_window: int,
    *,
    require_full_span: bool = True,
    plot_motifs: bool = True,
    smoothing_type: str = "moving",
    smoothing_window: int = 50,
    gaussian_sigma: float = 10,
    subsample_reads: bool = True,
    max_reads_per_condition: Optional[int] = 300,
    random_state: Optional[int] = 42,
    strand_filter: Optional[Sequence[str]] = None,
    debug: bool = False
) -> Optional[go.Figure]:
    """Single, double, triple, or quadruple read tracks with %m6A panels.
       First condition is compared with all others on separate %m6A panels."""
    conds = list(conditions) if isinstance(conditions, (list, tuple)) else [conditions]
    if not (1 <= len(conds) <= 4):
        raise ValueError("conditions must contain 1–4 entries")
    n = len(conds)

    # color map for all conditions (uses nanotools.get_color)
    color_for = {c: nanotools.get_color(c) for c in conds}

    # collect per-condition filtered reads
    cond_dfs = []
    for c in conds:
        df = _filter_reads(plot_df, c, chr_type, data_type, plot_window,
                           require_full_span, strand_filter=strand_filter, debug=debug)
        cond_dfs.append(df)

    if all(df.empty for df in cond_dfs):
        warnings.warn(f"No reads for {conds} / {data_type}.")
        return None

    # optional subsampling per condition
    rng = np.random.default_rng(random_state)
    if subsample_reads and max_reads_per_condition is not None:
        for i, df in enumerate(cond_dfs):
            if df.empty:
                continue
            uniq = df["read_id"].drop_duplicates()
            if len(uniq) > max_reads_per_condition:
                keep = rng.choice(uniq, size=max_reads_per_condition, replace=False)
                cond_dfs[i] = df[df["read_id"].isin(keep)].copy()
                if debug and DEBUG_SUMMARY:
                    print(f"[multi] {conds[i]} subsampled to {len(keep)} reads")

    # order reads independently per condition
    for i, df in enumerate(cond_dfs):
        if df.empty:
            continue
        cond_dfs[i] = _order_read_df(df, plot_window)

    # figure scaffold
    if n == 1:
        total_rows = 2                         # read, %m6A
        row_heights = [0.45, 0.55]
        read_rows   = [1]
        m6a_rows    = [2]
        m6a_pairs   = [(conds[0], None)]       # plot single series
    else:
        # rows = (n-1) combined panels (cond0 vs cond_i) + n read tracks
        total_rows  = (n - 1) + n
        # weights: combined panels heavier than read stripes → normalize
        weights = [3]*(n-1) + [2]*n
        s = float(sum(weights))
        row_heights = [w/s for w in weights]
        m6a_rows  = list(range(1, n))          # top panels
        read_rows = list(range(n, total_rows+1))
        m6a_pairs = [(conds[0], conds[i]) for i in range(1, n)]

    fig = make_subplots(rows=total_rows, cols=1, shared_xaxes=True, vertical_spacing=0.02,
                        row_heights=row_heights)
    fig.update_xaxes(range=[-plot_window, plot_window], tickmode="linear", tick0=0, dtick=300)

    # per-condition read layers + dots
    motif_collect = []
    for idx, (c, df_reads) in enumerate(zip(conds, cond_dfs)):
        if df_reads.empty:
            continue
        motif_collect.append(df_reads[["motif_rel_start", "motif_attributes"]])

        # read occupancy stripes
        _add_read_occupancy(
            fig, df_reads, plot_window,
            row=read_rows[idx], color=color_for[c], read_width=0.025
        )

        # m6A hit dots
        long_df = _long_format(df_reads, plot_window)
        hits = long_df[long_df["mod_qual_bin"] == 1]
        fig.add_trace(
            go.Scatter(
                x=hits["rel_pos"], y=hits["read_count"],
                mode="markers",
                marker=dict(symbol="square", size=1.5, color=color_for[c]),
                showlegend=False
            ),
            row=read_rows[idx], col=1
        )
        fig.update_yaxes(title_text="Read ID", row=read_rows[idx], col=1)

    # %m6A panels: compare cond0 vs each other
    global_max = 0.0
    if n == 1:
        x0, y0 = _m6a_series(plot_df, conds[0], chr_type, data_type,
                             plot_window, smoothing_type, smoothing_window, gaussian_sigma,
                             strand_filter=strand_filter)
        if len(y0):
            global_max = max(global_max, float(y0.max()))
            fig.add_trace(
                go.Scatter(x=x0, y=y0, mode="lines",
                           line=dict(color=color_for[conds[0]], width=2),
                           name=str(conds[0])),
                row=m6a_rows[0], col=1
            )
        fig.update_yaxes(title_text="% m6A", row=m6a_rows[0], col=1)
    else:
        for r, (c0, ci) in zip(m6a_rows, m6a_pairs):
            xA, yA = _m6a_series(plot_df, c0, chr_type, data_type,
                                 plot_window, smoothing_type, smoothing_window, gaussian_sigma,
                                 strand_filter=strand_filter)
            xB, yB = _m6a_series(plot_df, ci, chr_type, data_type,
                                 plot_window, smoothing_type, smoothing_window, gaussian_sigma,
                                 strand_filter=strand_filter)
            if len(yA): global_max = max(global_max, float(yA.max()))
            if len(yB): global_max = max(global_max, float(yB.max()))
            if len(yA):
                fig.add_trace(
                    go.Scatter(x=xA, y=yA, mode="lines",
                               line=dict(color=color_for[c0], width=2),
                               name=str(c0)),
                    row=r, col=1
                )
            if len(yB):
                fig.add_trace(
                    go.Scatter(x=xB, y=yB, mode="lines",
                               line=dict(color=color_for[ci], width=2),
                               name=str(ci)),
                    row=r, col=1
                )
            fig.update_yaxes(title_text="% m6A", row=r, col=1)

        # equalize y across all %m6A panels
        max_pct = math.ceil((max(global_max, 0.01) * 100) / 2) * 2
        ymax = max_pct / 100
        for r in m6a_rows:
            fig.update_yaxes(range=[0, ymax], tickformat=".0%", row=r, col=1)

    # motif annotations (single authoritative numbering across all conditions)
    if plot_motifs and motif_collect:
        merged = pd.concat(motif_collect, ignore_index=True)
        motif_df = _build_motif_df(merged, plot_window, debug=debug and DEBUG_SUMMARY)
        _annotate_motifs(fig, motif_df)

    # layout
    ttl = ", ".join(conds) if n == 1 else f"{conds[0]} vs " + ", ".join(conds[1:])
    fig.update_layout(template="plotly_white",
                      width=FIG_WIDTH, height=FIG_HEIGHT,
                      title=f"Conditions: {ttl} | Type: {data_type}",
                      legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))
    fig.update_xaxes(title_text="Genomic position (bp)", row=total_rows, col=1)
    return fig

# ───────────────────────── Saving worker ───────────────────────────── #
def _save_multi_plot(args):
    """Worker for multiprocessing. args = (conditions_tuple, motif_type, chr_sel, win, out_root, strand_filter)."""
    conds, motif_type, chr_sel, win, out_root, strand_filter = args

    mask = (down_sampled_plot_df["condition"].isin(conds)
            & (down_sampled_plot_df["chr_type"] == chr_sel)
            & (down_sampled_plot_df["type"] == motif_type))
    if strand_filter is not None and "bed_strand" in down_sampled_plot_df.columns:
        mask &= down_sampled_plot_df["bed_strand"].isin(list(strand_filter))
    if not mask.any():
        if DEBUG_PROGRESS:
            print(f"[skip] {conds} / {motif_type}: no rows after strand filter")
        return None

    fig = create_multi_plot(
        plot_df           = down_sampled_plot_df,
        group_df          = down_sampled_group_df,
        conditions        = list(conds),
        chr_type          = chr_sel,
        data_type         = motif_type,
        plot_window       = win,
        require_full_span = True,
        plot_motifs       = True,
        smoothing_type    = "moving",
        smoothing_window  = 50,
        gaussian_sigma    = 10,
        subsample_reads   = True,
        max_reads_per_condition = 300,
        random_state      = 42,
        strand_filter     = strand_filter,
        debug             = DEBUG_PROGRESS
    )
    if fig is None:
        return None

    read_subset = down_sampled_plot_df.loc[mask].copy()
    motif_tbl   = create_motif_table(read_subset, win)
    gene_tbl    = create_gene_table(read_subset)

    # directory + filenames
    cond_slug = "_".join(conds) if len(conds) == 1 else f"{conds[0]}_vs_" + "_".join(conds[1:])
    save_dir  = os.path.join(out_root, f"{cond_slug}_{win}bp")
    os.makedirs(save_dir, exist_ok=True)

    rank_val = chip_rank_lookup.get(motif_type, None)
    rank_pref = "NA" if rank_val is None else str(int(round(rank_val)))
    base      = f"{rank_pref}_{motif_type}_b{win}_chr{chr_sel}"
    png_path  = os.path.join(save_dir, base + ".png")
    svg_path  = os.path.join(save_dir, base + ".svg")

    # export main
    main_png = fig.to_image(format="png", scale=2)

    # stitching
    if motif_tbl and gene_tbl:
        tbl_png      = motif_tbl.to_image(format="png", scale=2)
        stitched_1   = stitch_vertical(main_png, tbl_png)
        buf = io.BytesIO(); stitched_1.save(buf, format="PNG"); stitched_1_png = buf.getvalue()
        gene_tbl_png = gene_tbl.to_image(format="png", scale=2)
        stitched     = stitch_vertical(stitched_1_png, gene_tbl_png)
    elif motif_tbl:
        tbl_png  = motif_tbl.to_image(format="png", scale=2)
        stitched = stitch_vertical(main_png, tbl_png)
    elif gene_tbl:
        gene_tbl_png = gene_tbl.to_image(format="png", scale=2)
        stitched     = stitch_vertical(main_png, gene_tbl_png)
    else:
        stitched = PILImage.open(io.BytesIO(main_png)).convert("RGBA")

    # write PNG
    stitched.save(png_path, dpi=(PNG_DPI, PNG_DPI))

    # assemble SVG stack
    main_svg = fig.to_image(format="svg").decode()
    if motif_tbl and gene_tbl:
        tbl_svg      = motif_tbl.to_image(format="svg").decode()
        gene_tbl_svg = gene_tbl.to_image(format="svg").decode()
        svg_head = '<svg xmlns="http://www.w3.org/2000/svg">'
        translate_tbl  = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX})'
        translate_gene = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX + motif_tbl.layout.height + TABLE_MARGIN_PX})'
        combined_svg = (svg_head + f'<g>{main_svg}</g>'
                        + f'<g transform="{translate_tbl}">{tbl_svg}</g>'
                        + f'<g transform="{translate_gene}">{gene_tbl_svg}</g>'
                        + '</svg>')
    elif motif_tbl:
        tbl_svg = motif_tbl.to_image(format="svg").decode()
        svg_head = '<svg xmlns="http://www.w3.org/2000/svg">'
        translate_tbl = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX})'
        combined_svg = svg_head + f'<g>{main_svg}</g>' + f'<g transform="{translate_tbl}">{tbl_svg}</g>' + '</svg>'
    elif gene_tbl:
        gene_tbl_svg = gene_tbl.to_image(format="svg").decode()
        svg_head = '<svg xmlns="http://www.w3.org/2000/svg">'
        translate_gene = f'translate(0,{fig.layout.height + TABLE_MARGIN_PX})'
        combined_svg = svg_head + f'<g>{main_svg}</g>' + f'<g transform="{translate_gene}">{gene_tbl_svg}</g>' + '</svg>'
    else:
        combined_svg = main_svg

    with open(svg_path, "w") as fh:
        fh.write(combined_svg)

    if DEBUG_PROGRESS:
        print(f"[saved] {png_path}")
    return png_path

# ─────────────────────────── Batch runner ───────────────────────────── #
def run_batch_save(
    CONDITIONS,
    MOTIF_TYPES,
    *,
    chr_sel: str = "X",
    win: Optional[int] = None,
    out_root: str = "images_multi",
    n_workers: Optional[int] = None,
    strand_filter: Optional[Sequence[str]] = None
):
    if CONDITIONS is None or MOTIF_TYPES is None:
        raise ValueError("Define CONDITIONS and MOTIF_TYPES before calling run_batch_save().")
    if win is None:
        if "temp_bed_w" not in globals():
            raise ValueError("Provide 'win' or define global 'temp_bed_w'.")
        win = int(temp_bed_w)
    n_workers = n_workers or max(1, min(cpu_count(), 50))
    conds_tuple = tuple(CONDITIONS)
    jobs = [(conds_tuple, mt, chr_sel, win, out_root, strand_filter) for mt in list(MOTIF_TYPES)]
    results = run_in_pool(_save_multi_plot, jobs, n_workers=n_workers,
                          desc="Saving plots", unit="plot")
    return [p for p in results if p is not None]

# ───────────────────────── End unified cell ─────────────────────────── #
# 1) Choose 1–4 conditions in desired order (duplicates allowed)
COND_IDXS  = [0,3,4,5]   # e.g., [1], [1,0], [1,0,2], or [1,0,2,2]
CONDITIONS = tuple(analysis_cond[i] for i in COND_IDXS)

# 2) Choose motif types
MOTIF_TYPES = [f"MOTIFS_{t}" for t in type_selected]  # or specify a subset, e.g. ["MOTIFS_rex48"]
MOTIF_TYPES = ["MEX_motif"]
STRAND_FILTER = ["+","-"] # None or ["+"]
# 3) Parameters 
CHR       = "X"                                         # chromosome group used in your filtered DF
WIN       = 1500 #temp_bed_w if "temp_bed_w" in globals() else 1000
OUT_ROOT  = "images_multi"                              # output root folder
N_WORKERS = min(50, cpu_count())                        # multiprocessing

# 4) Run batch save
saved_pngs = run_batch_save(CONDITIONS, MOTIF_TYPES, chr_sel=CHR, win=WIN,
                            out_root=OUT_ROOT, n_workers=N_WORKERS,
                            strand_filter=STRAND_FILTER)
print(f"Saved {len(saved_pngs)} PNG(s) under '{OUT_ROOT}'")

# 5) Optional inline preview
if MOTIF_TYPES:
    fig = create_multi_plot(
        plot_df=down_sampled_plot_df,
        group_df=down_sampled_group_df,
        conditions=CONDITIONS,
        chr_type=CHR,
        data_type=MOTIF_TYPES[0],
        plot_window=WIN,
        require_full_span=True,
        plot_motifs=False,
        smoothing_type="moving",
        smoothing_window=25,
        gaussian_sigma=10,
        subsample_reads=True,
        max_reads_per_condition=300,
        random_state=42,
        strand_filter=STRAND_FILTER,
        debug=False
    )
    if fig is not None:
        fig.show()

In [None]:
### CELL 1 · PREP FOR PLOTTING  ──────────────────────────────────────────────
import numpy as np
import pandas as pd
from scipy.signal import find_peaks
import multiprocessing as mp
from tqdm.auto import tqdm
import plotly.graph_objects as go

# ────────────────── CONFIGURATION ──────────────────
OUTPUT_DIR = "images_20250604"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ────────────────── CONFIGURATION ──────────────────
CHIP_RANK_CUTOFF        = 80
CHR_TYPE_INCLUDE        = ["X"]                # [] → include all chr_types
REL_POS_RANGE           = 500
N_BOOTSTRAP_GROUPS      = 2
SMOOTH_WINDOW           = 100              # bp for moving-average smoothing
SMOOTH_WINDOW_METRICS   = 10 # for m6A peak and mean
RANDOM_STATE            = 36
N_WORKERS               = max(2, mp.cpu_count() - 2)
DEBUG                   = True
PLOT_TEMPLATE           = "plotly_white"

DOWNSAMPLE_BALANCE      = True                       # ← set True to equalise reads per (type,condition)

PEAK_SELECTION_METHOD   = "closest_to_zero"  # "height" | "area" | "closest_to_zero"
CONDITIONS_TO_INCLUDE   = [analysis_cond[0], analysis_cond[1], analysis_cond[2]]

# NEW ▶ when True we keep “intergenic_control” reads, give them their own
#       second_type, and rename their type →  intergenic_control_<bed_start>
INCLUDE_INTERGENIC      = True             # ← flip to True when needed

DEFAULT_READ_CLRS       = ["#4974a5", "#b12537", "#47B562", "#808080"]

def print_debug(msg: str):
    if DEBUG:
        print(f"[DEBUG] {msg}")

# ───── 1) LOAD CHIP RANKS & BUILD LOOKUP ─────
print_debug("Loading chip rank file and building lookup…")
chiprank_df = (
    pd.read_csv("/Data1/reference/rex_chiprank.bed", sep=r"\s+")
      .assign(type=lambda d: "MOTIFS_" + d["type"].astype(str))
)
chip_rank_lookup = {
    t: round(float(rk) * 100, 3)
    for t, rk in zip(chiprank_df["type"], chiprank_df["chip_rank"])
}
print_debug(f"Loaded {len(chip_rank_lookup)} motif types from chip rank.")

# ───── 2) FILTER READS & ASSIGN SECOND_TYPE ─────
print_debug("Filtering merged_df by type, chr_type, and selected conditions…")

# decide which read types are allowed
types_to_keep = set(chip_rank_lookup.keys())
if INCLUDE_INTERGENIC:
    types_to_keep.add("intergenic_control")

keep_chr = set(CHR_TYPE_INCLUDE) if CHR_TYPE_INCLUDE else set(merged_df["chr_type"].unique())

df = merged_df[
    merged_df["type"].isin(types_to_keep)
    & merged_df["chr_type"].isin(keep_chr)
    & merged_df["condition"].isin(CONDITIONS_TO_INCLUDE)
].copy()
print_debug(f"After type/chr_type/condition filter: {len(df)} reads remain.")

# ── NEW second_type logic ────────────────────────────────────────────────
# 1) add numeric chip_rank (may be NaN for intergenic)
df["chip_rank"] = df["type"].map(chip_rank_lookup)

# 2) assign second_type:
df["second_type"] = np.where(
    df["type"] == "intergenic_control",
    "intergenic_control",
    np.where(df["chip_rank"].fillna(-np.inf) >= CHIP_RANK_CUTOFF, "gt80", "lt80")
)

# 3) rename type for intergenic reads ▸  intergenic_control_<bed_start>
if INCLUDE_INTERGENIC:
    mask_intergenic = df["type"] == "intergenic_control"
    df.loc[mask_intergenic, "type"] = (
        df.loc[mask_intergenic, "type"]
        + "_"
        + df.loc[mask_intergenic, "bed_start"].astype(str)
    )

# chip_rank no longer needed
df.drop(columns="chip_rank", inplace=True)
print_debug("Assigned 'second_type'; renamed intergenic types if applicable.")

# Filter reads by overlap with ±REL_POS_RANGE
print_debug("Filtering reads by overlap with rel_pos window…")
# Ensure 'rel_pos' column contains NumPy arrays or lists that can be converted
# and that they are not empty before applying boolean operations.
mask_overlap = df["rel_pos"].apply(lambda arr: isinstance(arr, (list, np.ndarray)) and len(arr) > 0 and ((np.asarray(arr) >= -REL_POS_RANGE) & (np.asarray(arr) <= REL_POS_RANGE)).any())
df = df[mask_overlap].reset_index(drop=True)
print_debug(f"After overlap filter: {len(df)} reads remain.")

# Drop reads without any methylation data
print_debug("Dropping reads with no methylation marks…")
mask_valid = df["mod_qual_bin"].apply(lambda x: isinstance(x, (list, np.ndarray)) and len(x) > 0 and np.nansum(x) > 0)
df = df[mask_valid].reset_index(drop=True)
print_debug(f"After methylation filter: {len(df)} reads remain.")

# ───────── OPTIONAL BALANCING OF READ COUNTS PER (type, condition) ─────────
if DOWNSAMPLE_BALANCE and not df.empty:
    print_debug("Balancing read counts per (type, condition)…")
    rng = np.random.RandomState(RANDOM_STATE)

    # ── A) snapshot *initial* counts  ────────────────────────────────────────
    init_counts_df = (
        df.groupby(["type", "condition"])
          .size()
          .reset_index(name="initial_reads")
    )

    # 1) size table
    size_tbl = (
        df.groupby(["type", "condition"])
          .size()
          .unstack(fill_value=0)
    )

    # 2) drop types missing a condition
    types_to_drop = size_tbl[(size_tbl == 0).any(axis=1)].index
    if len(types_to_drop):
        print_debug(f"Dropping {len(types_to_drop)} type(s) with zero reads in ≥1 condition: "
                    f"{', '.join(map(str, types_to_drop))}")
        df = df[~df["type"].isin(types_to_drop)].reset_index(drop=True)
        size_tbl = size_tbl.loc[~size_tbl.index.isin(types_to_drop)]

    # 3) down-sample each (type,condition) to equal n
    balanced_parts = []
    for typ, sub_typ in df.groupby("type"):
        min_n = sub_typ.groupby("condition").size().min()
        for cond, sub_cond in sub_typ.groupby("condition"):
            if len(sub_cond) > min_n:
                keep_idx = rng.choice(sub_cond.index, size=min_n, replace=False)
                balanced_parts.append(sub_cond.loc[keep_idx])
            else:
                balanced_parts.append(sub_cond)
    df = (
        pd.concat(balanced_parts, ignore_index=True)
          .sample(frac=1.0, random_state=rng)    # optional shuffle
          .reset_index(drop=True)
    )
    print_debug(f"After balancing: {len(df)} reads remain.")

    # ── B) snapshot *balanced* counts & merge into summary table ─────────────
    bal_counts_df = (
        df.groupby(["type", "condition"])
          .size()
          .reset_index(name="balanced_reads")
    )

    downsample_summary_df = (
        pd.merge(init_counts_df, bal_counts_df,
                 on=["type", "condition"], how="outer")
          .fillna(0)                             # combos that were dropped → 0
          .astype({"initial_reads":"int", "balanced_reads":"int"})
          .sort_values(["type", "condition"])
          .reset_index(drop=True)
    )

    # (Optional) write to disk so later cells don’t need to recompute
    downsample_csv = os.path.join(OUTPUT_DIR, "downsample_summary.csv")


# ────────────────── 3) ASSIGN BOOTSTRAP GROUPS ──────────────────
print_debug("Assigning bootstrap groups…")
rng = np.random.RandomState(RANDOM_STATE)
group_labels = np.arange(N_BOOTSTRAP_GROUPS)

def assign_bootstrap_groups(subdf):
    n = len(subdf)
    if n == 0: # Handle empty sub-dataframes
        # Add the column if it doesn't exist, or ensure it's present
        if "bootstrap_group" not in subdf.columns:
            subdf = subdf.assign(bootstrap_group=pd.Series(dtype=int))
        return subdf.copy()
    labels = np.tile(group_labels, int(np.ceil(n / N_BOOTSTRAP_GROUPS)))[:n]
    rng.shuffle(labels)
    subdf = subdf.copy()
    subdf["bootstrap_group"] = labels
    return subdf

if not df.empty:
    df = (
        df
        .groupby(["second_type", "condition"], group_keys=False)
        .apply(assign_bootstrap_groups)
        .reset_index(drop=True)
    )
elif "bootstrap_group" not in df.columns : # Ensure column exists if df was empty initially
    df["bootstrap_group"] = pd.Series(dtype=int)
print_debug("Bootstrap groups assigned.")

# ────────────────── 4) FLATTEN READS INTO POSITION-LEVEL ROWS (PARALLEL) ──────────────────
print_debug("Flattening reads into position-level rows using multiprocessing…")

def flatten_single_read(rec_dict):
    """
    Given a single read (as a dict), return a list of dicts:
    each with bootstrap_group, second_type, condition, type, rel_pos, mod
    """
    out = []
    # Ensure rel_pos and mod_qual_bin are numpy arrays for processing
    rel_pos_arr = np.asarray(rec_dict["rel_pos"], dtype=int)
    mods_arr    = np.asarray(rec_dict["mod_qual_bin"], dtype=float)

    for pos, m in zip(rel_pos_arr, mods_arr):
        if abs(pos) <= REL_POS_RANGE and not np.isnan(m):
            out.append({
                "bootstrap_group": rec_dict["bootstrap_group"],
                "second_type":     rec_dict["second_type"],
                "condition":       rec_dict["condition"],
                "type":            rec_dict["type"],
                "chr_type":        rec_dict["chr_type"],
                "rel_pos":         int(pos),
                "mod":             int(m)  # 1 or 0
            })
    return out

# Convert DataFrame to list of plain dicts for picklable objects
records = df.to_dict(orient="records")

if records: # Only run multiprocessing if there are records to process
    with mp.Pool(N_WORKERS) as pool:
        all_lists = list(tqdm(
            pool.imap_unordered(flatten_single_read, records),
            total=len(records),
            desc="Flattening reads"
        ))
    flat_records = [item for sublist in all_lists for item in sublist]
else:
    flat_records = []

if flat_records:
    flat_df = pd.DataFrame(flat_records)
else: # Create empty DataFrame with correct columns if no records resulted from flattening
    flat_df = pd.DataFrame(columns=["bootstrap_group", "second_type", "condition", "type","chr_type", "rel_pos", "mod"])
print_debug(f"Flattened DataFrame has {len(flat_df)} rows.")


# ────────────────── 5) COMPUTE METHYLATION FRACTION BY POSITION ──────────────────
print_debug("Computing methylation fraction by position…")
if not flat_df.empty:
    grouped = (
        flat_df
        .groupby(["bootstrap_group", "second_type", "condition", "type","chr_type","rel_pos"])
        .agg(total_reads  = ("mod", "size"),
             methylated   = ("mod", "sum"))
        .reset_index()
    )
    grouped["frac"] = grouped["methylated"] / grouped["total_reads"]
else: # Create empty grouped DataFrame with correct columns
    grouped = pd.DataFrame(columns=["bootstrap_group", "second_type", "condition", "type","chr_type", "rel_pos", "total_reads", "methylated", "frac"])
print_debug(f"Methylation grouping result has {len(grouped)} rows.")

# Prepare the full position index & all combos
all_positions = np.arange(-REL_POS_RANGE, REL_POS_RANGE + 1)
if not grouped.empty:
    combos = (
        grouped[["bootstrap_group", "second_type", "condition", "type","chr_type"]]
        .drop_duplicates()
        .reset_index(drop=True)
    )
else: # Create empty combos DataFrame
    combos = pd.DataFrame(columns=["bootstrap_group", "second_type", "condition", "type","chr_type"])
print_debug(f"Total (bootstrap_group, second_type, condition, type) combos: {len(combos)}")


# ────────────────── 6) DEFINE PEAK COMPUTATION FUNCTION ──────────────────
def compute_peak_for_combo(args):
    """
    Compute peak_loc, peak_height, peak_fwhm for one
    (bootstrap_group, second_type, condition, type).
    Smoothing now uses a rolling mean that ignores NaNs,
    leaving fully-NaN windows as gaps.
    """
    bg, st, cond, typ, chr_typ= args
    sub = grouped[
        (grouped["bootstrap_group"] == bg)
        & (grouped["second_type"]     == st)
        & (grouped["condition"]       == cond)
        & (grouped["type"]            == typ)
    ]

    # 1) build raw fraction series with NaN default
    methyl_series = pd.Series(index=all_positions, dtype=float)
    if not sub.empty:
        # assign only positions where we have reads
        methyl_series.loc[sub["rel_pos"].values] = sub["frac"].values

    # 2a) smoothed signal for PEAK DETECTION / FWHM (uses SMOOTH_WINDOW)
    if SMOOTH_WINDOW > 1:
        methyl_smooth_detect = (
            methyl_series.rolling(window=SMOOTH_WINDOW,
                                  center=True, min_periods=1)
                         .mean()
                         .values
        )
    else:
        methyl_smooth_detect = methyl_series.values

    # 2b) second smoothed signal for METRICS (uses SMOOTH_WINDOW_METRICS)
    if SMOOTH_WINDOW_METRICS > 1:
        methyl_smooth_metrics = (
            methyl_series.rolling(window=SMOOTH_WINDOW_METRICS,
                                  center=True, min_periods=1)
                         .mean()
                         .values
        )
    else:
        methyl_smooth_metrics = methyl_series.values

    # 3) ── find peaks on the *detection* profile
    peaks, props = find_peaks(methyl_smooth_detect, height=0)
    if len(peaks) == 0:
        return {
            "bootstrap_group": bg, "second_type": st, "condition": cond,
            "type": typ, "chr_type": chr_typ,"peak_loc": np.nan, "peak_height": np.nan,
            "peak_fwhm": np.nan, "peak_area": np.nan,
            "peak_mean_m6A": np.nan
        }

    peak_heights = props["peak_heights"]
    sorted_idxs  = np.argsort(peak_heights)[-5:]
    top_peaks    = peaks[sorted_idxs]

    # ── choose best peak (same logic as before) ────────────────────────────
    if PEAK_SELECTION_METHOD == "height":
        chosen_idx_rel = sorted_idxs[np.argmax(peak_heights[sorted_idxs])]
    elif PEAK_SELECTION_METHOD == "area":
        areas = []
        for rel_idx in sorted_idxs:
            p   = peaks[rel_idx]
            h   = methyl_smooth_detect[p]
            half = h / 2.0
            li, ri = p, p
            while li > 0 and methyl_smooth_detect[li] >= half: li -= 1
            while ri < len(methyl_smooth_detect)-1 and methyl_smooth_detect[ri] >= half: ri += 1
            areas.append(methyl_smooth_detect[li:ri+1].sum())
        chosen_idx_rel = sorted_idxs[np.argmax(areas)]
    elif PEAK_SELECTION_METHOD == "closest_to_zero":
        locs = all_positions[top_peaks]
        chosen_idx_rel = sorted_idxs[np.argmin(np.abs(locs))]
    else:
        raise ValueError(f"Unknown method: {PEAK_SELECTION_METHOD}")

    # 4) ── metrics ---------------------------------------------------------
    peak_idx    = peaks[chosen_idx_rel]
    peak_loc    = all_positions[peak_idx]

    # use *metrics* smoothing for peak height
    peak_height = methyl_smooth_metrics[peak_idx]

    # FWHM boundaries based on detection smoothing
    half_val = methyl_smooth_detect[peak_idx] / 2.0
    li, ri   = peak_idx, peak_idx
    while li > 0 and methyl_smooth_detect[li] >= half_val: li -= 1
    while ri < len(methyl_smooth_detect)-1 and methyl_smooth_detect[ri] >= half_val: ri += 1
    left_pos, right_pos = all_positions[li], all_positions[ri]

    # area & mean m6A use *metrics* smoothing within those boundaries
    window_vals   = methyl_smooth_metrics[li:ri+1]
    peak_area     = np.nansum(window_vals)
    peak_mean_m6A = np.nanmean(window_vals)

    return {
        "bootstrap_group": bg, "second_type": st, "condition": cond,
        "type": typ,
        "chr_type": chr_typ,
        "peak_loc":      peak_loc,
        "peak_height":   peak_height,
        "peak_fwhm":     right_pos - left_pos,
        "peak_area":     peak_area,
        "peak_mean_m6A": peak_mean_m6A
    }


# ────────────────── 7) PARALLEL PEAK METRICS COMPUTATION ──────────────────
print_debug("Starting parallel peak computation…")
# Ensure combos is a list of tuples for imap_unordered
if isinstance(combos, pd.DataFrame) and not combos.empty:
    combo_list = list(combos.itertuples(index=False, name=None))
elif isinstance(combos, list): # If it's already a list (e.g. from previous empty handling)
    combo_list = combos
else: # Handle case where combos might be an empty DataFrame or other
    combo_list = []


if combo_list: # Only proceed if there are combos to compute
    with mp.Pool(N_WORKERS) as pool:
        results = list(tqdm(
            pool.imap_unordered(compute_peak_for_combo, combo_list),
            total=len(combo_list),
            desc="Computing peaks"
        ))
    peaks_df = pd.DataFrame(results)
else: # Create an empty DataFrame with the expected columns if no results
    peaks_df = pd.DataFrame(columns=["bootstrap_group", "second_type", "condition", "type",
        "peak_loc", "peak_height", "peak_fwhm",
        "peak_area", "peak_mean_m6A"])

print_debug(f"Computed peak metrics for {len(peaks_df)} combos.")

# Each row corresponds to exactly one (bootstrap_group, second_type, condition, type)
if not peaks_df.empty:
    assert peaks_df.groupby(["bootstrap_group", "second_type", "condition", "type"]).size().max() == 1, \
        "Peak computation resulted in duplicate entries for a combo."


# ────────────────── 8) FILTER PEAKS FOR SELECTED CONDITIONS ──────────────────
if not peaks_df.empty:
    peaks_df = peaks_df[peaks_df["condition"].isin(CONDITIONS_TO_INCLUDE)].reset_index(drop=True)
print_debug(f"Peaks_df now has {len(peaks_df)} rows after filtering")

# ────────────────── MERGE METRICS INTO DOWNSAMPLE SUMMARY ──────────────────
# ────────────────── BUILD BOOTSTRAP-LEVEL SUMMARY ──────────────────
if DOWNSAMPLE_BALANCE and not peaks_df.empty:
    print_debug("Creating bootstrap-level down-sampling summary…")

    # A) balanced read counts *per bootstrap group*
    bal_counts_bg = (
        df.groupby(["type", "condition", "bootstrap_group"])
          .size()
          .reset_index(name="balanced_reads")
    )

    # B) peak-level metrics (already one row per bootstrap group)
    metrics_bg = peaks_df[[
        "type", "condition", "bootstrap_group",
        "peak_fwhm", "peak_loc",
        "peak_height", "peak_area", "peak_mean_m6A"
    ]]

    # C) merge → one row per (type, condition, bootstrap_group)
    downsample_summary_bg = (
        bal_counts_bg.merge(metrics_bg,
                            on=["type", "condition", "bootstrap_group"],
                            how="left")
                     .sort_values(["type", "condition", "bootstrap_group"])
                     .reset_index(drop=True)
    )

    # D) write to disk
    downsample_csv = os.path.join(OUTPUT_DIR, "downsample_summary.csv")
    downsample_summary_bg.to_csv(downsample_csv, index=False)
    print_debug(f"Saved bootstrap-level summary to {downsample_csv}")


In [None]:
### CELL 2 · PLOTTING FWHM & MEAN-m6A  ───────────────────────────────────────
import os, itertools, numpy as np, pandas as pd
from scipy.stats import ttest_ind
import plotly.graph_objects as go

# peaks_df, CONDITIONS_TO_INCLUDE, DEFAULT_READ_CLRS, PLOT_TEMPLATE,
# print_debug are already in memory (created in CELL 1)

# ────────────────── REUSABLE BOXPLOT MAKER ──────────────────
def make_boxplot(
    y_col: str,
    title: str,
    yaxis_label: str,
    file_stub: str,
    y_range=None
):
    """Grouped box-plot split by (second_type × chr_type)."""
    print_debug(f"Generating box plot for '{y_col}' …")
    fig = go.Figure()

    # ── 0) create a composite x-label in peaks_df just for plotting
    peaks_df["cat_x"] = (
        peaks_df["second_type"] + " | " + peaks_df["chr_type"]
    )

    # ── 1) traces
    for idx, cond_val in enumerate(CONDITIONS_TO_INCLUDE):
        color = DEFAULT_READ_CLRS[idx % len(DEFAULT_READ_CLRS)]
        sub   = peaks_df[peaks_df["condition"] == cond_val]
        if sub.empty:
            continue
        fig.add_trace(
            go.Box(
                x=sub["cat_x"], y=sub[y_col],
                name=cond_val,
                marker=dict(color=color, line=dict(width=1)),
                fillcolor="rgba(0,0,0,0)",
                boxmean=False, boxpoints="all",
                jitter=0.5, pointpos=0,
                marker_symbol="circle", marker_size=2,
                marker_color=color, marker_opacity=0.5,
                marker_line_width=0,
                legendgroup=cond_val,
                alignmentgroup="cat_x",
                offsetgroup=str(idx)
            )
        )

    # ── 2) layout
    fig.update_layout(
        template=PLOT_TEMPLATE,
        title=title,
        xaxis_title="Second_Type | Chr_Type",
        yaxis_title=yaxis_label,
        boxmode="group",
        width=700, height=650
    )
    if y_range is not None:
        fig.update_yaxes(range=y_range)

    # ── 3) ordering & labels ────────────────────────────────────────
    # order second_type first, then chr_type alphabetically
    sec_order = ["gt80", "lt80", "intergenic_control"]
    ordered_cats = []
    for sec in sec_order:
        chrs_here = sorted(
            peaks_df.loc[peaks_df["second_type"] == sec, "chr_type"].unique()
        )
        ordered_cats.extend([f"{sec} | {c}" for c in chrs_here])

    # compute counts per composite category
    comp_counts = (
        peaks_df
        .groupby(["second_type","chr_type"])["type"]
        .nunique()
        .to_dict()
    )

    # base labels for second_type
    base_labels = {
        "gt80":               "80th ChIP Percentile",
        "lt80":               "<80th ChIP Percentile",
        "intergenic_control": "Intergenic"
    }

    # build tickvals/ticktext
    tickvals = ordered_cats
    ticktext = []
    for cat in ordered_cats:
        sec, chr_t = cat.split(" | ")
        n = comp_counts.get((sec, chr_t), 0)
        ticktext.append(f"{base_labels[sec]} | {chr_t} (n={n})")

    fig.update_xaxes(
        categoryorder="array",
        categoryarray=ordered_cats,
        tickmode="array",
        tickvals=tickvals,
        ticktext=ticktext
    )

    # mapping cat_x ⇢ position
    cat_to_num = {c:i for i,c in enumerate(ordered_cats)}
    total_conds = len(CONDITIONS_TO_INCLUDE)
    OFFSET_STEP = 0.2
    cond_offsets = {c:(i-(total_conds-1)/2)*OFFSET_STEP
                    for i,c in enumerate(CONDITIONS_TO_INCLUDE)}

    # ── 4) significance tests per combined category
    shapes, stats_rows = [], []
    for cat in ordered_cats:
        sub_cat = peaks_df[peaks_df["cat_x"] == cat]
        if sub_cat.empty: continue

        y_min, y_max = sub_cat[y_col].min(), sub_cat[y_col].max()
        delta = max((y_max - y_min) * 0.05, 1e-3)
        comp_idx = 0

        for cond1, cond2 in itertools.combinations(CONDITIONS_TO_INCLUDE, 2):
            v1 = sub_cat[sub_cat["condition"]==cond1][y_col].dropna()
            v2 = sub_cat[sub_cat["condition"]==cond2][y_col].dropna()
            stat, pval = (ttest_ind(v1, v2, equal_var=False)
                          if len(v1)>1 and len(v2)>1 else (np.nan, np.nan))

            lbl = ("***" if pval < 0.001 else
                   "**"  if pval < 0.01  else
                   "*"   if pval < 0.05  else "ns") if pd.notna(pval) else "ns"

            if lbl in ("*","**","***"):
                y_pos = y_max + delta + comp_idx*delta
                bar_y = y_max + (delta/2) + comp_idx*delta
                cat_x_num = cat_to_num[cat]
                x1 = cat_x_num + cond_offsets[cond1]
                x2 = cat_x_num + cond_offsets[cond2]
                shapes.append(dict(type="line", x0=x1, x1=x2,
                                   y0=bar_y, y1=bar_y,
                                   xref="x", yref="y",
                                   line=dict(color="black", width=1)))
                fig.add_annotation(x=(x1+x2)/2, y=y_pos, text=lbl,
                                   showarrow=False, font=dict(size=14),
                                   xref="x", yref="y")
                comp_idx += 1

            stats_rows.append({
                "cat_x": cat, "condition_1": cond1, "condition_2": cond2,
                "n1": len(v1), "n2": len(v2),
                "mean1": v1.mean() if len(v1) else np.nan,
                "mean2": v2.mean() if len(v2) else np.nan,
                "t_statistic": stat, "p_value": pval, "label": lbl
            })

    if shapes:
        fig.update_layout(shapes=shapes)

    # ── 5) save + show
    fig.write_image(os.path.join(OUTPUT_DIR, f"{file_stub}.png"))
    fig.write_image(os.path.join(OUTPUT_DIR, f"{file_stub}.svg"))
    pd.DataFrame(stats_rows).to_csv(
        os.path.join(OUTPUT_DIR, f"{file_stub}_stats.csv"), index=False
    )
    print_debug(f"Saved {file_stub} plot & stats.")
    fig.show()


# ────────────────── PLOT 1: ORIGINAL FWHM ──────────────────
if not peaks_df.empty:
    make_boxplot(
        y_col="peak_fwhm",
        title="FWHM Distribution by Second_Type and Condition",
        yaxis_label="Full Width at Half-Max (bp)",
        file_stub="peaks_fwhm_boxplot",
        y_range=[50, 650]          # keep original explicit range
    )

    # ────────────────── PLOT 2: MEAN m6A WITHIN FWHM ──────────────────
    # (assumes 'peak_mean_m6A' added in CELL 1)
    make_boxplot(
        y_col="peak_mean_m6A",
        title="Mean m6A (within FWHM) by Second_Type and Condition",
        yaxis_label="Average m6A Fraction",
        file_stub="peaks_meanm6A_boxplot",
        y_range=None               # let Plotly auto-scale
    )
    
        # (assumes 'peak_mean_m6A' added in CELL 1)
    make_boxplot(
        y_col="peak_height",
        title="m6A peak (within FWHM)",
        yaxis_label="m6A Peak",
        file_stub="peak_m6A_boxplot",
        y_range=[0,0.35],               # let Plotly auto-scale
        
    )
    
        # (assumes 'peak_mean_m6A' added in CELL 1)
    make_boxplot(
        y_col="peak_area",
        title="Peak integrated area",
        yaxis_label="Peak integrated area",
        file_stub="integrated_area_boxplot",
        y_range=None               # let Plotly auto-scale
    )
else:
    print_debug("peaks_df is empty — skipping plot generation.")


In [None]:
import numpy as np
import pandas as pd
from scipy.signal import find_peaks
import plotly.graph_objects as go
from tqdm.auto import tqdm

# Assumes `grouped` and `peaks_df` already exist in the notebook’s namespace,
# computed by the previous cell.

# ────────────────── DEBUG PLOT: MULTI‐TRACE PER CONDITION/TYPE ──────────────────

# 1) Gather all unique (condition, type) combinations present in peaks_df
combos = peaks_df[["condition", "type"]].drop_duplicates().reset_index(drop=True)
combo_tuples = [tuple(x) for x in combos.values]

all_positions = np.arange(-REL_POS_RANGE, REL_POS_RANGE + 1)

# We will build a single Figure with:
# - For each (cond, typ, bootstrap_group), one line for methyl_smooth
# - For each such combo, a vertical dashed line at peak_loc
# - For each such combo, a horizontal dotted line at half-max (approx. using FWHM)

# Keep track of which trace indices belong to each (cond, typ) pair
trace_indices_by_combo = {}

fig = go.Figure()
trace_counter = 0

for cond, typ in combo_tuples:
    # Filter peaks_df to get only rows matching this (cond, typ)
    sub_peaks = peaks_df[
        (peaks_df["condition"] == cond) &
        (peaks_df["type"] == typ)
    ]
    # If none, skip
    if sub_peaks.empty:
        continue

    indices = []  # will collect trace indices for this combo

    # For each bootstrap_group in this subset
    for bg in sorted(sub_peaks["bootstrap_group"].unique()):
        row = sub_peaks[sub_peaks["bootstrap_group"] == bg].iloc[0]
        peak_loc = row["peak_loc"]
        peak_height = row["peak_height"]
        peak_fwhm = row["peak_fwhm"]

        # Reconstruct methyl_smooth array for this (bg, typ, cond)
        sub_group = grouped[
            (grouped["bootstrap_group"] == bg)
            & (grouped["condition"] == cond)
            & (grouped["type"] == typ)
        ]
        methyl_series = pd.Series(0.0, index=all_positions)
        if not sub_group.empty:
            methyl_series.update(sub_group.set_index("rel_pos")["frac"])
        methyl_array = methyl_series.values

        # Apply moving‐average smoothing over SMOOTH_WINDOW bp
        if SMOOTH_WINDOW > 1:
            kernel = np.ones(SMOOTH_WINDOW) / SMOOTH_WINDOW
            methyl_smooth = np.convolve(methyl_array, kernel, mode="same")
        else:
            methyl_smooth = methyl_array

        # 1a) Plot methyl_smooth vs all_positions
        fig.add_trace(
            go.Scatter(
                x=all_positions,
                y=methyl_smooth,
                mode="lines",
                name=f"BG {bg}",
                line=dict(width=1),
                visible=False
            )
        )
        indices.append(trace_counter)
        trace_counter += 1

        # 1b) Vertical dashed line at peak_loc (from y=0 to y=peak_height)
        fig.add_trace(
            go.Scatter(
                x=[peak_loc, peak_loc],
                y=[0, peak_height],
                mode="lines",
                line=dict(color="black", dash="dash"),
                showlegend=False,
                visible=False
            )
        )
        indices.append(trace_counter)
        trace_counter += 1

        # 1c) Horizontal dotted line at half‐max over [left_pos, right_pos]
        # Approximate left/right as peak_loc ± peak_fwhm/2
        half_val = peak_height / 2.0
        left_pos = peak_loc - peak_fwhm / 2.0
        right_pos = peak_loc + peak_fwhm / 2.0
        fig.add_trace(
            go.Scatter(
                x=[left_pos, right_pos],
                y=[half_val, half_val],
                mode="lines",
                line=dict(color="black", dash="dot"),
                showlegend=False,
                visible=False
            )
        )
        indices.append(trace_counter)
        trace_counter += 1

    trace_indices_by_combo[(cond, typ)] = indices

# 2) Build dropdown buttons so user can pick (condition, type)
buttons = []
for (cond, typ), idx_list in trace_indices_by_combo.items():
    visibility = [False] * trace_counter
    for i in idx_list:
        visibility[i] = True
    buttons.append(
        dict(
            label=f"{cond}   |   {typ}",
            method="update",
            args=[
                {"visible": visibility},
                {"title": f"Condition: {cond} | Type: {typ}"}
            ]
        )
    )

# 3) Initialize the figure with the first combo visible by default
if buttons:
    first_combo = list(trace_indices_by_combo.keys())[0]
    # set that combo’s traces to visible=True
    first_vis = [False] * trace_counter
    for i in trace_indices_by_combo[first_combo]:
        first_vis[i] = True
    fig.data = [trace.update(visible=first_vis[i]) for i, trace in enumerate(fig.data)]
    fig.update_layout(title=f"Condition: {first_combo[0]} | Type: {first_combo[1]}")

# 4) Add dropdown menu to layout
fig.update_layout(
    updatemenus=[
        dict(
            active=0,
            buttons=buttons,
            x=0.0,
            y=1.1,
            xanchor="left",
            yanchor="top",
            direction="down"
        )
    ],
    template="plotly_white",
    xaxis_title="Relative Position (bp)",
    yaxis_title="Methylation Fraction (smoothed)",
    width = 800
)

fig.show()


In [None]:
# set### Cross Correlation code
import numpy as np
import pandas as pd
import multiprocessing as mp
from functools import partial
from scipy.ndimage import gaussian_filter1d
from tqdm.auto import tqdm
import plotly.graph_objects as go
import plotly.express as px
from itertools import combinations
from scipy.stats import pearsonr
from plotly.subplots import make_subplots

# ────────────────── 1) CONFIGURATION ──────────────────
CHIP_RANK_CUTOFF       = 90
ABOVE_FLAG             = True
TYPES_TO_INCLUDE       = ["MOTIFS_rex48"]   # e.g. ["MOTIFS_rex14", ...]
CHR_TYPE_INCLUDE       = []

# Pre‐existing conditions list
CONDITIONS_TO_INCLUDE  = [
    analysis_cond[0],
    analysis_cond[1],
    analysis_cond[2],
    analysis_cond[3]
]

MIN_READ_LENGTH        = 500
REQUIRE_CENTRAL        = False

ZSCORE_BEFORE_AC       = False

PERFORM_FILLING        = True
MET_DOMAIN_WIDTH       = 9

PERFORM_INTERP         = True
INTERP_WINDOW          = 20

RAW_SMOOTH_KIND        = "gaussian" # gaussian, moving
RAW_MOVING_WIN         = 50
RAW_GAUSS_SIGMA        = 5
RAW_CLAMP_01           = False

REL_POS_RANGE          = 2000

# ── REL_POS BIN WIDTH ──
REL_POS_BIN_WIDTH      = 200
CENTER_BINS            = list(range(
    -REL_POS_RANGE,
     REL_POS_RANGE + 1,
     REL_POS_BIN_WIDTH
))

WINDOW_SIZE            = 100

MAX_READS_PER_COND     = 0
RANDOM_STATE           = 42

N_WORKERS              = max(2, mp.cpu_count() - 2)
CHUNK_SIZE             = 10
DEBUG                  = True  # only used to gate the “EXCESS CORR” print

PLOT_TEMPLATE          = "plotly_white"
BOX_WIDTH              = 0.4    # width for violin traces


def print_debug(msg: str):
    if DEBUG:
        print(f"[DEBUG] {msg}")


# ────────────────── 2) FILTERING ──────────────────
chiprank_df = (
    pd.read_csv("/Data1/reference/rex_chiprank.bed", sep=r"\s+")
      .assign(type=lambda d: "MOTIFS_" + d["type"].astype(str))
)
chip_rank_lookup = {
    t: round(float(rk) * 100, 3)
    for t, rk in zip(chiprank_df["type"], chiprank_df["chip_rank"])
}

keep_conds = set(CONDITIONS_TO_INCLUDE)
keep_types = (
    set(TYPES_TO_INCLUDE)
    if TYPES_TO_INCLUDE
    else {t for t, r in chip_rank_lookup.items()
          if (r >= CHIP_RANK_CUTOFF) == ABOVE_FLAG}
)
keep_chr = (
    set(CHR_TYPE_INCLUDE)
    if CHR_TYPE_INCLUDE
    else set(merged_df["chr_type"].unique())
)

df0 = merged_df.query(
    "condition in @keep_conds and type in @keep_types and chr_type in @keep_chr"
).copy()
print(f"[INFO] after metadata filter: {len(df0)} reads")

mask_overlap = df0["rel_pos"].apply(
    lambda arr: ((arr >= -REL_POS_RANGE) & (arr <= REL_POS_RANGE)).any()
)
df0 = df0[mask_overlap].reset_index(drop=True)
print(f"[INFO] after overlap filter:  {len(df0)} reads")

if REQUIRE_CENTRAL:
    half = MIN_READ_LENGTH // 2
    mask_central = df0["rel_pos"].apply(
        lambda arr: (arr.min() <= -half) and (arr.max() >= half)
    )
    filtered_reads_df = df0[mask_central].reset_index(drop=True)
else:
    filtered_reads_df = df0[df0["read_length"] >= MIN_READ_LENGTH].reset_index(drop=True)
print(f"[INFO] after length/central filter: {len(filtered_reads_df)} reads")

mask_valid = filtered_reads_df["mod_qual_bin"].apply(
    lambda x: isinstance(x, (list, np.ndarray)) and np.nansum(x) > 0
)
filtered_reads_df = filtered_reads_df[mask_valid].reset_index(drop=True)
print(f"[INFO] after methylation filter: {len(filtered_reads_df)} reads")


# ────────────────── 3) OPTIONAL DOWNSAMPLING ──────────────────
if MAX_READS_PER_COND > 0:
    reads_to_process = (
        filtered_reads_df
          .groupby("condition", group_keys=False)
          .apply(lambda df: df.sample(
              n=min(len(df), MAX_READS_PER_COND),
              random_state=RANDOM_STATE
          ))
          .reset_index(drop=True)
    )
    print_debug(f"Down‑sampled: {len(reads_to_process)} / {len(filtered_reads_df)} reads")
else:
    reads_to_process = filtered_reads_df.copy()
    print_debug(f"Using all {len(reads_to_process)} reads")

records = reads_to_process.to_dict(orient="records")


# ────────────────── 4) SMOOTHING HELPERS ──────────────────
def _fill_met_domains(arr, width):
    idx = np.where(arr == 1)[0]
    if len(idx) < 2:
        return arr
    out = arr.copy()
    for a, b in zip(idx, idx[1:]):
        if b - a <= width:
            out[a : b + 1] = 1
    return out

def _mode_interpolate(arr, radius):
    isnan = np.isnan(arr)
    if not isnan.any():
        return arr
    valid = (~isnan).astype(int)
    ones  = ((arr == 1) & ~isnan).astype(int)
    c_val = np.concatenate(([0], np.cumsum(valid)))
    c_one = np.concatenate(([0], np.cumsum(ones)))
    def count(c_vec, i):
        lo, hi = max(0, i - radius), min(len(arr) - 1, i + radius)
        return c_vec[hi + 1] - c_vec[lo]
    out = arr.copy()
    for i in np.where(isnan)[0]:
        tot = count(c_val, i)
        out[i] = 0.0 if tot == 0 else (1.0 if count(c_one, i) > tot / 2 else 0.0)
    return out

def _apply_smoothing(y, kind, moving_win, gauss_sigma, clamp):
    if kind == "moving" and moving_win > 1:
        y = np.convolve(y, np.ones(moving_win)/moving_win, mode="same")
    elif kind == "gaussian" and gauss_sigma > 0:
        y = gaussian_filter1d(y, sigma=gauss_sigma, mode="nearest")
    return np.clip(y, 0, 1) if clamp else y

raw_smooth = partial(
    _apply_smoothing,
    kind=RAW_SMOOTH_KIND,
    moving_win=RAW_MOVING_WIN,
    gauss_sigma=RAW_GAUSS_SIGMA,
    clamp=RAW_CLAMP_01
)


# ────────────────── 5) EXTRACT & SMOOTH WINDOWS ──────────────────
def extract_and_smooth(rec):
    read_id = rec["read_id"]
    cond    = rec["condition"]
    rtype   = rec["type"]
    rel_pos = np.asarray(rec["rel_pos"], dtype=int)
    signal  = np.asarray(rec["mod_qual_bin"], dtype=float)
    pos_to_idx = {p: i for i, p in enumerate(rel_pos)}

    rows = []
    min_center = rel_pos.min() + WINDOW_SIZE
    max_center = rel_pos.max() - WINDOW_SIZE

    for center in CENTER_BINS:
        vec = np.full(2 * WINDOW_SIZE + 1, np.nan)

        if center < min_center or center > max_center:
            rows.append({
                "read_id":     read_id,
                "type":        rtype,
                "condition":   cond,
                "center":      center,
                "vec_smoothed": None,
                "has_marks":   False
            })
            continue

        for p, idx in pos_to_idx.items():
            offset = p - center
            if -WINDOW_SIZE <= offset <= WINDOW_SIZE:
                vec[offset + WINDOW_SIZE] = signal[idx]

        valid_pts = ~np.isnan(vec)
        if valid_pts.sum() == 0:
            rows.append({
                "read_id":     read_id,
                "type":        rtype,
                "condition":   cond,
                "center":      center,
                "vec_smoothed": None,
                "has_marks":   False
            })
            continue

        if PERFORM_FILLING:
            vec = _fill_met_domains(vec, MET_DOMAIN_WIDTH)
        if PERFORM_INTERP:
            vec = _mode_interpolate(vec, INTERP_WINDOW)
        vec = raw_smooth(vec)

        valid_pts = ~np.isnan(vec)
        if ZSCORE_BEFORE_AC and valid_pts.sum() > 1:
            m, s = np.nanmean(vec[valid_pts]), np.nanstd(vec[valid_pts])
            vec[valid_pts] = (vec[valid_pts] - m)/s if s > 0 else 0.0

        rows.append({
            "read_id":     read_id,
            "type":        rtype,
            "condition":   cond,
            "center":      center,
            "vec_smoothed": vec,
            "has_marks":   True
        })

    return pd.DataFrame(rows)


print_debug(f"Launching pool with {N_WORKERS} workers for smoothing")
with mp.Pool(N_WORKERS) as pool:
    it = pool.imap_unordered(extract_and_smooth, records, chunksize=CHUNK_SIZE)
    dfs = [df for df in tqdm(it, total=len(records), desc="smoothing")]
all_windows = pd.concat(dfs, ignore_index=True)
print(f"[INFO] all_windows shape: {all_windows.shape}")


# ────────────────── 6) PERCENT WITH ZERO‑METHYLATION PER BIN ──────────────────
total_by_bin = all_windows.groupby(["condition", "center"]).size().reset_index(name="total")
no_mark_by_bin = all_windows[all_windows["has_marks"] == False] \
                  .groupby(["condition", "center"]).size().reset_index(name="no_marks")
pct_zero_df = pd.merge(total_by_bin, no_mark_by_bin, on=["condition", "center"], how="left")
pct_zero_df["no_marks"] = pct_zero_df["no_marks"].fillna(0).astype(int)
pct_zero_df["pct_zero"] = pct_zero_df["no_marks"] / pct_zero_df["total"]


# ────────────────── 7) FILTER OUT NO‑MARK WINDOWS FOR CORRELATION ──────────────────
windows_with_marks = all_windows[all_windows["has_marks"]].dropna(subset=["vec_smoothed"]).reset_index(drop=True)
print_debug(f"Windows with marks: {len(windows_with_marks)} (out of {len(all_windows)})")


# ────────────────── 8) PAIRWISE CROSS‑CORRELATION WITHIN (type, condition, center) ──────────────────
def compute_pairwise_corr(group_df):
    """
    Compute pairwise Pearson correlation among vec_smoothed arrays within one group.
    Returns list of dicts: {'type', 'condition', 'center', 'corr'}.
    """
    rtype  = group_df["type"].iloc[0]
    cond   = group_df["condition"].iloc[0]
    center = group_df["center"].iloc[0]
    vecs   = group_df["vec_smoothed"].tolist()
    read_ids = group_df["read_id"].tolist()
    results = []

    if len(vecs) < 2:
        return results

    for i, j in combinations(range(len(vecs)), 2):
        v1 = vecs[i]
        v2 = vecs[j]
        rid1 = read_ids[i]
        rid2 = read_ids[j]

        mask = (~np.isnan(v1)) & (~np.isnan(v2))
        valid_n = mask.sum()
        if valid_n < 2:
            continue

        raw_r = pearsonr(v1[mask], v2[mask])[0]

        if raw_r > 1.0 or raw_r < -1.0:
            # Only now do we print a debug line
            summary_v1 = v1[mask][:5]
            summary_v2 = v2[mask][:5]
            print_debug(
                f"EXCESS CORR r={raw_r:.5f} for reads ({rid1}, {rid2}), "
                f"center={center}, overlap={valid_n} "
                f"first5(v1)={summary_v1}, first5(v2)={summary_v2}"
            )
            r = np.clip(raw_r, -1.0, 1.0)
            print_debug(f"  Clamped to r={r:.5f}")
        else:
            r = raw_r

        results.append({
            "type":      rtype,
            "condition": cond,
            "center":    center,
            "corr":      float(r)
        })

    return results


grouped = windows_with_marks.groupby(["type", "condition", "center"])
pairwise_results = []
for (_, _, _), group_df in tqdm(grouped, desc="pairwise groups"):
    pairwise_results.extend(compute_pairwise_corr(group_df))

pair_corrs_df = pd.DataFrame(pairwise_results)
print(f"[INFO] Total pairwise correlations computed: {len(pair_corrs_df)}")


# ────────────────── 9) PREPARE FOR PLOTTING ──────────────────
bins = np.arange(-REL_POS_RANGE, REL_POS_RANGE + REL_POS_BIN_WIDTH, REL_POS_BIN_WIDTH)
bin_labels_str = [str(int(b)) for b in bins[:-1]]

pair_corrs_df["bin_str"] = pair_corrs_df["center"].astype(int).astype(str)
pct_zero_df["bin_str"]  = pct_zero_df["center"].astype(int).astype(str)


# ────────────────── 10) COMPUTE MEDIAN & IQR FOR EACH (condition, bin) ──────────────────
summary_stats = (
    pair_corrs_df
      .groupby(["condition", "bin_str"])["corr"]
      .agg([
          ("median", lambda arr: float(np.nanmedian(arr))),
          ("q1",     lambda arr: float(np.nanpercentile(arr, 25))),
          ("q3",     lambda arr: float(np.nanpercentile(arr, 75)))
      ])
      .reset_index()
)

stats_by_cond = {}
for cond in summary_stats["condition"].unique():
    df_sub = summary_stats[summary_stats["condition"] == cond].set_index("bin_str")
    stats_by_cond[cond] = {
        "median": df_sub["median"].reindex(bin_labels_str, fill_value=np.nan).tolist(),
        "q1":     df_sub["q1"].reindex(bin_labels_str, fill_value=np.nan).tolist(),
        "q3":     df_sub["q3"].reindex(bin_labels_str, fill_value=np.nan).tolist()
    }

# ────────────────── After computing pair_corrs_df ──────────────────

print("\n>> CORR SUMMARY BEFORE PLOTTING:")
print("   min corr =", pair_corrs_df["corr"].min())
print("   max corr =", pair_corrs_df["corr"].max(), "\n")

# If you still see values outside [-1, 1], force‑clamp:
pair_corrs_df["corr"] = pair_corrs_df["corr"].clip(-1.0, 1.0)

print(">> CORR SUMMARY AFTER CLAMP:")
print("   min corr =", pair_corrs_df["corr"].min())
print("   max corr =", pair_corrs_df["corr"].max(), "\n")

# ────────────────── 11) VIOLIN + MEDIAN/IQR RIBBON PLOTS (STACKED) ──────────────────
fig = make_subplots(
    rows=2, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.10,
    row_heights=[0.7, 0.3],
    subplot_titles=[
        "Pairwise Cross‑Correlation Distributions by Bin & Condition (Violin)",
        "Median & IQR of Pairwise r by Bin & Condition"
    ]
)

for cond in sorted(pair_corrs_df["condition"].unique()):
    sub = pair_corrs_df[pair_corrs_df["condition"] == cond]
    fig.add_trace(
        go.Violin(
            x=sub["bin_str"],
            y=sub["corr"],
            name=cond,
            legendgroup=cond,
            scalegroup=cond,
            offsetgroup=cond,
            width=BOX_WIDTH,
            points=False,
            showlegend=True
        ),
        row=1, col=1
    )

color_map = px.colors.qualitative.Plotly
cond_list = sorted(summary_stats["condition"].unique())
for idx, cond in enumerate(cond_list):
    color = color_map[idx % len(color_map)]
    stats = stats_by_cond[cond]
    med = stats["median"]
    q1  = stats["q1"]
    q3  = stats["q3"]

    band_x = bin_labels_str + bin_labels_str[::-1]
    band_y = q3 + q1[::-1]
    fig.add_trace(
        go.Scatter(
            x=band_x,
            y=band_y,
            fill="toself",
            fillcolor=f"rgba({int(color[1:3],16)},{int(color[3:5],16)},{int(color[5:7],16)},0.2)",
            line=dict(color="rgba(0,0,0,0)"),
            showlegend=False
        ),
        row=2, col=1
    )

    fig.add_trace(
        go.Scatter(
            x=bin_labels_str,
            y=med,
            mode="lines",
            line=dict(color=color, width=2),
            name=cond,
            legendgroup=cond,
            showlegend=True
        ),
        row=2, col=1
    )

fig.update_layout(
    template=PLOT_TEMPLATE,
    height=900,
    width=1200
)
fig.update_xaxes(
    title_text="Relative Position Bin (left edge, bp)",
    type="category",
    categoryorder="array",
    categoryarray=bin_labels_str,
    row=2, col=1
)
fig.update_xaxes(visible=False, row=1, col=1)
fig.update_yaxes(title_text="Pairwise Pearson r", row=1, col=1)
fig.update_yaxes(title_text="Median r (with IQR)", row=2, col=1)
fig.show()


# ────────────────── 12) LINE PLOT: % ZERO‑MARK WINDOWS PER BIN ──────────────────
fig2 = go.Figure()
for idx, cond in enumerate(sorted(pct_zero_df["condition"].unique())):
    sub = pct_zero_df[pct_zero_df["condition"] == cond].copy()
    sub = sub.set_index("bin_str").reindex(bin_labels_str, fill_value=0).reset_index()
    fig2.add_trace(
        go.Scatter(
            x=sub["bin_str"],
            y=sub["pct_zero"] * 100,
            mode="lines+markers",
            name=cond,
            line=dict(width=2),
            marker=dict(size=6)
        )
    )

fig2.update_layout(
    template=PLOT_TEMPLATE,
    title="% Windows with Zero Methylation by Bin & Condition",
    xaxis=dict(
        title="Relative Position Bin (left edge, bp)",
        type="category",
        categoryorder="array",
        categoryarray=bin_labels_str
    ),
    yaxis=dict(
        title="Fraction with Zero Marks (%)",
        ticksuffix="%"
    ),
    width=1200,
    height=500
)
fig2.show()


In [None]:
# SINGLE FIBER CROSS CORRELATION
#  PIPELINE  v3‑with‑mean/ideal PLOT (per‑type ideal patterns)
#  • NaN‑aware smoothing for *all* signals (reads + per‑type mean)
#  • automatic PAD_LEN check so every CENTER_BIN has a pattern segment
#  • extra debug prints to verify coverage per bin & condition
#  • UPDATED: compute & shift one “ideal pattern” per `type`, then correlate each read only to its own type’s pattern
#
# Assumptions: `merged_df`, `analysis_cond` and other upstream objects exist.
# Drop‑in replacement for the previous cell.  No mock data is generated here.

import numpy as np, pandas as pd, multiprocessing as mp
from scipy.ndimage import gaussian_filter1d
from scipy.stats import pearsonr
from functools import partial
from tqdm.auto import tqdm
import plotly.graph_objects as go, plotly.express as px
from plotly.subplots import make_subplots
import collections, itertools, warnings

# ─────────────────── CONFIG ─────────────────── #
CHIP_RANK_CUTOFF = 80
ABOVE_FLAG       = True
TYPES_TO_INCLUDE = ["MOTIFS_rex48"]
CHR_TYPE_INCLUDE = []

CONDITIONS_TO_INCLUDE = [
    analysis_cond[0], analysis_cond[1],
    analysis_cond[2], analysis_cond[3]
]

MIN_READ_LENGTH = 500
REQUIRE_CENTRAL = False

RAW_SMOOTH_KIND  = "gaussian"      # gaussian | moving
RAW_MOVING_WIN   = 50              # bp
RAW_GAUSS_SIGMA  = 5               # bp
RAW_CLAMP_01     = False           # clip result to [0,1]?

REL_POS_RANGE          = 1000
LAG_MIN, LAG_MAX       = 140, 200
PATTERN_TYPE           = "cosine"     # boxcar | cosine | binary_mean
CORE_LEN               = 147
PAD_LEN                = 1500         # auto‑expanded if too small
SHIFT_RANGE            = 100          # how much to shift ideal signal from center
WINDOW_SIZE            = 100
REL_POS_BIN_WIDTH      = 100

PERFORM_FILLING        = True
MET_DOMAIN_WIDTH       = 9
PERFORM_INTERP         = True
INTERP_WINDOW          = 15
ZSCORE_BEFORE_AC       = False        # per‑read vector z‑score?

MAX_READS_PER_COND     = 0
RANDOM_STATE           = 42
N_WORKERS              = max(2, mp.cpu_count() - 2)
CHUNK_SIZE             = 10
DEBUG                  = True

PLOT_TEMPLATE          = "plotly_white"
BOX_WIDTH              = 0.4

def dbg(msg):
    if DEBUG:
        print(f"[DEBUG] {msg}")

# ────────────────── 1. READ FILTERS ────────────────── #
chiprank_df = (
    pd.read_csv("/Data1/reference/rex_chiprank.bed", sep=r"\s+")
      .assign(type=lambda d: "MOTIFS_" + d["type"].astype(str))
)
chip_rank_lookup = {
    t: round(float(rk) * 100, 3)
    for t, rk in zip(chiprank_df["type"], chiprank_df["chip_rank"])
}

keep_conds = set(CONDITIONS_TO_INCLUDE)
keep_types = (set(TYPES_TO_INCLUDE)
              if TYPES_TO_INCLUDE else
              {t for t, r in chip_rank_lookup.items()
               if (r >= CHIP_RANK_CUTOFF) == ABOVE_FLAG})
keep_chr   = (set(CHR_TYPE_INCLUDE)
              if CHR_TYPE_INCLUDE else
              set(merged_df["chr_type"].unique()))

df0 = merged_df.query(
    "condition in @keep_conds and type in @keep_types and chr_type in @keep_chr"
).copy()
dbg(f"after metadata filter: {len(df0)} reads")

df0 = df0[df0["rel_pos"].apply(
    lambda arr: ((arr >= -REL_POS_RANGE) & (arr <= REL_POS_RANGE)).any()
)].reset_index(drop=True)
dbg(f"after overlap filter:  {len(df0)} reads")

if REQUIRE_CENTRAL:
    half = MIN_READ_LENGTH // 2
    df0 = df0[df0["rel_pos"].apply(
        lambda a: (a.min() <= -half) and (a.max() >= half)
    )].reset_index(drop=True)
else:
    df0 = df0[df0["read_length"] >= MIN_READ_LENGTH].reset_index(drop=True)
dbg(f"after length/central filter: {len(df0)} reads")

df0 = df0[df0["mod_qual_bin"].apply(
    lambda x: isinstance(x, (list, np.ndarray)) and np.nansum(x) > 0
)].reset_index(drop=True)
dbg(f"after methylation filter: {len(df0)} reads")

reads_df = (
    df0.groupby("condition", group_keys=False)
       .apply(lambda d: d.sample(min(len(d), MAX_READS_PER_COND),
                                 random_state=RANDOM_STATE))
       .reset_index(drop=True)
    if MAX_READS_PER_COND > 0 else df0.copy()
)
dbg(f"reads to process: {len(reads_df)}")

# ────────────────── 2. NAN‑AWARE SMOOTHING HELPERS ────────────────── #
def _nan_gauss(a, sigma):
    if sigma <= 0:
        return a
    mask   = ~np.isnan(a)
    if not mask.any():
        return a
    filled = np.where(mask, a, 0.0)
    sm_val = gaussian_filter1d(filled, sigma=sigma, mode="nearest")
    sm_wt  = gaussian_filter1d(mask.astype(float), sigma=sigma,
                               mode="nearest")
    out = sm_val / sm_wt
    out[sm_wt < 1e-6] = np.nan
    return out

def _nan_moving(a, win):
    if win <= 1:
        return a
    mask   = ~np.isnan(a)
    if not mask.any():
        return a
    filled = np.where(mask, a, 0.0)
    kernel = np.ones(win, dtype=float)
    sm_val = np.convolve(filled, kernel, "same")
    sm_wt  = np.convolve(mask.astype(float), kernel, "same")
    out = sm_val / sm_wt
    out[sm_wt == 0] = np.nan
    return out

def _apply_smoothing(arr, kind=RAW_SMOOTH_KIND, moving_win=RAW_MOVING_WIN,
                     gauss_sigma=RAW_GAUSS_SIGMA, clamp=RAW_CLAMP_01):
    if kind == "moving":
        arr = _nan_moving(arr, moving_win)
    elif kind == "gaussian":
        arr = _nan_gauss(arr, gauss_sigma)
    if clamp:
        arr = np.clip(arr, 0, 1)
    return arr

raw_smooth = partial(
    _apply_smoothing,
    kind=RAW_SMOOTH_KIND,
    moving_win=RAW_MOVING_WIN,
    gauss_sigma=RAW_GAUSS_SIGMA,
    clamp=RAW_CLAMP_01
)

# ────────────────── 3. PER‑TYPE MEAN SIGNAL & lag★ ────────────────── #
dbg("building per‑type mean %m6A …")
axis = np.arange(-REL_POS_RANGE, REL_POS_RANGE + 1)
types_list = sorted(reads_df["type"].unique())

# Containers for per‑type data
mean_sig_by_type    = {}
lag_star_by_type    = {}
axis_list           = axis  # same axis for all types

for rtype in types_list:
    sub_df = reads_df[reads_df["type"] == rtype]
    # accumulate methylation values at each position for this type
    acc = {p: [] for p in axis_list}
    for pos, met in zip(sub_df["rel_pos"], sub_df["mod_qual_bin"]):
        for p, m in zip(pos, met):
            if p in acc:
                acc[p].append(float(m))
    mean_raw = np.array([np.mean(acc[p]) if acc[p] else np.nan
                         for p in axis_list])
    mean_sig = raw_smooth(mean_raw)
    mean_sig_by_type[rtype] = mean_sig

    # compute lag★ (peak autocorrelation) for this type
    def pearson_acf(x, L):
        a, b = x[:-L], x[L:]
        m    = ~np.isnan(a) & ~np.isnan(b)
        return np.nan if m.sum() < 2 else pearsonr(a[m], b[m])[0]

    lags     = np.arange(LAG_MIN, LAG_MAX + 1)
    acf_vals = [pearson_acf(mean_sig, L) for L in lags]
    lag_star = int(lags[np.nanargmax(acf_vals)])
    lag_star_by_type[rtype] = lag_star
    dbg(f"[{rtype}] lag★ = {lag_star} bp  (r_max = {np.nanmax(acf_vals):.4f})")

# ────────────────── 4. PER‑TYPE IDEALIZED PATTERN (shifted) ────────────────── #
# Ensure padding is sufficient
min_half_pad = WINDOW_SIZE + SHIFT_RANGE
if PAD_LEN // 2 < min_half_pad:
    dbg(f"PAD_LEN too small, increasing from {PAD_LEN} → {2 * min_half_pad}")
    PAD_LEN = 2 * min_half_pad
half_pad = PAD_LEN // 2

pattern_idx_full = np.arange(
    -REL_POS_RANGE - half_pad,
     REL_POS_RANGE + half_pad + 1
)

# Function to build a "raw" pattern array given lag★ and pattern type
def build_raw_pattern(idx_full, lag_star):
    if PATTERN_TYPE == "boxcar":
        pat = np.ones_like(idx_full, float)
        pat[np.abs(idx_full) <= CORE_LEN // 2] = 0.0
    elif PATTERN_TYPE == "cosine":
        pat = 0.5 * (1 + np.cos(2 * np.pi * idx_full / lag_star))
    elif PATTERN_TYPE == "binary_mean":
        # use median of the per‑type mean signal as threshold
        thr  = np.nanmedian(mean_sig)  # `mean_sig` replaced below per type
        # this branch will be replaced per type, so we won't call this directly
        raise ValueError("binary_mean should be handled per type")
    else:
        raise ValueError("bad PATTERN_TYPE")
    return pat

# Per‑type shifted patterns & segments
pattern_segments = {}  # nested: pattern_segments[rtype][center] -> array

for rtype in types_list:
    mean_sig = mean_sig_by_type[rtype]
    lag_star = lag_star_by_type[rtype]

    # (re)build raw pattern for this type
    if PATTERN_TYPE == "binary_mean":
        thr   = np.nanmedian(mean_sig)
        base  = (mean_sig >= thr).astype(float)
        rep   = int(np.ceil(len(pattern_idx_full) / len(base)))
        pat   = np.tile(base, rep)[:len(pattern_idx_full)]
    else:
        pat = build_raw_pattern(pattern_idx_full, lag_star)

    pat_series = pd.Series(pat, index=pattern_idx_full)

    # split axis into negative / positive for separate shifting
    neg_ax   = axis_list[axis_list < 0]
    pos_ax   = axis_list[axis_list > 0]
    mean_neg = mean_sig[axis_list < 0]
    mean_pos = mean_sig[axis_list > 0]

    def best_shift(p_series, ref_idx, ref_sig, rng):
        best_r, best_s = -np.inf, 0
        for s in range(-rng, rng + 1):
            shifted = p_series.reindex(ref_idx + s).to_numpy()
            m       = ~np.isnan(ref_sig) & ~np.isnan(shifted)
            if m.sum() < 2:
                continue
            r = pearsonr(ref_sig[m], shifted[m])[0]
            if r > best_r:
                best_r, best_s = r, s
        return best_s

    shift_neg = best_shift(pat_series, neg_ax, mean_neg, SHIFT_RANGE)
    shift_pos = best_shift(pat_series, pos_ax, mean_pos, SHIFT_RANGE)
    dbg(f"[{rtype}] best shifts  upstream={shift_neg} bp  downstream={shift_pos} bp")

    # apply shifts to build the final per‑type “ideal pattern”
    pat_shift = pat_series.copy()
    pat_shift.loc[pat_shift.index < 0] = pat_series.reindex(
        pat_shift.index[pat_shift.index < 0] + shift_neg
    ).values
    pat_shift.loc[pat_shift.index > 0] = pat_series.reindex(
        pat_shift.index[pat_shift.index > 0] + shift_pos
    ).values

    # store segments for each center bin
    CENTER_BINS = list(range(-REL_POS_RANGE, REL_POS_RANGE + 1,
                             REL_POS_BIN_WIDTH))
    pattern_segments[rtype] = {}
    for c in CENTER_BINS:
        idx = np.arange(c - WINDOW_SIZE, c + WINDOW_SIZE + 1)
        pattern_segments[rtype][c] = pat_shift.reindex(idx).to_numpy()

    # (Optional) compute and plot per‑type mean vs. its ideal pattern
    pattern_on_axis = pat_shift.reindex(axis_list).to_numpy()
    mask_mp = (~np.isnan(mean_sig)) & (~np.isnan(pattern_on_axis))
    r_mean_pattern = np.nan
    if mask_mp.sum() >= 2:
        r_mean_pattern = pearsonr(mean_sig[mask_mp], pattern_on_axis[mask_mp])[0]
    dbg(f"[{rtype}] Pearson r between mean_sig and ideal pattern = {r_mean_pattern:.4f}")

    # Example per‑type plot (comment out if not needed)
    fig_mp = go.Figure()
    fig_mp.add_trace(go.Scatter(
        x=axis_list, y=mean_sig,
        mode="lines", name=f"{rtype} mean_sig",
        line=dict(width=2, color="blue")
    ))
    fig_mp.add_trace(go.Scatter(
        x=axis_list, y=pattern_on_axis,
        mode="lines", name=f"{rtype} ideal_pattern_shifted",
        line=dict(width=2, color="red")
    ))
    fig_mp.update_layout(
        template=PLOT_TEMPLATE,
        title=f"{rtype}: Mean methylation (blue) vs. Idealized pattern (red)<br>"
              f"Pearson r = {r_mean_pattern:.4f}",
        xaxis_title="Relative Position (bp)",
        yaxis_title="Methylation signal (0–1)",
        width=900, height=400
    )
    fig_mp.show()

# ────────────────── 5. READ WINDOW EXTRACTION ────────────────── #
def _fill_met_domains(arr, width):
    idx = np.where(arr == 1)[0]
    if len(idx) < 2:
        return arr
    out = arr.copy()
    for a, b in zip(idx, idx[1:]):
        if b - a <= width:
            out[a:b+1] = 1
    return out

def _mode_interpolate(arr, radius):
    isnan = np.isnan(arr)
    if not isnan.any():
        return arr
    valid = (~isnan).astype(int)
    ones  = ((arr == 1) & ~isnan).astype(int)
    c_val = np.concatenate(([0], np.cumsum(valid)))
    c_one = np.concatenate(([0], np.cumsum(ones)))
    def cnt(c, i):
        lo, hi = max(0, i-radius), min(len(arr)-1, i+radius)
        return c[hi+1] - c[lo]
    out = arr.copy()
    for i in np.where(isnan)[0]:
        tot = cnt(c_val, i)
        out[i] = 0.0 if tot == 0 else (1.0 if cnt(c_one, i) > tot/2 else 0.0)
    return out

def get_windows(rec):
    rid, cond, rtype = rec["read_id"], rec["condition"], rec["type"]
    pos, met = np.asarray(rec["rel_pos"], int), np.asarray(rec["mod_qual_bin"], float)
    lut = {p: i for i, p in enumerate(pos)}
    rows = []
    min_c, max_c = pos.min() + WINDOW_SIZE, pos.max() - WINDOW_SIZE
    for c in CENTER_BINS:
        vec = np.full(2 * WINDOW_SIZE + 1, np.nan)
        if c < min_c or c > max_c:
            rows.append((rid, cond, rtype, c, None, False))
            continue
        for p, i in lut.items():
            off = p - c
            if -WINDOW_SIZE <= off <= WINDOW_SIZE:
                vec[off + WINDOW_SIZE] = met[i]
        if np.isnan(vec).all():
            rows.append((rid, cond, rtype, c, None, False))
            continue
        if PERFORM_FILLING:
            vec = _fill_met_domains(vec, MET_DOMAIN_WIDTH)
        if PERFORM_INTERP:
            vec = _mode_interpolate(vec, INTERP_WINDOW)
        vec = raw_smooth(vec)
        if np.isnan(vec).all():
            rows.append((rid, cond, rtype, c, None, False))
            continue
        if ZSCORE_BEFORE_AC:
            msk = ~np.isnan(vec)
            if msk.sum() > 1:
                m, s = np.nanmean(vec[msk]), np.nanstd(vec[msk])
                if s > 0:
                    vec[msk] = (vec[msk] - m) / s
        rows.append((rid, cond, rtype, c, vec, True))
    return rows

dbg(f"smoothing per‑read windows ({N_WORKERS} workers)…")
with mp.Pool(N_WORKERS) as pool:
    flat = list(itertools.chain.from_iterable(
        tqdm(pool.imap_unordered(
            get_windows, reads_df.to_dict("records"),
            chunksize=CHUNK_SIZE),
            total=len(reads_df), desc="windows")))
all_win = pd.DataFrame(flat, columns=[
    "read_id", "condition", "type", "center", "vec_smoothed", "has_marks"
])
dbg(f"all_windows shape: {all_win.shape}")

# ────────────────── 6. ZERO‑MARK STATS ────────────────── #
tot = all_win.groupby(["condition", "center"]).size().reset_index(name="total")
nom = (all_win[~all_win["has_marks"]]
       .groupby(["condition", "center"])
       .size().reset_index(name="no_marks"))
pct = pd.merge(tot, nom, on=["condition", "center"], how="left")
pct["no_marks"] = pct["no_marks"].fillna(0).astype(int)
pct["pct_zero"] = pct["no_marks"] / pct["total"]

# ────────────────── 7. READ‑vs‑PATTERN CORRELATIONS (per‑type) ────────────────── #
def corr_job(row):
    if not row["has_marks"]:
        return None
    vec = row["vec_smoothed"]
    rtype = row["type"]
    # look up this read's type‑specific pattern for its bin
    pat = pattern_segments[rtype].get(row["center"], None)
    if pat is None:
        return None
    m   = (~np.isnan(vec)) & (~np.isnan(pat))
    if m.sum() < 2:
        return None
    r = pearsonr(vec[m], pat[m])[0]
    return {
        "condition": row["condition"],
        "type":      rtype,
        "center":    row["center"],
        "corr":      float(np.clip(r, -1, 1))
    }

dbg("computing correlations …")
with mp.Pool(N_WORKERS) as pool:
    corr_dicts = list(tqdm(pool.imap_unordered(
        corr_job, all_win.to_dict("records"), chunksize=CHUNK_SIZE),
        total=len(all_win), desc="corr"))
corr_df = pd.DataFrame([d for d in corr_dicts if d])
dbg(f"corr rows: {len(corr_df)}")

if DEBUG:
    flat = (corr_df.groupby(["type", "center"])
                   .agg(n=("corr", "size"),
                        std=("corr", "std"),
                        unique_vals=("corr", lambda x: len(set(np.round(x,5)))))
                   .reset_index())
    print("\n[DEBUG] variance of r by type × bin")
    print(flat)

    empty = flat[flat["std"].fillna(0) == 0]
    if not empty.empty:
        print("\n[DEBUG] bins with zero variance (per type):")
        print(empty)

# ────────────────── 8½.  BIN LABELS & axis sanity  ──────────────────
bin_lbl = [str(c) for c in CENTER_BINS]
corr_df["bin_str"] = corr_df["center"].astype(str)
pct["bin_str"]     = pct["center"].astype(str)

if DEBUG:
    print("\n[DEBUG] unique bin_str in corr_df:", sorted(corr_df["bin_str"].unique()))
    missing = set(bin_lbl) - set(corr_df["bin_str"])
    if missing:
        print("[DEBUG]  !!  These bins have no corr rows:", sorted(missing))
    extra = set(corr_df["bin_str"]) - set(bin_lbl)
    if extra:
        print("[DEBUG]  !!  These bins are in corr_df but not bin_lbl:", sorted(extra))

if corr_df.empty:
    raise RuntimeError("No valid correlations — adjust WINDOW_SIZE, SHIFT_RANGE, or smoothing parameters.")

# ────────────────── 9. PLOT PREPARATION  (patched) ────────────────── #
summary = (corr_df.groupby(["condition", "bin_str"])["corr"]
           .agg(median=lambda x: float(np.nanmedian(x)),
                q1=lambda x: float(np.nanpercentile(x, 25)),
                q3=lambda x: float(np.nanpercentile(x, 75)))
           .reset_index())

stats = {}
for cond in summary["condition"].unique():
    sub = (summary[summary["condition"] == cond]
           .set_index("bin_str")[["median", "q1", "q3"]]
           .reindex(bin_lbl))
    stats[cond] = {
        "median": sub["median"].tolist(),
        "q1":     sub["q1"].tolist(),
        "q3":     sub["q3"].tolist()
    }

fig = make_subplots(
    rows=2, cols=1, shared_xaxes=True,
    vertical_spacing=0.10, row_heights=[0.7, 0.3],
    subplot_titles=[
        "Read vs Ideal‑Pattern Correlations",
        "Median ± IQR by Bin & Condition"
    ]
)
palette = px.colors.qualitative.Plotly
for i, cond in enumerate(sorted(corr_df["condition"].unique())):
    col = palette[i % len(palette)]
    sub = corr_df[corr_df["condition"] == cond]
    fig.add_trace(
        go.Violin(
            x=sub["bin_str"], y=sub["corr"], name=cond,
            legendgroup=cond, offsetgroup=cond, width=BOX_WIDTH,
            points=False, line_color=col, fillcolor="rgba(0,0,0,0)",
            box_visible=True, meanline_visible=True
        ),
        row=1, col=1
    )
    s = stats[cond]
    band_x = bin_lbl + bin_lbl[::-1]
    band_y = s["q3"] + s["q1"][::-1]
    fig.add_trace(
        go.Scatter(
            x=band_x, y=band_y, fill="toself",
            fillcolor="rgba(0,0,0,0)", line=dict(color="rgba(0,0,0,0)"),
            showlegend=False
        ),
        row=2, col=1
    )
    fig.add_trace(
        go.Scatter(
            x=bin_lbl, y=s["median"], mode="lines",
            line=dict(color=col, width=2),
            name=cond, legendgroup=cond
        ),
        row=2, col=1
    )

fig.update_layout(template=PLOT_TEMPLATE, height=900, width=1200)
fig.update_xaxes(
    type="category", categoryorder="array",
    categoryarray=bin_lbl, row=2, col=1
)
fig.update_xaxes(visible=False, row=1, col=1)
fig.update_yaxes(title_text="Pearson r", row=1, col=1)
fig.update_yaxes(title_text="Median r (IQR)", row=2, col=1)
fig.update_xaxes(
    type="category", categoryorder="array",
    categoryarray=bin_lbl, row=1, col=1
)
fig.update_xaxes(
    type="category", categoryorder="array",
    categoryarray=bin_lbl, row=2, col=1
)

fig.show()

# ────────────────── 10. FIGURE 2: % ZERO‑MARK ────────────────── #
fig2 = go.Figure()
for i, cond in enumerate(sorted(pct["condition"].unique())):
    sub = (
        pct[pct["condition"] == cond]
        .set_index("bin_str")
        .reindex(bin_lbl, fill_value=0)
        .reset_index()
    )
    fig2.add_trace(go.Scatter(
        x=sub["bin_str"], y=sub["pct_zero"] * 100,
        mode="lines+markers", name=cond,
        line=dict(width=2), marker=dict(size=6)
    ))
fig2.update_layout(
    template=PLOT_TEMPLATE,
    width=1200, height=500,
    title="% Windows with Zero Methylation",
    xaxis=dict(
        type="category", categoryorder="array",
        categoryarray=bin_lbl,
        title="Relative Position Bin (left edge, bp)"
    ),
    yaxis=dict(title="Zero‑mark windows (%)", ticksuffix="%")
)
fig2.show()


In [None]:
### single fiber autocorrelation
import numpy as np
import pandas as pd
import multiprocessing as mp
from functools import partial
from scipy.ndimage import gaussian_filter1d
from tqdm.auto import tqdm
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ────────────────── 1) CONFIGURATION ──────────────────
CHIP_RANK_CUTOFF       = 80
ABOVE_FLAG             = True
TYPES_TO_INCLUDE       = []
CHR_TYPE_INCLUDE       = []

# Existing list of conditions (defined earlier in the notebook)
CONDITIONS_TO_INCLUDE  = [
    analysis_cond[0],
    analysis_cond[1],
    analysis_cond[2],
    analysis_cond[3]
]

MIN_READ_LENGTH        = 500
REQUIRE_CENTRAL        = False

ZSCORE_BEFORE_AC       = False

PERFORM_FILLING        = True
MET_DOMAIN_WIDTH       = 9

PERFORM_INTERP         = True
INTERP_WINDOW          = 5

RAW_SMOOTH_KIND        = "gaussian"
RAW_MOVING_WIN         = 21
RAW_GAUSS_SIGMA        = 10
RAW_CLAMP_01           = False

REL_POS_RANGE          = 1000

# ── REL_POS BIN WIDTH (compute centers every bin_width) ──
REL_POS_BIN_WIDTH      = 100      # configurable
CENTER_BINS            = list(range(-REL_POS_RANGE,
                                     REL_POS_RANGE + 1,
                                     REL_POS_BIN_WIDTH))

WINDOW_SIZE            = 150       # ensure WINDOW_SIZE >= LAG_MAX * 2

# ── LAG SEARCH RANGE ──
LAG_MIN                = 140       # inclusive
LAG_MAX                = 200       # inclusive

# ── AC PEAK THRESHOLD ──
AC_THRESHOLD           = 0.1       # configurable

MAX_READS_PER_COND     = 0         # 0 → no down‑sampling
RANDOM_STATE           = 42

N_WORKERS              = max(2, mp.cpu_count() - 2)
CHUNK_SIZE             = 10
DEBUG                  = True

# Plot styling
PLOT_TEMPLATE          = "plotly_white"
BOX_WIDTH              = 0.4       # width of each box trace


def print_debug(msg: str):
    if DEBUG:
        print(f"[DEBUG] {msg}")


# ────────────────── 2) FILTERING ──────────────────
chiprank_df = (
    pd.read_csv("/Data1/reference/rex_chiprank.bed", sep=r"\s+")
      .assign(type=lambda d: "MOTIFS_" + d["type"].astype(str))
)
chip_rank_lookup = {
    t: round(float(rk) * 100, 3)
    for t, rk in zip(chiprank_df["type"], chiprank_df["chip_rank"])
}

keep_conds = set(CONDITIONS_TO_INCLUDE)
keep_types = (
    set(TYPES_TO_INCLUDE)
    if TYPES_TO_INCLUDE
    else {t for t, r in chip_rank_lookup.items()
          if (r >= CHIP_RANK_CUTOFF) == ABOVE_FLAG}
)
keep_chr = (
    set(CHR_TYPE_INCLUDE)
    if CHR_TYPE_INCLUDE
    else set(merged_df["chr_type"].unique())
)

df0 = merged_df.query(
    "condition in @keep_conds and type in @keep_types and chr_type in @keep_chr"
).copy()
print(f"[INFO] after metadata filter: {len(df0)} reads")

mask_overlap = df0["rel_pos"].apply(
    lambda arr: ((arr >= -REL_POS_RANGE) & (arr <= REL_POS_RANGE)).any()
)
df0 = df0[mask_overlap].reset_index(drop=True)
print(f"[INFO] after overlap filter:  {len(df0)} reads")

if REQUIRE_CENTRAL:
    half = MIN_READ_LENGTH // 2
    mask_central = df0["rel_pos"].apply(
        lambda arr: (arr.min() <= -half) and (arr.max() >= half)
    )
    filtered_reads_df = df0[mask_central].reset_index(drop=True)
else:
    filtered_reads_df = df0[df0["read_length"] >= MIN_READ_LENGTH].reset_index(drop=True)
print(f"[INFO] after length/central filter: {len(filtered_reads_df)} reads")

# drop reads with no methylation data
mask_valid = filtered_reads_df["mod_qual_bin"].apply(
    lambda x: isinstance(x, (list, np.ndarray)) and np.nansum(x) > 0
)
filtered_reads_df = filtered_reads_df[mask_valid].reset_index(drop=True)
print(f"[INFO] after methylation filter: {len(filtered_reads_df)} reads")


# ────────────────── 3) OPTIONAL DOWNSAMPLING ──────────────────
if MAX_READS_PER_COND > 0:
    reads_to_process = (
        filtered_reads_df
          .groupby("condition", group_keys=False)
          .apply(lambda df: df.sample(
              n=min(len(df), MAX_READS_PER_COND),
              random_state=RANDOM_STATE
          ))
          .reset_index(drop=True)
    )
    print_debug(f"Down‑sampled: {len(reads_to_process)} / {len(filtered_reads_df)} reads")
else:
    reads_to_process = filtered_reads_df.copy()
    print_debug(f"Using all {len(reads_to_process)} reads")

records = reads_to_process.to_dict(orient="records")


# ────────────────── 4) SMOOTHING HELPERS ──────────────────
def _fill_met_domains(arr, width):
    idx = np.where(arr == 1)[0]
    if len(idx) < 2:
        return arr
    out = arr.copy()
    for a, b in zip(idx, idx[1:]):
        if b - a <= width:
            out[a : b + 1] = 1
    return out

def _mode_interpolate(arr, radius):
    isnan = np.isnan(arr)
    if not isnan.any():
        return arr
    valid = (~isnan).astype(int)
    ones  = ((arr == 1) & ~isnan).astype(int)
    c_val = np.concatenate(([0], np.cumsum(valid)))
    c_one = np.concatenate(([0], np.cumsum(ones)))
    def count(c_vec, i):
        lo, hi = max(0, i - radius), min(len(arr) - 1, i + radius)
        return c_vec[hi + 1] - c_vec[lo]
    out = arr.copy()
    for i in np.where(isnan)[0]:
        tot = count(c_val, i)
        out[i] = 0.0 if tot == 0 else (1.0 if count(c_one, i) > tot / 2 else 0.0)
    return out

def _apply_smoothing(y, kind, moving_win, gauss_sigma, clamp):
    if kind == "moving" and moving_win > 1:
        y = np.convolve(y, np.ones(moving_win)/moving_win, mode="same")
    elif kind == "gaussian" and gauss_sigma > 0:
        y = gaussian_filter1d(y, sigma=gauss_sigma, mode="nearest")
    return np.clip(y, 0, 1) if clamp else y

raw_smooth = partial(
    _apply_smoothing,
    kind=RAW_SMOOTH_KIND,
    moving_win=RAW_MOVING_WIN,
    gauss_sigma=RAW_GAUSS_SIGMA,
    clamp=RAW_CLAMP_01
)


# ────────────────── 5) COMPUTE PEAK NRL PER WINDOW ──────────────────
def compute_peak_nrl(rec):
    """
    For each center in CENTER_BINS, extract a window around that center,
    smooth/fill/interpolate, compute full AC[k] for k in [LAG_MIN, LAG_MAX],
    smooth the AC curve, pick lag_star = argmax AC_smooth, and record:
      - lag_star (NaN if below threshold)
      - peak_val (always computed if marks exist)
      - included flag (True if peak_val >= AC_THRESHOLD)
      - has_marks flag (True if any methylation present in window)
    """
    read_id = rec["read_id"]
    cond    = rec["condition"]
    rel_pos = np.asarray(rec["rel_pos"], dtype=int)
    signal  = np.asarray(rec["mod_qual_bin"], dtype=float)
    pos_to_idx = {p: i for i, p in enumerate(rel_pos)}

    rows = []

    min_center = rel_pos.min() + WINDOW_SIZE
    max_center = rel_pos.max() - WINDOW_SIZE

    for center in CENTER_BINS:
        lag_star = np.nan
        peak_val = np.nan
        included_flag = False
        has_marks = False

        # Skip windows entirely outside the read
        if center < min_center or center > max_center:
            rows.append({
                "read_id":    read_id,
                "condition":  cond,
                "center":     center,
                "lag_star":   lag_star,
                "peak_val":   peak_val,
                "included":   included_flag,
                "has_marks":  False
            })
            continue

        # Build the raw window around 'center'
        win_len = 2 * WINDOW_SIZE + 1
        vec = np.full(win_len, np.nan)
        for p, idx in pos_to_idx.items():
            offset = p - center
            if -WINDOW_SIZE <= offset <= WINDOW_SIZE:
                vec[offset + WINDOW_SIZE] = signal[idx]

        valid_pts = ~np.isnan(vec)
        if valid_pts.sum() == 0:
            # No methylation marks in this window
            rows.append({
                "read_id":    read_id,
                "condition":  cond,
                "center":     center,
                "lag_star":   lag_star,
                "peak_val":   peak_val,
                "included":   included_flag,
                "has_marks":  False
            })
            continue

        has_marks = True

        # Smoothing / filling / interpolation
        if PERFORM_FILLING:
            vec = _fill_met_domains(vec, MET_DOMAIN_WIDTH)
        if PERFORM_INTERP:
            vec = _mode_interpolate(vec, INTERP_WINDOW)
        vec = raw_smooth(vec)

        valid_pts = ~np.isnan(vec)
        if ZSCORE_BEFORE_AC and valid_pts.sum() > 1:
            m, s = np.nanmean(vec[valid_pts]), np.nanstd(vec[valid_pts])
            vec[valid_pts] = (vec[valid_pts] - m)/s if s > 0 else 0.0

        # Compute raw AC[k] for k=0..LAG_MAX
        ac_full = np.full(LAG_MAX + 1, np.nan)
        n = len(vec)
        for k in range(min(LAG_MAX, n - 1) + 1):
            x1, x2 = vec[:n - k], vec[k:]
            vp = (~np.isnan(x1)) & (~np.isnan(x2))
            if vp.sum() < 2:
                continue
            x1v, x2v = x1[vp], x2[vp]
            num = np.sum((x1v - x1v.mean()) * (x2v - x2v.mean()))
            den = np.sqrt(np.sum((x1v - x1v.mean())**2) *
                          np.sum((x2v - x2v.mean())**2))
            ac_full[k] = num/den if den > 0 else np.nan

        # Extract and smooth AC segment [LAG_MIN:LAG_MAX]
        ac_segment = ac_full[LAG_MIN : LAG_MAX + 1]
        if not np.all(np.isnan(ac_segment)):
            ac_smoothed = gaussian_filter1d(
                np.nan_to_num(ac_segment, nan=0.0),
                sigma=2,
                mode="nearest"
            )
            idx_peak = np.nanargmax(ac_smoothed)
            lag_star = LAG_MIN + int(idx_peak)
            peak_val = ac_smoothed[idx_peak]
            included_flag = (peak_val >= AC_THRESHOLD)

        rows.append({
            "read_id":    read_id,
            "condition":  cond,
            "center":     center,
            "lag_star":   lag_star,
            "peak_val":   peak_val,
            "included":   included_flag,
            "has_marks":  has_marks
        })

    return pd.DataFrame(rows)


# ────────────────── 6) MULTIPROCESSING & COLLECTION ──────────────────
print_debug(f"Launching pool with {N_WORKERS} workers")
with mp.Pool(N_WORKERS) as pool:
    it  = pool.imap_unordered(compute_peak_nrl, records, chunksize=CHUNK_SIZE)
    dfs = [df for df in tqdm(it, total=len(records), desc="reads") if not df.empty]
per_read_nrl = pd.concat(dfs, ignore_index=True)
print(f"[INFO] per_read_nrl shape: {per_read_nrl.shape}")

# ────────────────── 7) EXCLUDE WINDOWS WITH NO MARKS ──────────────────
total_windows = len(per_read_nrl)
no_mark_windows = (per_read_nrl["has_marks"] == False).sum()
print_debug(f"Total windows: {total_windows}; windows with NO methylation marks: {no_mark_windows}")

# Keep only windows that had at least one methylation mark
per_read_nrl = per_read_nrl[per_read_nrl["has_marks"]].copy()
print_debug(f"After excluding no-mark windows: {len(per_read_nrl)} windows remain")


# ────────────────── 8) BINNING ──────────────────

# Define bins from -REL_POS_RANGE to +REL_POS_RANGE in steps of REL_POS_BIN_WIDTH
bins = np.arange(-REL_POS_RANGE,
                 REL_POS_RANGE + REL_POS_BIN_WIDTH,
                 REL_POS_BIN_WIDTH)

# Create stringified labels for ordering
bin_labels_str = [str(int(b)) for b in bins[:-1]]  # e.g. ["-2000","-1900",...,"1900"]

# Assign each window’s center to a numeric bin
per_read_nrl["bin"] = pd.cut(
    per_read_nrl["center"],
    bins=bins,
    labels=bins[:-1]
).astype(float)

# Drop any rows that did not fall into a bin (NaN)
nan_bins = per_read_nrl["bin"].isna().sum()
print_debug(f"Windows with center outside bin range (dropped): {nan_bins}")
per_read_nrl = per_read_nrl.dropna(subset=["bin"]).copy()

# Convert bin numeric to string for categorical plotting
per_read_nrl["bin_str"] = per_read_nrl["bin"].astype(int).astype(str)

print_debug("After bin assignment, sample data:")
print_debug(per_read_nrl.head())


# ────────────────── 9) BOX PLOT (Grouped by Condition & Bin) ──────────────────

# Filter to only windows that passed the AC threshold
df_included = per_read_nrl[per_read_nrl["included"]].copy()
df_included = df_included.dropna(subset=["bin_str", "lag_star"])
print_debug(f"Number of included windows (for box plot): {len(df_included)}")

# Use Plotly Express to create a grouped box plot:
fig1 = px.box(
    df_included,
    x="bin_str",
    y="lag_star",
    color="condition",
    category_orders={"bin_str": bin_labels_str},
    points=False
)

fig1.update_layout(
    template=PLOT_TEMPLATE,
    title="Called NRLs by Rel_Pos Bin and Condition (Box Plot)",
    xaxis_title="Relative Position Bin (left edge, bp)",
    yaxis_title="Called NRL (bp)",
    boxmode="group",       # place boxes side-by-side
    width=1200,
    height=600
)

# Remove interior fill by setting each trace’s fillcolor to transparent
for trace in fig1.data:
    trace.update(fillcolor="rgba(0,0,0,0)", line=dict(width=1))

fig1.show()

# ────────────────── 9b) LINE PLOT: MEDIAN NRL BY CONDITION ──────────────────

# Compute the median NRL (lag_star) for each condition
median_df = df_included.groupby("condition")["lag_star"].median().reset_index()

fig_median = go.Figure()
fig_median.add_trace(
    go.Scatter(
        x=median_df["condition"],
        y=median_df["lag_star"],
        mode="lines+markers",
        line=dict(width=2),
        marker=dict(size=6),
        name="Median NRL"
    )
)
fig_median.update_layout(
    template=PLOT_TEMPLATE,
    title="Median Called NRL by Condition",
    xaxis_title="Condition",
    yaxis_title="Median NRL (bp)",
    width=800,
    height=500
)
fig_median.show()
# ────────────────── NEW: LINE PLOT OF MEDIAN NRL BY BIN & CONDITION ──────────────────

# Compute median lag_star per (condition, bin_str)
median_per_bin = (
    df_included
      .groupby(["condition", "bin_str"])["lag_star"]
      .median()
      .reset_index(name="median_lag_star")
)

# For plotting, ensure bins are sorted in numeric order:
median_per_bin["bin_int"] = median_per_bin["bin_str"].astype(int)
median_per_bin = median_per_bin.sort_values(["condition", "bin_int"])

# Build a Plotly Go figure with one line per condition
fig_line = go.Figure()
for cond in CONDITIONS_TO_INCLUDE:
    subset = median_per_bin[median_per_bin["condition"] == cond]
    if subset.empty:
        continue
    fig_line.add_trace(
        go.Scatter(
            x=subset["bin_str"],
            y=subset["median_lag_star"],
            mode="lines+markers",
            name=cond,
            line=dict(width=2),
            marker=dict(size=6)
        )
    )

fig_line.update_layout(
    template=PLOT_TEMPLATE,
    title="Median Called NRL by Bin and Condition (Line Plot)",
    xaxis_title="Relative Position Bin (left edge, bp)",
    yaxis_title="Median Called NRL (bp)",
    width=1200,
    height=600
)

# Enforce the same bin ordering on the x-axis
fig_line.update_xaxes(
    type="category",
    categoryorder="array",
    categoryarray=bin_labels_str
)

fig_line.show()



# ────────────────── 10) INCLUDED vs. EXCLUDED COUNTS & PERCENTAGES ──────────────────

# Aggregate counts per condition and bin
summary = (
    per_read_nrl
      .groupby(["condition", "bin"])
      .agg(
          included_count=pd.NamedAgg(column="included", aggfunc=lambda x: np.sum(x)),
          total_count=pd.NamedAgg(column="included", aggfunc="size")
      )
      .reset_index()
)
summary["excluded_count"] = summary["total_count"] - summary["included_count"]
summary["pct_included"] = summary["included_count"] / summary["total_count"]

print_debug("Summary counts (first 10 rows):")
print_debug(summary.head(10))

# Create subplots: one row per condition
conds = sorted(summary["condition"].unique())
num_conds = len(conds)
fig2 = make_subplots(
    rows=num_conds,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.05,
    specs=[[{"secondary_y": True}] for _ in conds],
    subplot_titles=[f"Condition: {c}" for c in conds]
)

for idx, cond in enumerate(conds, start=1):
    cond_df = summary[summary["condition"] == cond].copy()
    # Reindex so every bin is present
    cond_df = cond_df.set_index("bin").reindex(bins[:-1], fill_value=0).reset_index()
    cond_df["bin_str"] = cond_df["bin"].astype(int).astype(str)

    # Bar traces: Included vs. Excluded
    fig2.add_trace(
        go.Bar(
            x=cond_df["bin_str"],
            y=cond_df["included_count"],
            name="Included",
            marker_color="green",
            offsetgroup=str(idx) + "_incl",
            showlegend=(idx == 1),
        ),
        row=idx, col=1, secondary_y=False
    )
    fig2.add_trace(
        go.Bar(
            x=cond_df["bin_str"],
            y=cond_df["excluded_count"],
            name="Excluded",
            marker_color="lightgrey",
            offsetgroup=str(idx) + "_excl",
            showlegend=(idx == 1),
        ),
        row=idx, col=1, secondary_y=False
    )

    # Line trace: % Included
    fig2.add_trace(
        go.Scatter(
            x=cond_df["bin_str"],
            y=cond_df["pct_included"] * 100,
            name="% Included",
            mode="lines+markers",
            line=dict(color="black", width=1),
            marker=dict(size=4),
            showlegend=(idx == 1),
        ),
        row=idx, col=1, secondary_y=True
    )

    # Y-axis labels
    fig2.update_yaxes(title_text="Count", row=idx, col=1, secondary_y=False)
    fig2.update_yaxes(title_text="% Included", row=idx, col=1, secondary_y=True)

# Layout adjustments
fig2.update_layout(
    template=PLOT_TEMPLATE,
    title="Included vs. Excluded Windows & % Included per Bin and Condition",
    xaxis_title="Relative Position Bin (left edge, bp)",
    barmode="group",  # group bars side-by-side
    width=1200,
    height=300 * num_conds
)

# Enforce x-axis ordering across subplots
fig2.update_xaxes(
    type="category",
    categoryorder="array",
    categoryarray=bin_labels_str
)

fig2.show()


# ────────────────── 11) HISTOGRAM OF ALL SLIDING-WINDOW MAX-AC VALUES ──────────────────

# Extract all peak_val (drop NaN)
all_peak_vals = per_read_nrl["peak_val"].dropna()
print_debug(f"Windows with a valid peak_val: {len(all_peak_vals)} "
            f"out of {len(per_read_nrl)}")

fig3 = go.Figure(
    go.Histogram(
        x=all_peak_vals,
        nbinsx=50,
        marker_color="blue",
        opacity=0.75
    )
)
fig3.update_layout(
    template=PLOT_TEMPLATE,
    title="Histogram of Sliding-Window Peak AC Values (All Bins)",
    xaxis_title="Peak AC Value",
    yaxis_title="Number of Windows",
    width=800,
    height=500
)
fig3.show()


In [None]:
# ╔══════════════════════════════════════════════════════════════════╗
# ║  Stratified empirical‑null simulator for template‑specific FDR   ║
# ╚══════════════════════════════════════════════════════════════════╝
#
#  • Builds a *full* null distribution of correlation scores r for every
#    (template, k‑bin [,condition]) combination.
#  • Provides helpers:
#        p_value_for_tpl_k(tpl, k, r [,cond])   → tail p‑value
#        bh_threshold_for_tpl(tpl, m_tests, q)  → r‑cut‑off for BH FDR
#
#  Assumes `merged_df` exists and supplies `mod_qual_bin`
#  vector‑of‑ints/NaNs for p(m6A) estimation.
# --------------------------------------------------------------------

import math, os, multiprocessing as mp
from collections import defaultdict
import numpy as np, pandas as pd
from tqdm.auto import tqdm

# ─────────────────────────  USER CONFIG  ─────────────────────────── #
# ───────── BOOSTING TOGGLE ───────── #
USE_M6A_BOOST = True          # False ⇒ every position weight = 1
# ─────────────────────────────────── #

PERFORM_FILLING       = False
MET_DOMAIN_WIDTH      = 9
TEMPLATES      = [
    #(14, 298, 14),                 # T1  (175 bp total)   ← legacy default
    (14, 147, 14),                 # T1  (175 bp total)   ← legacy default
    (14, 166, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 170, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 160, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 150, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 140, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 130, 14),                 # T2
    # (14, 120, 14),                 # T2
    (14, 110, 14),                 # T2
    # (14,  100, 14),                 # T3
    # (14,  90, 14),                 # T3
    (14,  80, 14),                 # T3
    # (14, 175, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 165, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 155, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 145, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 135, 14),                 # T2
    # (14, 125, 14),                 # T2
    # (14, 115, 14),                 # T2
    # (14,  105, 14),                 # T3
    # (14,  95, 14),                 # T3
    # (14,  85, 14),                 # T3
    # (10,  5, 10),                 # T3
    # (10,  10, 10),                 # T3
    # (10,  15, 10),                 # T3
    ]
# #(14,166,14), (14,147,14),(14,110,14), (14, 80,14)]               # (14,298,14), (u,core,d)
N_SIM_PER_BIN  = 50_000                                   # Monte‑Carlo reps
K_BIN_EDGES    = [10,12,15,17,20,22,24,26,28,30,31,32,33,34,35,36,37,38,39,40,42,44,46,48,50,53,55,60,70,80,100,150]         # inclusive edges
PER_CONDITION  = False                                    # True ⇒ per‑cond null
N_WORKERS      = max(2, mp.cpu_count() - 2)
RND_SEED       = 43
rng_global     = np.random.default_rng(RND_SEED)
# ──────────────────────────────────────────────────────────────────── #
# ------------------------------------------------------------------
# >>>  INSERT THIS IN CELL 1 *before* the line that defines TPL_ONES
# ------------------------------------------------------------------
_tpl_cache = {}
_tpl_stats = {}          # (u,core,d) → (tpl‑array, mean, std, half_len)

for (u, core, d) in TEMPLATES:
    tpl        = np.r_[np.ones(u), np.zeros(core), np.ones(d)]
    tpl_m      = tpl.mean()
    tpl_s      = tpl.std()
    half_len   = tpl.size // 2
    _tpl_stats[(u, core, d)] = (tpl, tpl_m, tpl_s, half_len)

# handy aliases used by the caller and the plotting code
(u1, c1, d1)  = TEMPLATES[0]
TPL_1, TPL_M_1, TPL_S_1, TPL_HALF_1 = _tpl_stats[(u1, c1, d1)]
CORE_SIZE_1    = c1
LINKER_LEN_1   = u1

# full‑length (linker+core+linker) for every core size
CORE_TO_FULL_LEN = {core: 1 + core + 1 for (_, core, _) in TEMPLATES}
# ------------------------------------------------------------------

def _fill_met_domains(arr, width):
    """
    Fill NaNs/0s between consecutive 1’s that are ≤ width bp apart.
    (Same logic as the caller.)
    """
    idx = np.where(arr == 1)[0]
    if len(idx) < 2:
        return arr
    out = arr.copy()
    for a, b in zip(idx, idx[1:]):
        if b - a <= width:
            out[a : b + 1] = 1
    return out

# ─── helper: global or per‑condition P(m6A|A) ────────────────────── #
def _estimate_pm6A(seq_like):
    """
    Concatenate the (optionally filled) vectors and return P(vec==1).
    """
    chunks = []
    for x in seq_like:
        if not isinstance(x, (list, np.ndarray)):
            continue
        arr = np.asarray(x, dtype=float)
        if PERFORM_FILLING:
            arr = _fill_met_domains(arr, MET_DOMAIN_WIDTH)
        chunks.append(arr)

    if not chunks:
        return 0.0
    flat = np.concatenate(chunks)
    flat = flat[~np.isnan(flat)]
    return np.count_nonzero(flat == 1) / len(flat)

if PER_CONDITION:
    PM6A = {c: _estimate_pm6A(g["mod_qual_bin"])
            for c, g in merged_df.groupby("condition")}
else:
    PM6A = _estimate_pm6A(merged_df["mod_qual_bin"])

TPL_ONES = {tpl: tpl_arr.mean()      # f_tpl
            for tpl, (tpl_arr, *_ ) in _tpl_stats.items()}
# ───────── WEIGHT FOR vec==1  (derived once, reused everywhere) ───────── #
if USE_M6A_BOOST:
    if isinstance(PM6A, dict):
        W1_BOOST = {c: (1.0 - p) / p for c, p in PM6A.items()}
    else:
        W1_BOOST = (1.0 - PM6A) / PM6A
else:
    W1_BOOST = 1.0 if not isinstance(PM6A, dict) else {c: 1.0 for c in PM6A}


# ─── simulation worker (fork‑safe) ───────────────────────────────── #
def _sim_worker(arg):
    tpl_key, k, pm6A, bin_key, tpl_arr, tpl_m, tpl_s = arg
    rng  = np.random.default_rng()

    # ------------ simulate vector & weights ------------
    idx  = rng.choice(tpl_arr.size, k, replace=False)
    vec  = np.full(tpl_arr.size, np.nan, float)
    vec[idx] = (rng.random(k) < pm6A).astype(float)
    if PERFORM_FILLING:
        vec = _fill_met_domains(vec, MET_DOMAIN_WIDTH)

    f_tpl = TPL_ONES[tpl_key]
    boost = 1.0 if not USE_M6A_BOOST else ((1.0 - pm6A) / pm6A) / f_tpl
    w_all = np.where(vec == 1, boost, 1.0)

    mask = ~np.isnan(vec)
    if mask.sum() < 2:                 # <‑‑ still too little information
        # treat as r = 0 so it counts in the null but never in the tail
        return tpl_key, bin_key, 0.0

    v = vec[mask]
    t = tpl_arr[mask]
    w = w_all[mask]

    w_sum  = w.sum()
    v_mean = np.dot(w, v) / w_sum
    t_mean = np.dot(w, t) / w_sum
    v_var  = np.dot(w, (v - v_mean) ** 2) / w_sum
    t_var  = np.dot(w, (t - t_mean) ** 2) / w_sum
    if v_var == 0 or t_var == 0:                     # <‑‑ constant vector
        return tpl_key, bin_key, np.nan # 0.0   # <<< changed: r = 0.0
    cov    = np.dot(w, (v - v_mean) * (t - t_mean)) / w_sum
    r      = cov / math.sqrt(v_var * t_var)
    return tpl_key, bin_key, r



# ─── build empirical null for one p(m6A) value ───────────────────── #
def _build_null(pm6A):
    bin_ranges = list(zip(K_BIN_EDGES[:-1], K_BIN_EDGES[1:]))
    args = []
    for tpl_key in TEMPLATES:
        tpl_arr = np.r_[np.ones(tpl_key[0]), np.zeros(tpl_key[1]), np.ones(tpl_key[2])]
        tpl_m, tpl_s = tpl_arr.mean(), tpl_arr.std()
        for bin_key in bin_ranges:
            lo, hi = bin_key
            eff_lo = max(1, min(lo, tpl_arr.size))
            eff_hi = min(hi, tpl_arr.size)
            if eff_lo > eff_hi:
                continue
            ks = rng_global.integers(eff_lo, eff_hi + 1, size=N_SIM_PER_BIN)
            args.extend(
                (tpl_key, int(k), pm6A, bin_key, tpl_arr, tpl_m, tpl_s)
                for k in ks
            )

    accum = defaultdict(lambda: defaultdict(list))          # tpl → k‑bin → [r]
    with mp.Pool(N_WORKERS) as pool:
        for tpl_key, bin_key, r in tqdm(
                pool.imap_unordered(_sim_worker, args, chunksize=2048),
                total=len(args), desc="null sims"):
            accum[tpl_key][bin_key].append(r)

    # ---------- convert lists to sorted numpy arrays ------------------------
    return {
        tpl_key: {
            bin_key: np.sort(np.asarray(r_list, float))     # no NaNs anymore
            for bin_key, r_list in bin_dict.items()
        }
        for tpl_key, bin_dict in accum.items()
    }




# ═══════════════════════════════════════════════════════════════════ #
#  Runtime helpers                                                   #
# ═══════════════════════════════════════════════════════════════════ #
def _table(cond=None):
    if not PER_CONDITION:
        return NULL_TABLE
    if cond is None:
        raise ValueError("cond must be specified when PER_CONDITION=True")
    return NULL_TABLE[cond]

def _bin_for_k(k):
    for lo, hi in zip(K_BIN_EDGES[:-1], K_BIN_EDGES[1:]):
        if lo <= k < hi:
            return (lo, hi)
    return max(zip(K_BIN_EDGES[:-1], K_BIN_EDGES[1:]))

def p_value_for_tpl_k(tpl_key, k, r_val, *, cond=None):
    """Upper‑tail p‑value under template‑specific null."""
    arr = _table(cond)[tpl_key][_bin_for_k(k)]
    pos = np.searchsorted(arr, r_val, side="right")
    return 1.0 - pos / arr.size

def alpha_threshold_for_tpl(tpl_key, alpha=0.001, *, cond=None):
    """
    For each k‑bin, return the r value above which the upper‑tail
    probability under the null is ≤ alpha (default 0.05).
    """
    tbl = _table(cond)[tpl_key]
    out = {}
    for bin_key, arr in tbl.items():
        arr = arr[~np.isnan(arr)]          # drop NaNs from zero‑var sims
        if arr.size == 0:                  # safety for tiny k‑bins
            out[bin_key] = np.nan
            continue
        idx = max(0, int(np.ceil((1.0 - alpha) * arr.size)) - 1)
        out[bin_key] = arr[idx]
    return out

ALPHA = 0.05
# ─── run sims (global or per condition) ──────────────────────────── #
if PER_CONDITION:
    NULL_TABLE = {cond: _build_null(p) for cond, p in tqdm(PM6A.items(), desc="conds")}
    ALPHA_THRESHOLDS = {
        cond: {
            tpl: alpha_threshold_for_tpl(tpl, ALPHA, cond=cond)
            for tpl in TEMPLATES
        }
        for cond in PM6A        # PM6A.keys()
    }
else:
    NULL_TABLE = _build_null(PM6A)
    ALPHA_THRESHOLDS = {
        tpl: alpha_threshold_for_tpl(tpl, ALPHA)
        for tpl in TEMPLATES
    }

print("\n[INFO] Template‑specific α=0.05 thresholds:")
if PER_CONDITION:
    for cond, tpl_map in ALPHA_THRESHOLDS.items():
        print(f"\nCondition: {cond}")
        for tpl_key, bin_map in tpl_map.items():
            print(f"  Template {tpl_key}:")
            for (lo, hi), r_thr in bin_map.items():
                print(f"    {lo:3d}–{hi:3d}: r ≥ {r_thr:.3f}")
else:
    for tpl_key, bin_map in ALPHA_THRESHOLDS.items():
        print(f"\nTemplate {tpl_key}:")
        for (lo, hi), r_thr in bin_map.items():
            print(f"    {lo:3d}–{hi:3d}: r ≥ {r_thr:.3f}")




In [None]:
###############################################################################
#  NUC-CALLING PIPELINE v5 + 3rd-pass ACCESSIBLE FILTER (debug only on first read)
#
#  • 1st-pass fixed-template corr-calls   → candidate centres (no pruning)
#  • 2nd-pass grid-search (±REFINE_WIN)  → best core/linker sizes per centre
#  • 3rd-pass accessible-filter → remove or trim any core overlapping ANY
#      high-confidence “open” (continuously methylated) footprint
#  • final overlap-prune keeps highest-r → nuc_coords / nuc_centers
#  • debug (detailed) only for the first read
###############################################################################

import os, math, itertools, warnings, multiprocessing as mp
import numpy as np, pandas as pd
from tqdm.auto import tqdm
from scipy.ndimage import uniform_filter1d, gaussian_filter1d
from scipy.stats  import gaussian_kde, pearsonr
from scipy.signal import find_peaks
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display, Image   # Image used in read-scatter helper

# ────────────────── 0) CONFIGURATION ──────────────────
METHOD                = "corr-calls"         # "fiber-tools" | "corr-calls"

# ---- 1st-pass template (narrow, fast) --------------------------
import os, math, itertools, warnings, multiprocessing as mp
import numpy as np, pandas as pd
from tqdm.auto import tqdm
from scipy.ndimage import uniform_filter1d, gaussian_filter1d

CORR_THRESHOLD    = thr                # min smoothed-r to accept a peak
# ───────────────────── USER-TUNABLE CONFIG ───────────────────── #
# ---- 1st-pass templates (can be 1 or many) ---------------------
TEMPLATES = [                       # (u, core, d) 80, 110, 147, 166
    # (14, 298, 14),                 # T1  (175 bp total)   ← legacy default
    (14, 80, 14),                 # T1  (175 bp total)   ← legacy default
    (14, 110, 14),                 # T1  (175 bp total)   ← legacy default
    (14, 147, 14),                 # T1  (175 bp total)   ← legacy default
    (14, 166, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 170, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 160, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 150, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 140, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 130, 14),                 # T2
    # (14, 120, 14),                 # T2
    # (14, 110, 14),                 # T2
    # (14,  100, 14),                 # T3
    # (14,  90, 14),                 # T3
    # (14,  80, 14),                 # T3
    # (14, 175, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 165, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 155, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 145, 14),                 # T1  (175 bp total)   ← legacy default
    # (14, 135, 14),                 # T2
    # (14, 125, 14),                 # T2
    # (14, 115, 14),                 # T2
    # (14,  105, 14),                 # T3
    # (14,  95, 14),                 # T3
    # (14,  85, 14),                 # T3
    # (10,  5, 10),                 # T3
    # (10,  10, 10),                 # T3
    # (10,  15, 10),                 # T3
]
# Optional per-template correlation thresholds (falls back to CORR_THRESHOLD)
CORR_THRESHOLDS = {
    #(14, 298, 14): 0.2,                 # T1  (175 bp total)   ← legacy default
    (14, 166, 14): 0.2,                 # T1  (175 bp total)   ← legacy default
    (14, 147, 14): 0.2,   # thr1
    (14, 110, 14): 0.2,   # thr2  – adjust as needed
    (14,  80, 14): 0.2,   # thr3
}
USE_SIM_THRESHOLDS = True
print(f"Using {len(TEMPLATES)} templates: {TEMPLATES}")
print(f"Using correlation thresholds: {CORR_THRESHOLDS}")

CORE_RANGE        = range(147, 148, 1)  # 2nd-pass grid – core sizes to test
LINK_RANGE        = range(14, 15, 1)     # 2nd-pass grid – each linker arm
REFINE_WIN        = 1                  # ±bp shifts tested around a centre
MOV_AVG_WINDOW    = 10                  # smoothing 1st-pass r for peak pick

ACCESS_FOOTPRINTS = [250]#25, 50, 75, 100, 150, 250]
ACCESS_THRESHOLDS = [0.8]#,0.8, 0.8, 0.8, 0.8, 0.8]
MIN_CORE_LEN      = min(CORE_RANGE)
N_WORKERS         = max(2, mp.cpu_count() - 2)
CHUNK_SIZE        = 10


CORE_LENGTH_PENALTY = 0        # 0 → no penalty
DEBUG             = True                # ⇢ extra logging on read 0 only

MIN_OVERLAP           = 100                  # bp minimum overlap to accept an r
MAX_NUC_SIZE          = 320                  # for 1st-pass clipping only
# keep only cores whose final r ≥ this value
MIN_FINAL_R = 0       # 0 → disable
USE_SECOND_PASS   = False    # ← set to False to skip grid-search & shift

# ---- misc heuristics (unchanged) ---------------------------------
NUC_MIN               = 75
NUC_COMBINED          = 100
NUC_EXTEND            = 25
ALLOWED_M6A_SKIPS     = 2
# ---- fill / interp --------------------------------------------------------
# ───────── BOOSTING TOGGLE ───────── #
USE_M6A_BOOST = True          # False ⇒ every position weight = 1
# ─────────────────────────────────── #

PERFORM_FILLING       = False
MET_DOMAIN_WIDTH      = 9
PERFORM_INTERP        = False
INTERP_WINDOW         = 1
GAUSS_SMOOTH_RAW      = False
GAUSS_SIGMA_RAW       = 1
# ---- plotting / globals -----------------------------------------------------
REL_POS_RANGE         = 3000
FIG_TEMPLATE          = "plotly_white"
DEFAULT_READ_CLRS     = ["#b12537", "#4974a5", "#47B562", "#808080"]
SUBSAMPLE_READS_PLOT  = 100
SMOOTH_WINDOW_PCT     = 50
CENTER_HIST_BIN       = 10
KDE_POINTS            = 10000
DEBUG                 = True
N_WORKERS             = max(2, mp.cpu_count() - 2)
CHUNK_SIZE            = 10

CHIP_RANK_CUTOFF      = 80
ABOVE_FLAG            = True
KEEP_FIRST_BED_START_ONLY = False   # True ⇒ drop all but first bed_start
TYPES_TO_INCLUDE      = ["MEX_motif"]  #"univ_nuc",,"MEXII_motif" "ALL" for all types, or list of specific types
CHR_TYPE_INCLUDE      = ["X"]#,"Autosome"]
BED_STRANDS_TO_INCLUDE = [] # Filter reads by strand. Examples: ["+"], ["+","-"]. [] ⇒ no filter.
CONDITIONS_TO_INCLUDE = [
    analysis_cond[0],
    # analysis_cond[3],
    # analysis_cond[4],
    analysis_cond[5]
]
MIN_READ_LENGTH        = 300
REQUIRE_CENTRAL        = False
USE_TRIM_READS         = False    # set False to skip the ±175 bp trimming step


def dbg(msg, always=False):
    if DEBUG or always:
        print(f"[DEBUG] {msg}")


# ────────────────── 0a) CACHING CONFIG ──────────────────
import os

# where to stash your final DataFrame
TEMP_DF_PATH   = "/tmp/filtered_reads_df.pkl"

# control flags
USE_CACHE      = False   # if True, and file exists & not forcing replace, we’ll load
FORCE_REPLACE  = True   # if True, always rerun pipeline and overwrite existing file

# ────────────────── MAIN PIPELINE WRAPPER ──────────────────
if USE_CACHE and os.path.exists(TEMP_DF_PATH) and not FORCE_REPLACE:
    filesize = os.path.getsize(TEMP_DF_PATH)
    dbg(f"Cache file found at {TEMP_DF_PATH} ({filesize} bytes), loading…")
    filtered_reads_df = pd.read_pickle(TEMP_DF_PATH)
    dbg(f"✔ Loaded DataFrame: shape={filtered_reads_df.shape}, columns={list(filtered_reads_df.columns)}")
else:
    dbg("⚙️ Running full NUC-CALLING pipeline (cache bypassed or not present)")

    # ────────────────── 1) FILTERING (unchanged) ──────────────────
    chiprank_df = (
        pd.read_csv("/Data1/reference/rex_chiprank.bed", sep=r"\s+")
          .assign(type=lambda d: "MOTIFS_" + d["type"].astype(str))
    )
    chip_rank_lookup = {
        t: round(float(rk) * 100, 3)
        for t, rk in zip(chiprank_df["type"], chiprank_df["chip_rank"])
    }

    keep_conds = set(CONDITIONS_TO_INCLUDE)

    if TYPES_TO_INCLUDE == ["ALL"]:
        # pass every type that actually appears in merged_df
        keep_types = set(merged_df["type"].unique())
        print(f"Including all types: {keep_types}")

    elif TYPES_TO_INCLUDE:  # any non‑empty iterable (manual inclusion)
        keep_types = set(TYPES_TO_INCLUDE)
        print(f"Including specified types: {keep_types}")

    else:  # fall back to chip‑rank threshold logic
        keep_types = {
            t for t, r in chip_rank_lookup.items()
            if (r >= CHIP_RANK_CUTOFF) == ABOVE_FLAG
        }

    keep_chr = (
        set(CHR_TYPE_INCLUDE)
        if CHR_TYPE_INCLUDE
        else set(merged_df["chr_type"].unique())
    )

    # Derive allowed strands
    if "bed_strand" not in merged_df.columns:
        raise KeyError("Expected column 'bed_strand' not found in merged_df")
    if BED_STRANDS_TO_INCLUDE:
        keep_strands = set(BED_STRANDS_TO_INCLUDE)
        dbg(f"Including specified bed_strands: {keep_strands}")
    else:
        keep_strands = set(merged_df["bed_strand"].unique())
        dbg(f"No bed_strand filter set; using all: {keep_strands}")
    
    # Apply metadata filters, now including bed_strand
    df0 = merged_df.query(
        "condition in @keep_conds and type in @keep_types and chr_type in @keep_chr and bed_strand in @keep_strands"
    ).copy()

    # ───── NEW optional single-bed_start filter ─────
    if KEEP_FIRST_BED_START_ONLY and not df0.empty:
        # grab the first 50 unique bed_start values in appearance order
        unique_beds = pd.Index(df0["bed_start"]).unique()[:5]
        pre_n = len(df0)
        df0 = df0[df0["bed_start"].isin(unique_beds)].reset_index(drop=True)
        dbg(f"kept only first 50 bed_starts ({len(unique_beds)} IDs): {pre_n} → {len(df0)} reads")
    # ───
    dbg(f"after metadata filter: {len(df0)} reads")

    mask_overlap = df0["rel_pos"].apply(
        lambda arr: ((arr >= -REL_POS_RANGE) & (arr <= REL_POS_RANGE)).any()
    )
    df0 = df0[mask_overlap].reset_index(drop=True)
    dbg(f"after overlap filter:  {len(df0)} reads")

    # NEW: debug surviving types after metadata filter
    surviving_types = set(df0["type"].unique())
    dbg(f"Types surviving metadata filter: {surviving_types}")

    dbg(f"after metadata filter: {len(df0)} reads")

    if REQUIRE_CENTRAL:
        half = MIN_READ_LENGTH // 2
        # must span ±half around origin
        mask_central = df0["rel_pos"].apply(
            lambda arr: (arr.min() <= -half) and (arr.max() >= half)
        )
        # must also have ≥ MIN_READ_LENGTH bases inside [-REL_POS_RANGE, REL_POS_RANGE]
        overlap_mask = df0["rel_pos"].apply(
            lambda arr: (
                min(arr.max(), REL_POS_RANGE)
                - max(arr.min(), -REL_POS_RANGE)
                + 1
            ) >= MIN_READ_LENGTH
        )
        combined_mask = mask_central & overlap_mask
        filtered_reads_df = df0[combined_mask].reset_index(drop=True)
    else:
        # only enforce overlap-length
        overlap_mask = df0["rel_pos"].apply(
            lambda arr: (
                min(arr.max(), REL_POS_RANGE)
                - max(arr.min(), -REL_POS_RANGE)
                + 1
            ) >= MIN_READ_LENGTH
        )
        filtered_reads_df = df0[overlap_mask].reset_index(drop=True)
    dbg(f"after length/central filter: {len(filtered_reads_df)} reads")

    mask_valid = filtered_reads_df["mod_qual_bin"].apply(
        lambda x: isinstance(x, (list, np.ndarray)) and np.nansum(x) > 0
    )
    filtered_reads_df = filtered_reads_df[mask_valid].reset_index(drop=True)
    dbg(f"after methylation filter: {len(filtered_reads_df)} reads")

    # ────────────────── 2) HELPER FUNCTIONS ──────────────────
    def _fill_met_domains(arr, width):
        idx = np.where(arr == 1)[0]
        if len(idx) < 2:
            return arr
        out = arr.copy()
        for a, b in zip(idx, idx[1:]):
            if b - a <= width:
                out[a : b + 1] = 1
        return out

    def _mode_interpolate(arr, radius):
        isnan = np.isnan(arr)
        if not isnan.any():
            return arr

        valid = (~isnan).astype(int)
        pos1  = (arr ==  1).astype(int) & ~isnan
        neg1  = (arr == -1).astype(int) & ~isnan

        c_val = np.concatenate(([0], np.cumsum(valid)))
        c_pos = np.concatenate(([0], np.cumsum(pos1)))
        c_neg = np.concatenate(([0], np.cumsum(neg1)))

        out = arr.copy()
        for i in np.where(isnan)[0]:
            lo, hi = max(0, i - radius), min(len(arr) - 1, i + radius)
            tot  = c_val[hi + 1] - c_val[lo]
            if tot == 0:
                out[i] = 0.0                # no data in window
                continue
            n_pos = c_pos[hi + 1] - c_pos[lo]
            n_neg = c_neg[hi + 1] - c_neg[lo]
            if n_pos > n_neg:
                out[i] =  1.0
            elif n_neg > n_pos:
                out[i] = -1.0
            else:
                out[i] =  0.0               # tie → neutral
        return out

    def _maybe_gauss(y):
        return gaussian_filter1d(y, GAUSS_SIGMA_RAW, mode="nearest") if GAUSS_SMOOTH_RAW else y

    # ────────────────── 3) FIBER-TOOLS HEURISTIC ──────────────────
    def _fiber_find(m6a_sites):
        nucs=[]
        pre_idx, pre_gap, pre_added = -1, 0, False
        idx = 0
        while idx < len(m6a_sites):
            cur = m6a_sites[idx]
            gap = cur - pre_idx - 1
            nxt = 0 if idx == len(m6a_sites) - 1 else m6a_sites[idx+1] - cur - 1

            can_combine = (pre_idx >= 0 and pre_gap + gap + 1 >= NUC_COMBINED)
            cstart = cur - pre_gap - gap - 1

            if (ALLOWED_M6A_SKIPS >= 1 and can_combine and
                (gap < NUC_MIN or pre_gap < NUC_MIN) and
                max(gap, pre_gap) <= NUC_COMBINED + NUC_EXTEND and
                min(gap, pre_gap) >= NUC_EXTEND):
                if pre_added:
                    nucs.pop()
                nucs.append((cstart, cur - cstart))
                pre_added = True

            elif gap >= NUC_MIN:
                nucs.append((pre_idx + 1, gap))
                pre_added = True

            elif (ALLOWED_M6A_SKIPS >= 1 and can_combine and pre_gap < NUC_MIN and
                  not (nxt > pre_gap and nxt < NUC_MIN)):
                nucs.append((cstart, cur - cstart))
                pre_added = True

            elif (ALLOWED_M6A_SKIPS >= 2 and pre_idx >= 0 and nxt > 0 and pre_gap < NUC_MIN and
                  pre_gap + gap + nxt + 2 >= NUC_COMBINED and nxt < NUC_MIN and
                  gap + nxt + 1 < NUC_COMBINED):
                cur = m6a_sites[idx + 1]
                nucs.append((cstart, cur - cstart))
                idx += 1
                pre_added = True
            else:
                pre_added = False

            pre_gap, pre_idx = gap, cur
            idx += 1
        return nucs

    # ────────────────── TEMPLATE PRE-COMPUTE ────────────────── #
    _tpl_cache = {}
    _tpl_stats = {}          # (u,core,d) → (tpl, mean, std, half_len)

    for (u, core, d) in TEMPLATES:
        tpl = np.r_[ np.ones(u), np.zeros(core), np.ones(d) ]
        tpl_m, tpl_s = tpl.mean(), tpl.std()
        half_len     = tpl.size // 2
        _tpl_stats[(u, core, d)] = (tpl, tpl_m, tpl_s, half_len)

    (u1, c1, d1)      = TEMPLATES[0]
    TPL_1, TPL_M_1, TPL_S_1, TPL_HALF_1 = _tpl_stats[(u1, c1, d1)]
    CORE_SIZE_1       = c1     # alias for legacy constant
    LINKER_LEN_1      = u1

    # CORE_TO_FULL_LEN = {core: u + core + d for (u, core, d) in TEMPLATES}
    CORE_TO_FULL_LEN = {core: 1 + core + 1 for (u, core, d) in TEMPLATES}

    def _greedy_non_overlapping(peaks):
        """
        peaks : list[(centre_bp, core_len_bp, r_val, *extra)]
                *extra is ignored (allows passing tpl_key etc.)

        Returns a list of (centre, core_len, r_val) that:
          1.  are sorted by genomic position,
          2.  are ≥ ½(sum(full_lengths)) apart,
          3.  keep the highest‑r within each collision window
              (ties → longer core wins, but that can only happen
               inside the per‑centre replacement clause below).

        Logic:
          – candidates are pre‑sorted by (centre, r) ascending;
          – we scan left→right, keeping a `current_best`;
          – if the next peak collides, we replace current_best
            only when it has higher r (or equal r but longer core).
        """
        if not peaks:
            return []

        # 1) sort as requested: centre asc, then r asc
        peaks_sorted = sorted(peaks, key=lambda t: (t[0], t[2]))

        winners = []
        cur_cen, cur_core, cur_r = peaks_sorted[0][:3]

        for cen, core, r, *_ in peaks_sorted[1:]:
            full_cur  = CORE_TO_FULL_LEN[cur_core]
            full_next = CORE_TO_FULL_LEN[core]
            clash     = abs(cen - cur_cen) < 0.5 * (full_cur + full_next)

            if clash:
                # keep the better of the two
                if (r > cur_r) or (r == cur_r and core > cur_core):
                    cur_cen, cur_core, cur_r = cen, core, r
            else:
                winners.append((cur_cen, cur_core, cur_r))
                cur_cen, cur_core, cur_r = cen, core, r

        winners.append((cur_cen, cur_core, cur_r))
        return winners


    # ───────────── shared sliding-r helper ───────────── #
    def _corr_window_nan(vec, tpl, *, debug=False, cond=None):
        """
        Weighted Pearson r — weights = W1_BOOST if vec==1 else 1.
        vec : 1 / 0 / NaN
        tpl : 1 or 0
        cond: optional condition name (for per‑condition weighting)
        """
        win = tpl.size
        n   = vec.size
        if n < win:
            return np.empty(0, float)

        # pre‑compute weight vector once for the whole read
        if USE_M6A_BOOST:
            base  = W1_BOOST[cond] if isinstance(W1_BOOST, dict) else W1_BOOST
            boost = base / tpl.mean()
            w_vec = np.where(vec == 1, boost, 1.0)
        else:
            w_vec = np.ones_like(vec, dtype=float)

        out = np.full(n - win + 1, np.nan, dtype=float)
        for i in range(len(out)):
            mask   = ~np.isnan(vec[i : i + win])
            k      = mask.sum()
            if k < 2:
                continue

            v      = vec[i : i + win][mask]
            t      = tpl[mask]
            w      = w_vec[i : i + win][mask]

            w_sum  = w.sum()
            v_mean = np.dot(w, v) / w_sum
            t_mean = np.dot(w, t) / w_sum

            v_var  = np.dot(w, (v - v_mean) ** 2) / w_sum
            t_var  = np.dot(w, (t - t_mean) ** 2) / w_sum
            if v_var == 0 or t_var == 0:
                out[i] = 0.0
                continue

            cov    = np.dot(w, (v - v_mean) * (t - t_mean)) / w_sum
            out[i] = cov / math.sqrt(v_var * t_var)

        return out


    # overwrite the old version
    def _corr_window(vec, tpl, tpl_m=None, tpl_s=None, *, cond=None):
        # pass debug=True to get a printout when you hit a too-short case
        return _corr_window_nan(vec, tpl, cond=cond)


    # 1st pass  (k-aware threshold)
    def _corr_call_first_pass(vec, pos_min, cond=None):
        """
        Multi-template first pass.
        Returns list of (centre_abs, r_val, core_len).
        Tie-break rule: if two centres are within ½(sum of core lenses) bp,
        keep the one with higher r; on exact-r ties, keep the longer core.
        """
        candidates = []                    # (centre, r, core, tpl_key)

        for tpl_key, (tpl, tpl_m, tpl_s, half_len) in _tpl_stats.items():
            r_raw = _corr_window(vec, tpl, tpl_m, tpl_s, cond=cond)
            if r_raw.size == 0:
                continue
            r_smooth = uniform_filter1d(r_raw, MOV_AVG_WINDOW, mode="nearest")
            x_corr   = np.arange(pos_min, pos_min + r_raw.size) + half_len

            peak_idx, _ = find_peaks(r_smooth)

            for p in peak_idx:
                k = int(np.count_nonzero(~np.isnan(vec[p : p + tpl.size])))

                # ─── choose the threshold source ───────────────────────────
                if USE_SIM_THRESHOLDS:
                    thr_src = ALPHA_THRESHOLDS[cond] if PER_CONDITION else ALPHA_THRESHOLDS
                    thr     = thr_src[tpl_key][_bin_for_k(k)]
                else:
                    thr = CORR_THRESHOLDS.get(tpl_key, CORR_THRESHOLD)
                # ───────────────────────────────────────────────────────────

                if r_raw[p] >= thr:
                    # ⬇︎ extra gate for the 298‑bp dinucleosome template
                    if tpl_key == (14, 298, 14):
                        win_start = p                      # index into vec
                        win_end   = p + tpl.size
                        if win_end > vec.size:             # safety (truncated read)
                            continue

                        win_vec   = vec[win_start:win_end]       # NaNs = un‑observed
                        left_ok   = np.nansum(win_vec[:14]  == 1) > 0
                        right_ok  = np.nansum(win_vec[-14:] == 1) > 0
                        if not (left_ok and right_ok):
                            # at least one linker lacks a methylated position – discard peak
                            continue
                        # (optional) dbg:
                        # dbg(f"    kept 298‑bp peak at {int(x_corr[p])} – linkers OK")
                    centre = int(x_corr[p])
                    candidates.append(
                        (centre, tpl_key[1], float(r_raw[p]), tpl_key)
                    )

        # ---------- resolve overlaps / tie-breaks ---------------------
        # sort: highest r first; on equal r, longer core wins
                # Bail out early if no peaks survived the threshold
        if not candidates:
            return []

        # If exactly one candidate, no clustering needed
        if len(candidates) == 1:
            cen, core_len, r_val, _ = candidates[0]
            return [(cen, core_len, r_val)]

        winners = _greedy_non_overlapping(candidates)
        return winners



    # 2nd pass (vectorised)
    # ───────── 2nd-pass: arg-max with core-length penalty ───────── #
    def _best_match_for_centre_vectorised(vec, cen_rel, pre_r_dict):
        """
        Return template/shift with max( r – λ·ΔL ),
        where ΔL = (core_len – CORE_SIZE_1)  and λ = CORE_LENGTH_PENALTY.
        """
        best_score, best_tpl = -np.inf, None

        for (u, core, d), (r_arr, off) in pre_r_dict.items():
            lo = max(0, cen_rel - off - REFINE_WIN)
            hi = min(r_arr.size - 1, cen_rel - off + REFINE_WIN)
            if lo > hi:
                continue

            k     = lo + r_arr[lo:hi + 1].argmax()
            r_val = r_arr[k]

            # ----- one-liner penalty ---------------------------------
            adj_score = r_val - CORE_LENGTH_PENALTY * (core - CORE_SIZE_1)
            # ---------------------------------------------------------

            if adj_score > best_score:
                best_score = adj_score
                best_tpl   = (u, core, d, off, k, r_val)

        if best_tpl is None:               # fallback (shouldn’t happen)
            return cen_rel, LINKER_LEN_1, CORE_SIZE_1, LINKER_LEN_1, -1.0, 0

        u, core, d, off, k, r_val = best_tpl
        shift = (k + off) - cen_rel
        return k + off, u, core, d, r_val, shift

    # 3rd pass (open-region filter)
    def _compute_open_spans(vec, pos_min, dbg_flag=False):
        vec_len = vec.size
        csum    = np.r_[0, np.cumsum(vec)]
        spans   = []
        for L, thr in zip(ACCESS_FOOTPRINTS, ACCESS_THRESHOLDS):
            if L > vec_len:
                continue
            mean = (csum[L:] - csum[:-L]) / L
            for i in np.where(mean >= thr)[0]:
                spans.append((pos_min + i, pos_min + i + L - 1))
        if not spans:
            if dbg_flag:
                dbg("    [accessible] no open intervals detected")
            return []
        spans.sort(key=lambda x: x[0])
        merged = [spans[0]]
        for s, e in spans[1:]:
            if s <= merged[-1][1] + 1:
                merged[-1] = (merged[-1][0], max(merged[-1][1], e))
            else:
                merged.append((s, e))
        if dbg_flag:
            dbg(f"    [accessible] merged open spans = {merged}")
        return merged

    def _remove_under_accessible_and_shift(vec, pos_min, refined, dbg_flag=False):
        open_sp = _compute_open_spans(vec, pos_min, dbg_flag)
        if not open_sp:
            return refined, 0

        out, removed = [], 0
        for cen, core, r in refined:
            s, e   = cen - core//2, cen + core//2
            discard = False
            for os, oe in open_sp:
                if e < os or s > oe:
                    continue
                if s < os and e > oe:         # fully covered → drop
                    discard = True; removed += 1; break
                if e >= os and e <= oe:       # overlap on right → trim only
                    core -= (e - os + 1)
                elif s >= os and s <= oe:     # overlap on left → trim only
                    core -= (oe - s + 1)
                if core < MIN_CORE_LEN:
                    discard = True; removed += 1; break
                s, e = cen - core//2, cen + core//2   # re-calc for next open span
            if not discard:
                out.append((cen, core, r))            # cen unchanged
        if dbg_flag:
            dbg(f"    [accessible] removed={removed}, kept={len(out)}")
        return out, removed

    # prune overlaps (unchanged)
    def _prune_overlaps(cands):
        cands.sort(key=lambda x: x[0])     # left→right
        out = []
        for cen, core, r in cands:
            s, e = cen - core//2 - 10, cen + core//2 + 10
            updated = []
            replace = False

            for oc, oc_core, oc_r in out:
                os, oe = oc - oc_core//2, oc + oc_core//2
                if e < os or s > oe:                   # no overlap
                    updated.append((oc, oc_core, oc_r))
                else:                                  # overlap
                    if r > oc_r or (r == oc_r and core > oc_core):
                        # incoming peak is better → drop the older one
                        replace = True
                    else:
                        # older one is better → discard incoming
                        replace = False
                        break

            if replace or not any(
                    abs(cen - oc) < 0.5 * (core + oc_core) for oc, oc_core, _ in updated):
                updated.append((cen, core, r))

            out = updated
        return out


    # ───────────── per-read worker (returns stats) ───────────── #
    def _call_nucs_for_read(idx_rec):
        try:
            idx, rec = idx_rec
            rel_pos  = np.asarray(rec["rel_pos"], int)
            signal   = np.asarray(rec["mod_qual_bin"], float)

            #signal[signal == 0] = -1          # vectorised in-place

            pos_min  = rel_pos.min()
            span_len = rel_pos.ptp() + 1
            vec      = np.full(span_len, np.nan, float)
            vec[rel_pos - pos_min] = signal
            if PERFORM_FILLING:
                vec = _fill_met_domains(vec, MET_DOMAIN_WIDTH)
            if PERFORM_INTERP:
                vec = _mode_interpolate(vec, INTERP_WINDOW)
            # after optional fill / interp
            # if not PERFORM_INTERP:          # i.e. you skipped the majority vote
            #     vec[np.isnan(vec)] = 0.0    # simple NaN → 0 mapping
            # if np.isnan(vec).any():
            #     raise ValueError(f"NaNs remain in read {rec['read_id']}")

            cand_peaks = _corr_call_first_pass(vec, pos_min, rec['condition'])
            if DEBUG and idx == 0:
                dbg(f"[{rec['read_id']}] 1st-pass peaks = {cand_peaks[:5]}")

            if not cand_peaks:
                empty_stats = dict(n_cands=0, n_refined=0, n_after_access=0,
                                   n_adjusted=0, n_removed_access=0,n_removed_low_r=0,
                                   template_counts={})
                return (idx, [], [], empty_stats)

             # ───────── second-pass handling ─────────
            if USE_SECOND_PASS:
                # -------- existing code (unchanged) --------
                pre_r = {k: (_corr_window(vec, *tpl[:3]), tpl[3])
                         for k, tpl in _tpl_cache.items()}

                refined, template_cnt = [], Counter()
                adjusted = 0
                for cen_abs, _ in cand_peaks:
                    cen_rel = cen_abs - pos_min
                    cen_best_rel, u, c_len, d, r_val, shift = (
                        _best_match_for_centre_vectorised(vec, cen_rel, pre_r)
                    )
                    if (shift != 0) or (u != LINKER_LEN_1) or (c_len != CORE_SIZE_1) or (d != LINKER_LEN_1):
                        adjusted += 1
                    template_cnt[(u, c_len, d)] += 1
                    refined.append((cen_best_rel + pos_min, c_len, r_val))
            else:
                # -------- bypass: keep 1st-pass peaks “as is”, but carry core_len --------
                refined = [
                    (cen_abs, core_len, r_val)
                    for cen_abs, core_len, r_val in cand_peaks
                ]

                adjusted     = 0
                template_cnt = Counter((14, core, 14) for _, core, _ in refined)

            if DEBUG and idx == 0:
                dbg(f"[{rec['read_id']}] 2nd-pass refined = {[(c,cl) for c,cl,_ in refined]}")

            refined, n_removed = _remove_under_accessible_and_shift(
                vec, pos_min, refined, dbg_flag=(idx == 0))

            final = _prune_overlaps(refined)
            # drop weak cores (unchanged)


            # ── NEW helper: expand a dinucleosome core (298 bp ≡ 147+4+147) ──
            # ── helper: expand a dinucleosome core (298 bp ≡ 147+4+147) ──
            def _expand_dinuc(cen, core_len):
                """
                Return nucleosome tuples for downstream storage.

                Regular nucleosome  → (centre, core_len)                (2‑tuple)
                Dinucleosome 298 bp → (centre, 147, 298)  for each half (3‑tuple)

                The third element “parent_core_len” lets the plotting layer colour both
                halves together without affecting any size‑based analytics.
                """
                if core_len == 298:
                    offset = int(round((147 + 4) / 2))      # ≈76 bp; 4 bp linker
                    return [
                        (cen - offset, 147, 298),           # left half
                        (cen + offset, 147, 298),           # right half
                    ]
                return [(cen, core_len)]                    # unchanged for 80‑,110‑,147‑bp

            # ── build expanded coord / centre lists ───────────────────────────
            kept = [t for t in final if t[2] >= MIN_FINAL_R]
            nuc_coords_exp = []
            for cen, core, _ in kept:
                nuc_coords_exp.extend(_expand_dinuc(cen, core))

            # tuples may now be length‑2 or length‑3 → grab only index 0
            nuc_centres_exp = [t[0] for t in nuc_coords_exp]

            # ---------- stats bookkeeping -----------------------------------
            stats = dict(
                n_cands          = len(cand_peaks),
                n_refined        = len(refined) + n_removed,
                n_after_access   = len(nuc_coords_exp),       # ← updated count
                n_adjusted       = adjusted,
                n_removed_access = n_removed,
                n_removed_low_r  = len(final) - len(kept),
                template_counts  = dict(template_cnt),
            )

            return (idx, nuc_centres_exp, nuc_coords_exp, stats)
        except Exception as e:
            import traceback, sys
            tb = traceback.format_exc()
            print(f"\n\n[WORKER EXCEPTION] read-idx {idx_rec[0]}\n{tb}\n",
                  file=sys.stderr, flush=True)
            raise  # re-raise so the main loop knows this task failed

    # ╔════════════════════════════════════════════════════════════════════════╗
    # ║  MULTI-PROCESS DRIVER + GLOBAL STATS  (clean, single instance)         ║
    # ╚════════════════════════════════════════════════════════════════════════╝
    from collections import Counter
    import gc, sys

    dbg("Running nucleosome caller v7.2 – starting pool", always=True)

    records = filtered_reads_df.to_dict("records")
    n_tasks = len(records)

    results_slots = [None] * n_tasks       # pre-allocate to keep order
    received = 0
    with mp.Pool(N_WORKERS) as pool, tqdm(total=n_tasks, desc="nuc-calls") as bar:
        for idx, *payload in pool.imap_unordered(
                _call_nucs_for_read, enumerate(records), chunksize=CHUNK_SIZE):
            if results_slots[idx] is None:
                results_slots[idx] = (idx, *payload)
                received += 1
                bar.update(1)

    if received != n_tasks:
        missing = n_tasks - received
        raise RuntimeError(
            f"Pool finished early: expected {n_tasks} results, got {received}. "
            f"Check worker stderr for exceptions (↑).")

    gc.collect()  # explicitly free the per-read r-arrays

    # ────────── aggregate statistics ──────────
    tot_cand = tot_ref = tot_final = tot_adj = tot_removed = 0
    tpl_hist = Counter()
    tot_removed_low_r = 0

    for _, _, _, st in results_slots:
        tot_cand    += st["n_cands"]
        tot_ref     += st["n_refined"]
        tot_final   += st["n_after_access"]
        tot_adj     += st["n_adjusted"]
        tot_removed += st["n_removed_access"]
        tot_removed_low_r += st["n_removed_low_r"]
        tpl_hist.update(st["template_counts"])

    dbg("\n================  GLOBAL NUC-CALL STATS  ================ ", always=True)
    dbg(f"• 1st-pass peaks (candidates):       {tot_cand:,}", always=True)
    dbg(f"• After 2nd-pass refinement:         {tot_ref:,}", always=True)
    dbg(f"      ↳ adjusted template/centre:    {tot_adj:,}", always=True)
    dbg(f"• Removed by accessible filter:      {tot_removed:,}", always=True)
    dbg(f"• Final non-overlapping cores:       {tot_final:,}\n", always=True)
    dbg(f"• Removed for r < {MIN_FINAL_R}:      {tot_removed_low_r:,}", always=True)

    dbg("Top templates chosen (u,core,d)  count", always=True)
    for tpl, ct in tpl_hist.most_common(15):
        dbg(f"   {tpl}   {ct:,}", always=True)

    sys.stdout.flush()  # guarantee all DEBUG lines are shown before cell prompt

    # ────────── attach results to DataFrame & cache ──────────
    results_slots.sort(key=lambda x: x[0])
    filtered_reads_df["nuc_centers"] = [cent for _, cent, _, _ in results_slots]
    filtered_reads_df["nuc_coords"]  = [coord for _, _, coord, _ in results_slots]

    # ╔════════════════════════════════════════════════════════════════════════╗
    # ║  4) TRIM READS TO ±175 bp AROUND FIRST/LAST NUCLEOSOME               ║
    # ╚════════════════════════════════════════════════════════════════════════╝
    from statistics import mean, median
    from collections  import Counter
    import numpy as np
    from tqdm.auto    import tqdm

    TRIM_PAD_BP = 175          # ⇠ single configurable constant

    # columns whose *array-like* entries must be sliced
    ARRAY_COLS  = [
        "mod_qual", "base_qual",
        "forward_read_position", "ref_position", "rel_pos", "mod_qual_bin"
    ]

    # helper ------------------------------------------------------------------
    def _trim_read(row):
        """Return (row, n_drop_start, n_drop_end) or None if the read is dropped."""
        centres = row["nuc_centers"]
        if not centres:                 # ➊ drop reads with no nucleosomes
            return None

        # window in *rel_pos* space (integers; round OK)
        keep_lo = int(min(centres) - TRIM_PAD_BP)
        keep_hi = int(max(centres) + TRIM_PAD_BP)

        # vectorised mask on rel_pos
        rel = np.asarray(row["rel_pos"])
        mask = (rel >= keep_lo) & (rel <= keep_hi)

        n_start = int((rel < keep_lo).sum())
        n_end   = int((rel > keep_hi).sum())

        # ➋ slice array-type columns
        for col in ARRAY_COLS:
            arr = row[col]
            row[col] = [v for v, m in zip(arr, mask) if m]

        # ➌ scalar updates
        row["rel_read_start"] = max(row["rel_read_start"], keep_lo)
        row["rel_read_end"]   = min(row["rel_read_end"],   keep_hi)
        row["read_length"]    = row["rel_read_end"] - row["rel_read_start"] + 1

        # ➍ motif filtering (keep indexing stable)
        if row["motif_rel_start"] is not None:
            keep_idx = [
                i for i, s in enumerate(row["motif_rel_start"])
                if keep_lo <= s <= keep_hi
            ]
            row["motif_rel_start"]  = tuple(row["motif_rel_start"][i]  for i in keep_idx)
            row["motif_attributes"] = tuple(row["motif_attributes"][i] for i in keep_idx)

        return row, n_start, n_end


    # ---------------------------- main loop -----------------------------------
    trimmed_records   = []
    drop_no_nuc       = 0
    drop_stats_start  = []   # bp discarded on 5' side
    drop_stats_end    = []   # bp discarded on 3' side
    example_before    = None
    example_after     = None
    
    if USE_TRIM_READS:
        
        dbg("⚙️  Trimming reads to ±175 bp around nucleosome block…", always=True)
        for i, row in tqdm(filtered_reads_df.iterrows(), total=len(filtered_reads_df), desc="trim"):
            if example_before is None:          # stash an untouched copy of the 1st row
                example_before = row.copy(deep=True)
    
            out = _trim_read(row)
            if out is None:                     # dropped for empty nuc list
                drop_no_nuc += 1
                continue
    
            row, n_lo, n_hi = out
            drop_stats_start.append(n_lo)
            drop_stats_end.append(n_hi)
    
            if example_after is None:           # first surviving row after trim
                example_after = row.copy(deep=True)
    
            trimmed_records.append(row)

        # rebuild DataFrame ---------------------------------------------------------
        trimmed_df = pd.DataFrame(trimmed_records).reset_index(drop=True)
    
        # ----------------------------- DEBUG LINES ---------------------------------
        dbg("\n================  TRIM STATS  ================ ", always=True)
        dbg(f"• Reads dropped (no nucleosomes):   {drop_no_nuc:,}", always=True)
    
        if drop_stats_start:
            disc_start = np.array(drop_stats_start)
            disc_end   = np.array(drop_stats_end)
            disc_tot   = disc_start + disc_end
    
            dbg("• bp removed at start  (5'): "
                f"min={disc_start.min():,}  "
                f"median={int(median(disc_start)):,}  "
                f"mean={mean(disc_start):.1f}  "
                f"max={disc_start.max():,}", always=True)
    
            dbg("• bp removed at end    (3'): "
                f"min={disc_end.min():,}  "
                f"median={int(median(disc_end)):,}  "
                f"mean={mean(disc_end):.1f}  "
                f"max={disc_end.max():,}", always=True)
    
            dbg("• bp removed combined: "
                f"min={disc_tot.min():,}  "
                f"median={int(median(disc_tot)):,}  "
                f"mean={mean(disc_tot):.1f}  "
                f"max={disc_tot.max():,}\n", always=True)
        else:
            dbg("• No reads needed trimming.\n", always=True)
    
        # pretty-print example ------------------------------------------------------
        if example_before is not None and example_after is not None:
            dbg("===== EXAMPLE ROW BEFORE TRIM =====", always=True)
            dbg(example_before, always=True)
            dbg("===== EXAMPLE ROW AFTER  TRIM =====", always=True)
            dbg(example_after,  always=True)
    
        # finally, swap the DataFrame so everything downstream sees the trimmed set
        filtered_reads_df = trimmed_df
        dbg(f"✔ Trimmed DataFrame shape = {filtered_reads_df.shape}", always=True)
    else:
        dbg("⚙️  Skipping trimming step (USE_TRIM_READS=False)", always=True)
    # filtered_reads_df.to_pickle(TEMP_DF_PATH)
    # dbg(f"💾 Saved to {TEMP_DF_PATH} ({os.path.getsize(TEMP_DF_PATH)} bytes)", always=True)

In [None]:
# print unique combinations of condition and exp_id in filtered_reads_df, with count of each
print("Unique combinations of condition and exp_id in filtered_reads_df, with count of each:")
print(filtered_reads_df.groupby(['condition', 'exp_id']).size().reset_index(name='count'))
print(filtered_reads_df.groupby(['type', 'bed_start']).size().reset_index(name='count'))

# print 1 row
print("\nOne row from filtered_reads_df:")
print(filtered_reads_df.iloc[0])
# ─────────────────── CONFIG ───────────────────
TEMPLATE     = "plotly_white"
WIDTH, HEIGHT = 950, 600
OUTDIR       = "plots"
SAVE_HTML    = True
WITHIN_BIN   = "longest"   # 'longest' or 'random'
STEP_PCT     = 2.5         # 0.5% steps of the 99th percentile
UPPER_Q      = 99.0
RANDOM_STATE = 42
DEBUG_SUMMARY= True

# ─────────────────── IMPORTS ──────────────────
import os, numpy as np, pandas as pd
import plotly.graph_objects as go

os.makedirs(OUTDIR, exist_ok=True)
rng = np.random.default_rng(RANDOM_STATE)

# ──────────────── HELPERS ─────────────────────
def _ecdf_xy(v):
    v = np.asarray(v, float)
    v = v[~np.isnan(v)]
    if v.size == 0:
        return np.array([]), np.array([])
    x = np.sort(v)
    y = np.arange(1, x.size + 1) / x.size
    return x, y

def _bin_edges_from_p99(lengths, step_pct=0.5, upper_q=99.0):
    v = np.asarray(lengths, float)
    v = v[~np.isnan(v)]
    if v.size == 0:
        raise ValueError("No read lengths available.")
    p99 = np.percentile(v, upper_q)
    vmax = float(np.max(v))
    step = max(1.0, (step_pct/100.0) * p99)  # at least 1 bp
    edges = np.arange(0.0, p99 + step, step)
    edges = np.unique(edges)
    # include overflow bin up to max length if needed
    if vmax > edges[-1]:
        edges = np.concatenate([edges, [vmax]])
    # ensure at least two edges
    if edges.size < 2:
        edges = np.array([0.0, vmax])
    return edges

def downsample_bin_matched(
    df, group_col="condition", len_col="read_length",
    within_bin="longest", edges=None, random_state=42
):
    rng = np.random.default_rng(random_state)
    d = df[[group_col, len_col]].copy()
    d = d.dropna(subset=[len_col])
    d["_bin"] = pd.cut(d[len_col].astype(float), bins=edges, include_lowest=True, right=True)

    # per-bin min across conditions; keep only feasible bins
    ct = d.groupby([group_col, "_bin"]).size().rename("n").reset_index()
    per_bin_min = ct.groupby("_bin")["n"].min()
    feasible_bins = per_bin_min[per_bin_min > 0].index
    per_bin_min = per_bin_min.loc[feasible_bins]

    selected_idx = []
    for b in feasible_bins:
        target = int(per_bin_min.loc[b])
        sub_b = d[d["_bin"] == b]
        for cond, sub_bc in sub_b.groupby(group_col, sort=False):
            pool_idx = sub_bc.index
            if within_bin == "longest":
                chosen = sub_bc.sort_values(len_col, ascending=False).index[:target]
            else:
                chosen = rng.choice(pool_idx.to_numpy(), size=target, replace=False)
            selected_idx.append(np.asarray(chosen))

    if selected_idx:
        selected_idx = np.concatenate(selected_idx)
        out = df.loc[selected_idx].copy()
    else:
        out = df.iloc[0:0].copy()

    return out, feasible_bins

def plot_ecdf(df, title, group_col="condition", len_col="read_length", width=950, height=600, template="plotly_white", save_path=None):
    fig = go.Figure()
    for cond, sub in df.groupby(group_col, sort=False):
        x, y = _ecdf_xy(sub[len_col].values)
        fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=str(cond)))
    fig.update_layout(template=template, width=width, height=height,
                      title=title, xaxis_title="Read length (bp)", yaxis_title="Cumulative fraction")
    fig.show()
    if save_path:
        fig.write_html(save_path)

def plot_box(df, title, group_col="condition", len_col="read_length", width=950, height=600, template="plotly_white", save_path=None):
    fig = go.Figure()
    for cond, sub in df.groupby(group_col, sort=False):
        fig.add_trace(go.Box(y=sub[len_col], name=str(cond), boxmean=True))
    fig.update_layout(template=template, width=width, height=height,
                      title=title, yaxis_title="Read length (bp)")
    fig.show()
    if save_path:
        fig.write_html(save_path)

# ───────────── PREP + BINS ────────────────────
assert "condition" in filtered_reads_df.columns and "read_length" in filtered_reads_df.columns
pre_df = filtered_reads_df.copy()
pre_df = pre_df.dropna(subset=["read_length"])

edges = _bin_edges_from_p99(pre_df["read_length"].values, step_pct=STEP_PCT, upper_q=UPPER_Q)

# ───────────── DOWNSAMPLE ─────────────────────
filtered_reads_df, feasible_bins = downsample_bin_matched(
    pre_df, group_col="condition", len_col="read_length",
    within_bin=WITHIN_BIN, edges=edges, random_state=RANDOM_STATE
)

# ───────────── STATS ──────────────────────────
before = pre_df.groupby("condition").size().rename("n_before")
after  = filtered_reads_df.groupby("condition").size().rename("n_after")
stats  = pd.concat([before, after], axis=1).fillna(0).astype(int)
stats["retained_frac"] = np.where(stats["n_before"]>0, stats["n_after"]/stats["n_before"], np.nan)
stats = stats.sort_index()

total_before = int(stats["n_before"].sum())
total_after  = int(stats["n_after"].sum())
total_frac   = (total_after / total_before) if total_before > 0 else np.nan

if DEBUG_SUMMARY:
    print("Feasible bins retained:", len(feasible_bins), "of", len(edges)-1, "total bins")
    print("Reads before (total):", total_before)
    print("Reads after  (total):",  total_after)
    print(f"Overall retained fraction: {total_frac:.4f}")
    display(stats)

# ───────────── PLOTS: BEFORE/AFTER ────────────
plot_ecdf(
    pre_df,
    title="ECDF of read lengths by condition — BEFORE bin-matched downsampling",
    save_path=os.path.join(OUTDIR, "ecdf_before.html") if SAVE_HTML else None,
    template=TEMPLATE, width=WIDTH, height=HEIGHT
)

plot_ecdf(
    filtered_reads_df,
    title="ECDF of read lengths by condition — AFTER bin-matched downsampling",
    save_path=os.path.join(OUTDIR, "ecdf_after.html") if SAVE_HTML else None,
    template=TEMPLATE, width=WIDTH, height=HEIGHT
)

plot_box(
    pre_df,
    title="Read length by condition (box) — BEFORE downsampling",
    save_path=os.path.join(OUTDIR, "box_before.html") if SAVE_HTML else None,
    template=TEMPLATE, width=WIDTH, height=HEIGHT
)

plot_box(
    filtered_reads_df,
    title="Read length by condition (box) — AFTER downsampling",
    save_path=os.path.join(OUTDIR, "box_after.html") if SAVE_HTML else None,
    template=TEMPLATE, width=WIDTH, height=HEIGHT
)

In [None]:
# ╔══════════════════════════════════════════════════════════════════╗
# ║  CELL X: edge‑growing / merging of nucleosomes (parallel)        ║
# ╚══════════════════════════════════════════════════════════════════╝
#
#  New behaviour (opt‑in):
#    ⋄ set SPLIT_LONG_NUCS = True
#    ⋄ any grown nucleosome of 200‑450 bp is split into two equal halves
#    ⋄ any grown nucleosome  >450 bp is dropped
# --------------------------------------------------------------------

import multiprocessing as mp
from copy import deepcopy
from itertools import repeat
from tqdm.auto import tqdm
import numpy as np
import pandas as pd

DEBUG_FIRST      = True                 # original flag

# ─── NEW CONFIG ────────────────────────────────────────────────────
SPLIT_LONG_NUCS  = True                 # ← toggle the feature
SPLIT_MIN_LEN    = 200                  # lower bound (inclusive)
SPLIT_MAX_LEN    = 400                  # upper bound (inclusive)
# ───────────────────────────────────────────────────────────────────

# ─── helpers (unchanged) ──────────────────────────────────────────
def _interval_from_coord(cen, core):
    half = core // 2
    return [cen - half, cen + half]     # mutable [start,end]

def _coords_from_interval(start, end):
    core = end - start                  # inclusive → core_len=end‑start
    cen  = start + core // 2
    return (int(cen), int(core))

# ─── worker ───────────────────────────────────────────────────────
def _grow_single(row_idx):
    row   = filtered_reads_df.loc[row_idx]
    rel_p = np.asarray(row["rel_pos"],      dtype=int)
    m_bin = np.asarray(row["mod_qual_bin"], dtype=int)
    obs1  = {p for p,b in zip(rel_p, m_bin) if b == 1}

    # your original intervals
    intervals = sorted(_interval_from_coord(c,cl)
                       for (c,cl,*_) in row["nuc_coords"])
    grown = deepcopy(intervals)

    # ← INSERT THESE TWO LINES
    left_edge, right_edge = rel_p.min(), rel_p.max()

    max_it = max(1, len(rel_p))
    for _ in range(max_it):
        changed = False

        # ── REPLACE the old “grow starts” loop with this ───────────
        # ── REPLACE the old “grow starts” loop with this ───────────
        for i, (s, e) in enumerate(grown):
            if s in obs1:                       # edge already on a mod site → skip
                continue
            new_s = s - 1
            if new_s in obs1:                   # proposed base on a mod site → skip
                continue
            if ((i > 0 and new_s > grown[i-1][1])         # interior
                or (i == 0 and new_s >= left_edge)):       # left‑most
                grown[i][0] = new_s
                changed = True
    
        # ── REPLACE the old “grow ends” loop with this ─────────────
        for i in reversed(range(len(grown))):
            s, e = grown[i]
            if e in obs1:                       # edge already on a mod site → skip
                continue
            new_e = e + 1
            if new_e in obs1:                   # proposed base on a mod site → skip
                continue
            if ((i < len(grown)-1 and new_e < grown[i+1][0])    # interior
                or (i == len(grown)-1 and new_e <= right_edge)): # right‑most
                grown[i][1] = new_e
                changed = True

        if not changed:
            break

    # merge touching / overlapping
    merged=[]
    for s,e in sorted(grown):
        if not merged or s > merged[-1][1]+1:
            merged.append([s,e])
        else:
            merged[-1][1]=max(merged[-1][1],e)

    # ─── NEW post‑processing: split or drop long cores ────────────
    processed=[]
    for s,e in merged:
        core_len = e-s
        if core_len > SPLIT_MAX_LEN and SPLIT_LONG_NUCS:
            continue                              # drop
        if SPLIT_LONG_NUCS and SPLIT_MIN_LEN <= core_len <= SPLIT_MAX_LEN:
            half = core_len // 2
            left_end  = s + half - 1              # inclusive
            right_sta = left_end + 1
            processed.extend([[s, left_end],
                              [right_sta, e]])
        else:
            processed.append([s,e])

    new_coords  = [_coords_from_interval(s,e) for s,e in processed]
    new_centers = [c for c,_ in new_coords]

    dbg=None
    if DEBUG_FIRST and row_idx==0:
        dbg=dict(
            read_id=row["read_id"],
            original_intervals=intervals,
            grown_intervals=[tuple(iv) for iv in grown],
            merged_intervals=[tuple(iv) for iv in merged],
            final_intervals=[tuple(iv) for iv in processed],
            new_coords=new_coords
        )
    return row_idx,new_centers,new_coords,dbg

# ─── multiprocess driver (unchanged) ──────────────────────────────
idx_all   = filtered_reads_df.index.to_numpy()
n_workers = max(2, mp.cpu_count()-2)

print(f"⚙️  Growing nucleosome edges using {n_workers} workers…")
with mp.Pool(n_workers) as pool, tqdm(total=len(idx_all), desc="expand") as bar:
    for row_idx, centers, coords, dbg in pool.imap_unordered(_grow_single,
                                                             idx_all,
                                                             chunksize=128):
        filtered_reads_df.at[row_idx,"nuc_centers"]=centers
        filtered_reads_df.at[row_idx,"nuc_coords" ]=coords
        if dbg is not None:
            print("\n[DEBUG — first read]")
            for k,v in dbg.items(): print(f"  {k:20s}: {v}")
        bar.update(1)

print("✅ edge‑growing complete – DataFrame updated.")


In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# Add ChIP signal columns from 50-bp bedgraphs using N nearest bins to center
# Fix: use float distances when masking the containing bin (prevents OverflowError)
# Adds: DEBUG prints and validation summary
# ──────────────────────────────────────────────────────────────────────────────
import os
import numpy as np
import pandas as pd
import multiprocessing as mp
from tqdm.auto import tqdm

# ─── Config ───────────────────────────────────────────────────────────────────
N_NEAREST_BINS = 4           # strictly this many selected when available
PROGRESS       = True        # tqdm progress
SUMMARIZE      = True        # final summary lines
DEBUG_VALIDATE = True        # extra validation checks and sample prints
DEBUG_SAMPLES  = 3           # number of rows to sample for detailed checks
N_WORKERS      = max(2, mp.cpu_count() - 2)

BEDGRAPH_PATHS = {
    "dpy27_chip": "/Data1/ext_data/qiming_2024/50bp_bins/BMQY009_2_DPY27_rep1_antimNeon.sort.bedgraph",
    "sdc2_chip":  "/Data1/ext_data/qiming_2024/50bp_bins/BMQY009_5_SDC2_rep1_antimNeon.sort.bedgraph",
    "sdc3_chip":  "/Data1/ext_data/qiming_2024/50bp_bins/BMQY010_2_SDC3_rep1_antimNeon.sort.bedgraph",
}

CHR_MAP = {
    "chrI":"CHROMOSOME_I", "chrII":"CHROMOSOME_II", "chrIII":"CHROMOSOME_III",
    "chrIV":"CHROMOSOME_IV", "chrV":"CHROMOSOME_V", "chrX":"CHROMOSOME_X",
}

# Ensure chromosome values match the bedgraph mapping
if filtered_reads_df["chrom"].dtype.name == "category":
    filtered_reads_df["chrom"] = filtered_reads_df["chrom"].astype(str)
filtered_reads_df["chrom"] = filtered_reads_df["chrom"].astype(str)

# ─── Helpers ──────────────────────────────────────────────────────────────────
def _load_bedgraph(path):
    df = pd.read_csv(
        path, sep=r"\s+", header=None, names=["chrom","start","end","value"],
        dtype={"chrom":"category","start":np.int64,"end":np.int64,"value":np.float32}
    )
    df["chrom"] = df["chrom"].map(CHR_MAP).astype("category")
    df = df.dropna(subset=["chrom"]).sort_values(["chrom","start"]).reset_index(drop=True)
    df["center"] = (df["start"].values + df["end"].values) // 2
    out = {}
    for c, sub in df.groupby("chrom", sort=False):
        out[str(c)] = {
            "start":  sub["start"].to_numpy(np.int64, copy=False),
            "end":    sub["end"].to_numpy(np.int64, copy=False),
            "center": sub["center"].to_numpy(np.int64, copy=False),
            "value":  sub["value"].to_numpy(np.float32, copy=False),
        }
    return out

def _containing_index(starts, ends, pos):
    """Index i with starts[i] ≤ pos < ends[i], or None if not contained."""
    i = int(np.searchsorted(ends, pos, side="right"))
    if i == len(ends):
        return None
    return i if starts[i] <= pos < ends[i] else None

def _mean_nearest_bins(bdg_chrom, pos, n_bins):
    """Mean of up to n_bins nearest bins to pos. Always include containing bin if present."""
    if bdg_chrom is None:
        return np.nan
    starts  = bdg_chrom["start"]
    ends    = bdg_chrom["end"]
    centers = bdg_chrom["center"]
    values  = bdg_chrom["value"]
    if len(values) == 0:
        return np.nan

    # Locate containing bin if present
    i_cont = _containing_index(starts, ends, pos)

    # Float distances so we can use np.inf safely
    dist = np.abs(centers - int(pos)).astype(np.float64)

    chosen = []
    if i_cont is not None:
        chosen.append(i_cont)

    # Select remaining neighbors by smallest distance, excluding i_cont
    k_rem = max(0, n_bins - len(chosen))
    if k_rem > 0:
        if i_cont is not None:
            dist[i_cont] = np.inf  # safe now because dist is float64
        finite_ct = int(np.isfinite(dist).sum())
        k_take = min(k_rem, finite_ct)
        if k_take > 0:
            idxs = np.argpartition(dist, kth=k_take-1)[:k_take]
            chosen.extend(list(idxs))

    if not chosen:
        return np.nan
    return float(np.nanmean(values[np.array(chosen, dtype=int)]))

# Globals for workers
_BDG_IDX = {}
_NBINS   = N_NEAREST_BINS

def _init_pool(bdg_idx, n_bins):
    global _BDG_IDX, _NBINS
    _BDG_IDX = bdg_idx
    _NBINS   = n_bins

def _worker(task):
    chrom, bstart, bend = task
    center = (int(bstart) + int(bend)) // 2
    out = [chrom, bstart, bend]
    for key in ("dpy27_chip","sdc2_chip","sdc3_chip"):
        bdg_chrom = _BDG_IDX[key].get(chrom)
        out.append(_mean_nearest_bins(bdg_chrom, center, _NBINS))
    return tuple(out)

# ─── Load bedgraphs once ──────────────────────────────────────────────────────
_bdg_idx = {name: _load_bedgraph(path) for name, path in BEDGRAPH_PATHS.items()}

# ─── Unique queries to minimize work ──────────────────────────────────────────
keys_df = filtered_reads_df.loc[:, ["chrom","bed_start","bed_end"]].drop_duplicates().reset_index(drop=True)
tasks   = [tuple(x) for x in keys_df.to_records(index=False)]

# ─── Parallel compute ─────────────────────────────────────────────────────────
results = []
if PROGRESS:
    with mp.Pool(N_WORKERS, initializer=_init_pool, initargs=(_bdg_idx, N_NEAREST_BINS)) as pool:
        for rec in tqdm(pool.imap_unordered(_worker, tasks, chunksize=1024),
                        total=len(tasks), desc=f"chip-annot (N={N_NEAREST_BINS})"):
            results.append(rec)
else:
    with mp.Pool(N_WORKERS, initializer=_init_pool, initargs=(_bdg_idx, N_NEAREST_BINS)) as pool:
        results = pool.map(_worker, tasks, chunksize=1024)

res_df = pd.DataFrame(results, columns=["chrom","bed_start","bed_end","dpy27_chip","sdc2_chip","sdc3_chip"])

# ─── Merge back ───────────────────────────────────────────────────────────────
pre_cols = set(filtered_reads_df.columns)
filtered_reads_df = filtered_reads_df.merge(res_df, on=["chrom","bed_start","bed_end"], how="left")
post_cols = set(filtered_reads_df.columns)

# ─── Validation / Debug ───────────────────────────────────────────────────────
added = ["dpy27_chip","sdc2_chip","sdc3_chip"]
missing_cols = [c for c in added if c not in post_cols]
if missing_cols:
    raise RuntimeError(f"Expected columns missing after merge: {missing_cols}")

if SUMMARIZE:
    n_keys = len(keys_df)
    n_rows = len(filtered_reads_df)
    na_counts = filtered_reads_df[added].isna().sum().to_dict()
    print(f"[chip-annot] unique keys: {n_keys:,}  rows updated: {n_rows:,}  N_NEAREST_BINS={N_NEAREST_BINS}")
    print(f"[chip-annot] NaN counts: {na_counts}")

if DEBUG_VALIDATE:
    # 1) Show distinct chromosomes found in bedgraphs vs dataframe
    bg_chroms = sorted({c for d in _bdg_idx.values() for c in d.keys()})
    df_chroms = sorted(filtered_reads_df["chrom"].unique().tolist())
    print(f"[debug] bedgraph chroms: {bg_chroms}")
    print(f"[debug] df chroms:       {df_chroms}")

    # 2) Sample a few rows and recompute single-row values to verify containment logic
    samp = filtered_reads_df.sample(min(DEBUG_SAMPLES, len(filtered_reads_df)), random_state=1)
    for i, r in enumerate(samp.itertuples(index=False), 1):
        chrom  = r.chrom
        center = (int(r.bed_start) + int(r.bed_end)) // 2
        for key in ("dpy27_chip","sdc2_chip","sdc3_chip"):
            bdg_chrom = _bdg_idx[key].get(chrom)
            recomputed = _mean_nearest_bins(bdg_chrom, center, N_NEAREST_BINS)
            orig = getattr(r, key)
            ok = (np.isnan(orig) and np.isnan(recomputed)) or (not np.isnan(orig) and np.isclose(orig, recomputed, rtol=0, atol=1e-6))
            print(f"[debug] row#{i} {key}: center={center}  orig={orig}  recomputed={recomputed}  match={ok}")

    # 3) Basic stats
    desc = filtered_reads_df[added].describe(percentiles=[0.1,0.5,0.9]).T
    with pd.option_context("display.max_columns", 8, "display.precision", 3):
        print("[debug] column summary:")
        print(desc)


In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# Add/overwrite two columns on filtered_reads_df:
#   • ind           → distance between nearest left/right nucleosome around rel_pos=0
#                      configurable: "centers" or "edges", using the SAME center-bracketed pair
#                      optional gating by # of m6A (mod_qual_bin==1) between the pair
#   • percent_m6a   → % of mod_qual_bin==1 within rel_pos ∈ [M6A_WIN]
# Notes:
#   • Pair selection ALWAYS brackets x=0 by centers from nuc_coords for both modes.
#   • In "edges" mode: IND = (cR - kR//2) - (cL + kL//2). Negative gaps → NaN.
#   • If nuc_coords missing cores and IND_MODE=="centers", falls back to nuc_centers at x=0.
#   • Safe to rerun: columns are overwritten.
# ──────────────────────────────────────────────────────────────────────────────
import numpy as np
import pandas as pd
import multiprocessing as mp
from tqdm.auto import tqdm

# ========================== CONFIG ========================== #
IND_MODE                 = "edges"   # "edges" or "centers"
MIN_M6A_IN_IND           = 2         # require ≥ this many mod_qual_bin==1 between the chosen pair; 0 disables
INCLUDE_EDGES_IN_IND     = True      # gating interval closed vs open at boundaries
MAX_IND_BP               = None      # e.g., 600; None disables exclusion by max gap
USE_CANONICAL_CORE       = None      # e.g., 147 to force standard core; None uses per-nuc core lengths
M6A_WIN                  = (-150, 150) # rel_pos window for percent_m6a only
SHOW_PROGRESS            = True
SHOW_SUMMARY             = True
DEBUG_VALIDATE           = True
DEBUG_SAMPLES            = 3
_CHUNKSIZE               = 1024

# ======================= WORKERS SETUP ====================== #
try:
    _NWORKERS = max(2, mp.cpu_count() - 2) if 'N_WORKERS' not in globals() else max(2, int(N_WORKERS))
except Exception:
    _NWORKERS = max(2, mp.cpu_count() - 2)

# ======================= HELPERS (IND) ====================== #
def _centers_and_cores(nuc_coords):
    """Return sorted centers and matching cores from [(center, core_len, ...), ...]."""
    if not nuc_coords:
        return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
    cc, kk = [], []
    for t in nuc_coords:
        try:
            c = int(t[0])
            k = int(t[1])
        except Exception:
            continue
        if USE_CANONICAL_CORE is not None:
            k = int(USE_CANONICAL_CORE)
        cc.append(c); kk.append(k)
    if not cc:
        return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
    cc = np.asarray(cc, dtype=np.int64)
    kk = np.asarray(kk, dtype=np.int64)
    order = np.argsort(cc, kind="mergesort")
    return cc[order], kk[order]

def _pair_indices_by_center_at_x(C, x=0):
    """Return indices (l, r) bracketing x by centers, or (None, None)."""
    n = C.size
    if n < 2:
        return None, None
    r = np.searchsorted(C, x, side="left")  # first center ≥ x
    l = r - 1                                # last center <  x
    if l < 0 or r >= n:
        return None, None
    return l, r

def _count_m6a_between(rel_pos, mod_bin, lo, hi, include_edges=True):
    """Count mod_qual_bin==1 within (lo,hi) or [lo,hi] depending on include_edges."""
    if rel_pos is None or mod_bin is None:
        return 0
    rp = np.asarray(rel_pos); mb = np.asarray(mod_bin)
    if rp.size == 0 or mb.size == 0:
        return 0
    n = min(rp.size, mb.size)
    rp = rp[:n]; mb = mb[:n]
    if include_edges:
        mask = (rp >= lo) & (rp <= hi)
    else:
        mask = (rp >  lo) & (rp <  hi)
    if not np.any(mask):
        return 0
    sel = mb[mask]
    return int((sel == 1).sum())

def _ind_at_rel0_from_nuc_coords(nuc_coords, rel_pos, mod_bin, mode="edges",
                                 min_m6a=0, include_edges=True, max_ind_bp=None):
    """
    Compute IND at x=0 using center-bracketed pair. Optionally gate on #m6A between the pair.
      centers: IND = cR - cL                (≥0 by construction)
      edges:   IND = (cR - kR//2) - (cL + kL//2); negative → NaN
    """
    C, K = _centers_and_cores(nuc_coords)
    l, r = _pair_indices_by_center_at_x(C, x=0)
    if l is None or r is None:
        return np.nan

    cL, cR = C[l], C[r]
    if mode == "centers":
        ind = float(cR - cL)
        lo, hi = (cL, cR)
    else:
        kL, kR = K[l], K[r]
        end_left    = cL + (kL // 2)
        start_right = cR - (kR // 2)
        ind = float(start_right - end_left)
        lo, hi = (end_left, start_right)

    if max_ind_bp is not None and ind > max_ind_bp:
        return np.nan
    if mode == "edges" and ind < 0:
        return np.nan

    if min_m6a > 0 and ind >= 0:
        cnt = _count_m6a_between(rel_pos, mod_bin, lo, hi, include_edges=include_edges)
        if cnt < min_m6a:
            return np.nan

    return ind

def _ind_at_rel0_centers_fallback(nuc_centers):
    """Fallback if cores are unavailable and IND_MODE == 'centers'."""
    if not nuc_centers or len(nuc_centers) < 2:
        return np.nan
    arr = np.asarray(nuc_centers, dtype=np.int64)
    left_mask  = arr <= 0
    right_mask = arr >  0
    if not left_mask.any() or not right_mask.any():
        return np.nan
    left  = arr[left_mask].max()
    right = arr[right_mask].min()
    d = int(right - left)
    return float(d) if d >= 0 else np.nan

# ======================= HELPERS (%m6A) ====================== #
def _pct_m6a_in_window(rel_pos, mod_bin, lo=-50, hi=50):
    """% of 1's among {0,1} in window [lo,hi]. NaN if none."""
    if rel_pos is None or mod_bin is None:
        return np.nan
    rp = np.asarray(rel_pos); mb = np.asarray(mod_bin)
    if rp.size == 0 or mb.size == 0:
        return np.nan
    n = min(rp.size, mb.size)
    rp = rp[:n]; mb = mb[:n]
    mask = (rp >= lo) & (rp <= hi)
    if not np.any(mask):
        return np.nan
    mbw = mb[mask]
    valid = (mbw == 0) | (mbw == 1)
    if not np.any(valid):
        return np.nan
    mbv = mbw[valid]
    return 100.0 * (float((mbv == 1).sum()) / float(mbv.size))

# ========================== WORKER ========================== #
def _compute_two_metrics(row_idx):
    """Worker: returns (row_idx, ind, pct_m6a)."""
    row = filtered_reads_df.loc[row_idx]

    # IND at x=0 using nuc_coords; fallback for centers if needed
    ind_val = _ind_at_rel0_from_nuc_coords(
        row.get("nuc_coords", []),
        row.get("rel_pos", []),
        row.get("mod_qual_bin", []),
        mode=IND_MODE,
        min_m6a=MIN_M6A_IN_IND,
        include_edges=INCLUDE_EDGES_IN_IND,
        max_ind_bp=MAX_IND_BP
    )
    if np.isnan(ind_val) and IND_MODE == "centers":
        # optional fallback if cores missing
        ind_val = _ind_at_rel0_centers_fallback(row.get("nuc_centers", []))

    # %m6A in fixed window
    pct_val = _pct_m6a_in_window(row.get("rel_pos", []), row.get("mod_qual_bin", []),
                                 lo=M6A_WIN[0], hi=M6A_WIN[1])
    return (row_idx, ind_val, pct_val)

# ======================= PARALLEL MAP ======================== #
idx_all = filtered_reads_df.index.to_numpy()
results = []
if SHOW_PROGRESS:
    with mp.Pool(_NWORKERS) as pool:
        for rec in tqdm(pool.imap_unordered(_compute_two_metrics, idx_all, chunksize=_CHUNKSIZE),
                        total=len(idx_all), desc=f"ind({IND_MODE}, m6A≥{MIN_M6A_IN_IND}) + %m6A"):
            results.append(rec)
else:
    with mp.Pool(_NWORKERS) as pool:
        results = pool.map(_compute_two_metrics, idx_all, chunksize=_CHUNKSIZE)

# =================== WRITE BACK TO DATAFRAME ================= #
inds  = pd.Series({i: ind for (i, ind, pct) in results}, dtype=float)
pcts  = pd.Series({i: pct for (i, ind, pct) in results}, dtype=float)

filtered_reads_df["ind"]         = inds.reindex(filtered_reads_df.index).astype(float)
filtered_reads_df["percent_m6a"] = pcts.reindex(filtered_reads_df.index).astype(float)

# ========================= VALIDATION ======================== #
if SHOW_SUMMARY:
    n_rows = len(filtered_reads_df)
    na_ind = int(filtered_reads_df["ind"].isna().sum())
    na_pct = int(filtered_reads_df["percent_m6a"].isna().sum())
    print(f"[add-cols] mode={IND_MODE}  m6A_gate={MIN_M6A_IN_IND}  include_edges={INCLUDE_EDGES_IN_IND}  "
          f"max_ind={MAX_IND_BP}  canon_core={USE_CANONICAL_CORE}")
    print(f"[add-cols] rows={n_rows:,}  workers={_NWORKERS}")
    print(f"[add-cols] NaN counts → ind: {na_ind:,}  percent_m6a: {na_pct:,}")

if DEBUG_VALIDATE:
    with pd.option_context("display.precision", 3):
        print("[debug] ind summary:")
        print(filtered_reads_df["ind"].describe(percentiles=[0.1,0.5,0.9]))
        print("[debug] percent_m6a summary:")
        print(filtered_reads_df["percent_m6a"].describe(percentiles=[0.1,0.5,0.9]))
    sample_n = min(DEBUG_SAMPLES, len(filtered_reads_df))
    if sample_n > 0:
        samp = filtered_reads_df.sample(sample_n, random_state=11)
        for i, r in enumerate(samp.itertuples(index=False), 1):
            ind_re = _ind_at_rel0_from_nuc_coords(
                r.nuc_coords, r.rel_pos, r.mod_qual_bin,
                mode=IND_MODE, min_m6a=MIN_M6A_IN_IND,
                include_edges=INCLUDE_EDGES_IN_IND, max_ind_bp=MAX_IND_BP
            )
            if np.isnan(ind_re) and IND_MODE == "centers":
                ind_re = _ind_at_rel0_centers_fallback(getattr(r, "nuc_centers", []))
            pct_re = _pct_m6a_in_window(r.rel_pos, r.mod_qual_bin, lo=M6A_WIN[0], hi=M6A_WIN[1])
            ok_ind = (np.isnan(r.ind) and np.isnan(ind_re)) or (np.isclose(r.ind, ind_re, rtol=0, atol=1e-9))
            ok_pct = (np.isnan(r.percent_m6a) and np.isnan(pct_re)) or (np.isclose(r.percent_m6a, pct_re, rtol=0, atol=1e-9))
            print(f"[debug] row#{i} ind({r.ind} vs {ind_re}) match={ok_ind} | %m6A({r.percent_m6a} vs {pct_re}) match={ok_pct}")


In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# Motif-0 lnP ceiling filter (applies ONLY to selected types) + IND and m6A plots
#
# • Apply lnP@rel_start==0 filter ONLY for types in TARGET_TYPES.
#     – For TARGET_TYPES: keep iff motif at 0 exists AND lnP <= LNP_CEILING.
#     – For all other types: keep unconditionally (no lnP requirement).
# • Debug prints show what was filtered among TARGET_TYPES, then plots proceed.
# • MODE:
#     - "PAIR": show both conditions and connect with green (Δ>0) / orange (Δ<0).
#     - "DELTA": plot Δ vs ChIP.  (No regression overlay in this cell.)
#
# Requirements in `filtered_reads_df`:
#   ['condition','type','bed_start','ind','percent_m6a',
#    'motif_rel_start','motif_attributes',
#    'dpy27_chip','sdc2_chip','sdc3_chip']
# ──────────────────────────────────────────────────────────────────────────────
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from nanotools import get_color

# ─── CONFIG ───────────────────────────────────────────────────────────────────
try:
    COND1, COND2 = analysis_cond[5], analysis_cond[0]
except Exception:
    COND1, COND2 = "COND_BASE", "COND_TREAT"

MODE         = "DELTA"      # "PAIR" or "DELTA"
LNP_CEILING  = -15.5      # lnP ceiling for motif at rel_start == 0 (keep <= this)
IND_CEILING  = 1000         # None to disable (applies to IND only)
TEMPLATE     = "plotly_white"
EPS          = 1e-6
CLR_POS      = "#2CA02C"    # green for positive Δ
CLR_NEG      = "#FF7F0E"    # orange for negative Δ
X_AXIS_MODE = "rank"   # or "rank"
TARGET_TYPES = {"MEX_motif","MEXII_motif","motifC"}  # lnP filter applies only to these
CHR_ONLY = "X" # "X" "Autosome" or None
TYPES_ONLY = ["MEX_motif","MEXII_motif","univ_nuc"]#, "MEXII_motif","univ_nuc"]#, "motifC"]  # example
CHIP_MAP = {
    "DPY-27 ChIP": "dpy27_chip",
    #"SDC-2 ChIP":  "sdc2_chip",
    #"SDC-3 ChIP":  "sdc3_chip",
}

need_cols = {"condition","type","bed_start","ind","percent_m6a",
             "motif_rel_start","motif_attributes","chr_type"} | set(CHIP_MAP.values())
missing = need_cols - set(filtered_reads_df.columns)
if missing:
    raise KeyError(f"Missing columns: {sorted(missing)}")

# ─── Extract lnP at rel_start == 0 ────────────────────────────────────────────
def _get_lnp_at_zero(row):
    rs = row["motif_rel_start"]; attrs = row["motif_attributes"]
    if rs is None or attrs is None:
        return np.nan
    try:
        rs_list = list(rs)
    except Exception:
        return np.nan
    if 0 not in rs_list:
        return np.nan
    i = rs_list.index(0)
    try:
        return float(attrs[i][1])  # (motif_name, lnP, strand, seq)
    except Exception:
        return np.nan

df0 = filtered_reads_df[list(need_cols)].copy()
df0["condition"] = df0["condition"].astype(str)
df0["type"]      = df0["type"].astype(str)

if TYPES_ONLY:
    TYPES_ONLY = [str(t) for t in TYPES_ONLY]
    df0 = df0[df0["type"].astype(str).isin(TYPES_ONLY)].copy()
    if df0.empty:
        raise RuntimeError(f"No rows left after type filter: {TYPES_ONLY}")

# ── chr filter (normalize then subset) ─────────────────────────────────
df0["chr_type"] = np.where(df0["chr_type"].astype(str).str.upper()=="X", "X", "Autosome")
if CHR_ONLY in {"X","Autosome"}:
    df0 = df0[df0["chr_type"] == CHR_ONLY].copy()
    if df0.empty:
        raise RuntimeError(f"No rows after chr_type filter '{CHR_ONLY}'.")

lnp0 = df0.apply(_get_lnp_at_zero, axis=1)
is_target   = df0["type"].isin(TARGET_TYPES)

# Apply filter ONLY on TARGET_TYPES
has_zero_t  = is_target & lnp0.notna()
fail_t      = is_target & (lnp0 > float(LNP_CEILING)) & lnp0.notna()
keep_target = is_target & lnp0.notna() & (lnp0 <= float(LNP_CEILING))
keep_other  = ~is_target  # all non-target rows pass
mask_keep   = keep_other | keep_target

# ─── Debug summary ────────────────────────────────────────────────────────────
n_total       = len(df0)
n_target      = int(is_target.sum())
n_other       = n_total - n_target
n_t_no_zero   = int((is_target & lnp0.isna()).sum())
n_t_fail      = int(fail_t.sum())
n_t_keep      = int(keep_target.sum())
n_other_keep  = int((~is_target).sum())
n_keep_total  = int(mask_keep.sum())

print(f"[filter] total={n_total:,}  target_types={n_target:,}  other_types={n_other:,}")
print(f"[filter] TARGET drop (no motif@0): {n_t_no_zero:,}")
print(f"[filter] TARGET drop (lnP@0 > {LNP_CEILING}): {n_t_fail:,}")
print(f"[filter] kept TARGET: {n_t_keep:,}  kept OTHER: {n_other_keep:,}  kept TOTAL: {n_keep_total:,}")

if n_t_no_zero > 0:
    ex = df0.loc[is_target & lnp0.isna(), ["type","bed_start","motif_rel_start"]].head(5)
    print("[filter] examples (TARGET missing 0-pos motif):")
    print(ex.to_string(index=False))
if n_t_fail > 0:
    cols = ["type","bed_start"];
    if "read_id" in df0.columns: cols = ["read_id"] + cols
    ex = df0.loc[fail_t, cols].head(5).copy()
    ex["lnP_0"] = lnp0[fail_t].head(5).values
    print(f"[filter] examples (TARGET failing lnP ceiling > {LNP_CEILING}):")
    print(ex.to_string(index=False))

# Apply
df = df0[mask_keep].copy()
df["lnp0"] = lnp0[mask_keep].values
if df.empty:
    raise RuntimeError("All rows filtered; nothing to plot.")

# ─── Helpers for plots ────────────────────────────────────────────────────────
# ─── Shape helpers (place above the plotting loops) ───────────────────────────
MarkerTrace = go.Scatter  # SVG scatter → supports marker.symbol
def _rank_int(s):
    # unique ranks 1..N in ascending order; stable for ties
    return pd.Series(s).rank(method="first").astype(int)

def _xaxis_layout():
    if str(X_AXIS_MODE).upper() == "RANK":
        return dict(title="Rank by ChIP (weak → strong)")
    else:
        return dict(title="ChIP signal (log10)", type="log")

def _attach_chr(dfin, df_source,
                chr_col_candidates=("chr_type","bed_chr_type","bed_chr","chr","chrom")):
    """Ensure a 'chr_type' column with values in {'X','Autosome'}."""
    out = dfin.copy()
    for c in chr_col_candidates:
        if c in df_source.columns:
            m = df_source[["bed_start", c]].drop_duplicates("bed_start")
            if c in ("bed_chr","chr","chrom"):
                v = m[c].astype(str).str.upper()
                is_x = v.eq("X") | v.eq("CHRX") | v.str.endswith("X")
                m = pd.DataFrame({
                    "bed_start": m["bed_start"],
                    "chr_type": np.where(is_x, "X", "Autosome")
                })
            else:
                m = m.rename(columns={c: "chr_type"})
            out = out.merge(m, on="bed_start", how="left")
            break
    if "chr_type" not in out.columns:
        out["chr_type"] = "X"
    else:
        out["chr_type"] = out["chr_type"].fillna("X")
        mask = ~out["chr_type"].astype(str).str.upper().isin(["X","AUTOSOME"])
        out.loc[mask, "chr_type"] = "Autosome"
    # quick sanity
    try:
        print(f"[chr] counts:", out["chr_type"].value_counts(dropna=False).to_dict())
    except Exception:
        pass
    return out

def _by_chr_groups(dfin):
    if "chr_type" not in dfin.columns:
        yield ("all", "circle", dfin)
        return
    su = dfin["chr_type"].astype(str).str.upper()
    m_auto = su != "X"
    yield ("X",        "circle", dfin[~m_auto])
    yield ("Autosome", "square", dfin[m_auto])

def _add_shape_legend(fig):
    # legend-only guides for shapes
    fig.add_trace(MarkerTrace(x=[None], y=[None], mode="markers",
        name="X (circle)", showlegend=True,
        marker=dict(symbol="circle", color="#666", size=8)))
    fig.add_trace(MarkerTrace(x=[None], y=[None], mode="markers",
        name="Autosome (square)", showlegend=True,
        marker=dict(symbol="square", color="#666", size=8)))

def _agg_ind(dfin, chip_col, cond):
    sub = dfin[dfin["condition"] == cond]
    if sub.empty: return pd.DataFrame(columns=["type","bed_start","ind_med","chip_med"])
    return (sub.groupby(["type","bed_start"], as_index=False)
              .agg(ind_med=("ind","median"), chip_med=(chip_col,"median")))

def _agg_m6a(dfin, chip_col, cond):
    sub = dfin[dfin["condition"] == cond]
    if sub.empty: return pd.DataFrame(columns=["type","bed_start","m6a_mean","chip_med"])
    return (sub.groupby(["type","bed_start"], as_index=False)
              .agg(m6a_mean=("percent_m6a","mean"), chip_med=(chip_col,"median")))

def _pair_filter(dfin, x1, y1, x2, y2, apply_ind_ceiling=False):
    m = (np.isfinite(dfin[x1]) & np.isfinite(dfin[y1]) &
         np.isfinite(dfin[x2]) & np.isfinite(dfin[y2]) &
         (dfin[x1] > 0) & (dfin[x2] > 0))
    if apply_ind_ceiling and IND_CEILING is not None:
        m &= (dfin[y1] <= IND_CEILING) & (dfin[y2] <= IND_CEILING)
    return dfin[m].copy()

def _build_seg_arrays(dfin, x1, y1, x2, y2):
    pos_x, pos_y, neg_x, neg_y = [], [], [], []
    for r in dfin.itertuples(index=False):
        dy = getattr(r, y2) - getattr(r, y1)
        xs = [getattr(r, x1) + EPS, getattr(r, x2) + EPS, None]
        ys = [getattr(r, y1),       getattr(r, y2),       None]
        if dy >= 0: pos_x += xs; pos_y += ys
        else:       neg_x += xs; neg_y += ys
    return pos_x, pos_y, neg_x, neg_y

# --- Trend helpers (rank x) ---------------------------------------------------
TREND_WINDOW = None  # set to an odd int (e.g. 101) to override; None = auto (~5% of N)

def _spearman_no_scipy(x, y):
    xr = pd.Series(x).rank(method="average").to_numpy(float)
    yr = pd.Series(y).rank(method="average").to_numpy(float)
    xc = xr - xr.mean(); yc = yr - yr.mean()
    den = np.sqrt((xc**2).sum() * (yc**2).sum())
    return np.nan if den == 0 else float((xc*yc).sum() / den)

def _overlay_rank_trends(fig, dfin, xcol, ycol, window=None, legend_group="trend"):
    d = dfin[[xcol, ycol]].dropna().sort_values(xcol)
    if d.empty: return np.nan, 0
    n = len(d)
    if window is None:
        window = max(31, int(round(n * 0.05)) | 1)  # odd, ~5% of N
    minp = max(10, window // 3)

    x = d[xcol].to_numpy(float)
    s = d[ycol].reset_index(drop=True)

    med = s.rolling(window, center=True, min_periods=minp).median()
    try:
        q25 = s.rolling(window, center=True, min_periods=minp).quantile(0.25)
        q75 = s.rolling(window, center=True, min_periods=minp).quantile(0.75)
    except Exception:
        q25 = s.rolling(window, center=True, min_periods=minp).apply(lambda v: np.nanpercentile(v, 25))
        q75 = s.rolling(window, center=True, min_periods=minp).apply(lambda v: np.nanpercentile(v, 75))

    mask = med.notna() & q25.notna() & q75.notna()
    if not mask.any(): return np.nan, 0
    xv, lo, hi, mv = x[mask.to_numpy()], q25[mask].to_numpy(float), q75[mask].to_numpy(float), med[mask].to_numpy(float)

    # IQR ribbon (regular Scatter so fill works)
    fig.add_trace(go.Scatter(x=xv, y=lo, mode="lines",
        line=dict(width=0), showlegend=False, hoverinfo="skip", name="_iqr_lo", legendgroup=legend_group))
    fig.add_trace(go.Scatter(x=xv, y=hi, mode="lines", fill="tonexty",
        line=dict(width=0), fillcolor="rgba(0,0,0,0.15)",
        name="Rolling IQR", showlegend=True, hoverinfo="skip", legendgroup=legend_group))
    # Median
    fig.add_trace(go.Scatter(x=xv, y=mv, mode="lines",
        line=dict(width=2, color="#222"),
        name="Rolling median", showlegend=True, hoverinfo="skip", legendgroup=legend_group))

    # Spearman (rank-based)
    rho = _spearman_no_scipy(d[xcol], d[ycol])
    fig.add_annotation(xref="paper", yref="paper", x=0.98, y=0.02, xanchor="right", yanchor="bottom",
                       text=f"Spearman ρ={rho:.2f} (n={n})", showarrow=False, font=dict(size=12, color="#333"))
    return rho, n
# ─────────────────────────── IND plots ────────────────────────────────────────
for title, chip_col in CHIP_MAP.items():
    a1 = _agg_ind(df, chip_col, COND1)
    a2 = _agg_ind(df, chip_col, COND2)
    if a1.empty or a2.empty:
        print(f"[IND] {title}: missing data for one condition after filter."); continue
    pair = a1.merge(a2, on=["type","bed_start"], suffixes=("_c1","_c2"))
    if pair.empty:
        print(f"[IND] {title}: no overlapping (type, bed_start)."); continue

    if MODE.upper() == "PAIR":
        pair = _pair_filter(pair, "chip_med_c1","ind_med_c1", "chip_med_c2","ind_med_c2", apply_ind_ceiling=True)
        if pair.empty:
            print(f"[IND-PAIR] {title}: nothing after numeric/ceiling filter."); continue

        # choose x
        if str(X_AXIS_MODE).upper() == "RANK":
            pair["x_c1"] = _rank_int(pair["chip_med_c1"])
            pair["x_c2"] = _rank_int(pair["chip_med_c2"])
            x1, x2 = "x_c1", "x_c2"
            x_hover = "rank"
        else:
            pair["x_c1"] = pair["chip_med_c1"].to_numpy(float) + EPS
            pair["x_c2"] = pair["chip_med_c2"].to_numpy(float) + EPS
            x1, x2 = "x_c1", "x_c2"
            x_hover = "ChIP"

        # segments by sign of ΔIND
        pos_x, pos_y, neg_x, neg_y = _build_seg_arrays(pair, x1,"ind_med_c1", x2,"ind_med_c2")
        fig = go.Figure()
        if pos_x:
            fig.add_trace(go.Scattergl(x=pos_x, y=pos_y, mode="lines",
                line=dict(color=CLR_POS, width=1), name="ΔIND>0", showlegend=True, hoverinfo="skip"))
        if neg_x:
            fig.add_trace(go.Scattergl(x=neg_x, y=neg_y, mode="lines",
                line=dict(color=CLR_NEG, width=1), name="ΔIND<0", showlegend=True, hoverinfo="skip"))

        pair_chr = _attach_chr(pair, df)
        for cond, xcol, ycol in [(COND1,"x_c1","ind_med_c1"),
                                 (COND2,"x_c2","ind_med_c2")]:
            for grp, sym, sub in _by_chr_groups(pair_chr):
                if sub.empty: continue
                fig.add_trace(MarkerTrace(
                    x=sub[xcol].to_numpy(float),
                    y=sub[ycol].to_numpy(float),
                    mode="markers",
                    name=str(cond) if grp == "X" else f"{cond} (auto)",
                    showlegend=True,
                    legendgroup=str(cond),
                    marker=dict(color=get_color(str(cond)), size=7, opacity=0.85, symbol=sym),
                    customdata=np.stack([sub["type"].to_numpy(str), sub["bed_start"].to_numpy()], axis=1),
                    hovertemplate=(f"cond=%{{fullData.name}}<br>type=%{{customdata[0]}}"
                                   f"<br>bed_start=%{{customdata[1]}}<br>IND=%{{y:.1f}} bp"
                                   f"<br>{x_hover}=%{{x}}<extra></extra>")
                ))

        _add_shape_legend(fig)
        fig.update_layout(
            title=f"{title}: IND paired ({COND1} ↔ {COND2}) [lnP0≤{LNP_CEILING} for {sorted(TARGET_TYPES)}]",
            template=TEMPLATE,
            xaxis=_xaxis_layout(),
            yaxis=dict(title="IND (bp)"),
            legend=dict(title="Condition / Shapes", itemsizing="constant", groupclick="toggleitem"),
            width=800
        )
        print(f"[IND-PAIR] {title}: pairs={len(pair):,}  pos={bool(pos_x)}  neg={bool(neg_x)}")
        fig.show()

    elif MODE.upper() == "DELTA":
        joined = pair.copy()
        if IND_CEILING is not None:
            joined = joined[(joined["ind_med_c1"] <= IND_CEILING) & (joined["ind_med_c2"] <= IND_CEILING)]
        if joined.empty:
            print(f"[IND-DELTA] {title}: nothing after ceiling."); continue

        joined["delta_ind"] = joined["ind_med_c2"] - joined["ind_med_c1"]
        joined["chip_x"] = np.nanmedian(
            np.stack([joined["chip_med_c1"].to_numpy(float),
                      joined["chip_med_c2"].to_numpy(float)], axis=1), axis=1)
        m = np.isfinite(joined["delta_ind"]) & np.isfinite(joined["chip_x"]) & (joined["chip_x"] > 0)
        data = joined.loc[m, ["type","bed_start","delta_ind","chip_x"]]
        if data.empty:
            print(f"[IND-DELTA] {title}: nothing after finite/log filter."); continue

        # choose x (rank or chip)
        if str(X_AXIS_MODE).upper() == "RANK":
            data["x_plot"] = pd.Series(data["chip_x"]).rank(method="first").astype(int)
            x_hover = "rank"
        else:
            data["x_plot"] = data["chip_x"].to_numpy(float) + EPS
            x_hover = "ChIP"

        fig = go.Figure()
        if str(X_AXIS_MODE).upper() == "RANK":
            _overlay_rank_trends(fig, data, "x_plot", "delta_ind", window=TREND_WINDOW)
        data_chr = _attach_chr(data, df)
        for typ, sub_t in data_chr.groupby("type", sort=False):
            for grp, sym, sub in _by_chr_groups(sub_t):
                if sub.empty: continue
                fig.add_trace(MarkerTrace(
                    x=sub["x_plot"].to_numpy(float),
                    y=sub["delta_ind"].to_numpy(float),
                    mode="markers",
                    name=str(typ) if grp == "X" else f"{typ} (auto)",
                    showlegend=True,
                    legendgroup=str(typ),
                    marker=dict(color=get_color(str(typ)), size=7, opacity=0.85, symbol=sym),
                    customdata=np.stack([sub["type"].to_numpy(str), sub["bed_start"].to_numpy()], axis=1),
                    hovertemplate=(f"ΔIND=%{{y:.1f}} bp<br>type=%{{customdata[0]}}"
                                   f"<br>bed_start=%{{customdata[1]}}<br>{x_hover}=%{{x}}"
                                   "<extra></extra>")
                ))

        _add_shape_legend(fig)
        fig.update_layout(
            title=f"{title}: ΔIND ({COND2} − {COND1}) [lnP0≤{LNP_CEILING} for {sorted(TARGET_TYPES)}]",
            template=TEMPLATE,
            xaxis=_xaxis_layout(),
            yaxis=dict(title="ΔIND (bp)"),
            legend=dict(title="Type / Shapes", itemsizing="constant", groupclick="toggleitem"),
            width=800
        )
        print(f"[IND-DELTA] {title}: points={len(data):,}")
        fig.show()


    else:
        raise ValueError("MODE must be 'PAIR' or 'DELTA'")

# ───────────────────── m6A accessibility plots ────────────────────────────────
for title, chip_col in CHIP_MAP.items():
    b1 = _agg_m6a(df, chip_col, COND1)
    b2 = _agg_m6a(df, chip_col, COND2)
    if b1.empty and b2.empty:
        print(f"[m6A] {title}: no data after filter."); continue
    pair = b1.merge(b2, on=["type","bed_start"], suffixes=("_c1","_c2"))
    if pair.empty:
        print(f"[m6A] {title}: no overlapping (type, bed_start)."); continue

    if MODE.upper() == "PAIR":
        pair = _pair_filter(pair, "chip_med_c1","m6a_mean_c1", "chip_med_c2","m6a_mean_c2", apply_ind_ceiling=False)
        if pair.empty:
            print(f"[m6A-PAIR] {title}: nothing after numeric filter."); continue

        if str(X_AXIS_MODE).upper() == "RANK":
            pair["x_c1"] = _rank_int(pair["chip_med_c1"])
            pair["x_c2"] = _rank_int(pair["chip_med_c2"])
            x1, x2 = "x_c1","x_c2"
            x_hover = "rank"
        else:
            pair["x_c1"] = pair["chip_med_c1"].to_numpy(float) + EPS
            pair["x_c2"] = pair["chip_med_c2"].to_numpy(float) + EPS
            x1, x2 = "x_c1","x_c2"
            x_hover = "ChIP"

        pos_x, pos_y, neg_x, neg_y = _build_seg_arrays(pair, x1,"m6a_mean_c1", x2,"m6a_mean_c2")
        fig = go.Figure()
        if pos_x:
            fig.add_trace(go.Scattergl(x=pos_x, y=pos_y, mode="lines",
                line=dict(color=CLR_POS, width=1), name="Δm6A>0", showlegend=True, hoverinfo="skip"))
        if neg_x:
            fig.add_trace(go.Scattergl(x=neg_x, y=neg_y, mode="lines",
                line=dict(color=CLR_NEG, width=1), name="Δm6A<0", showlegend=True, hoverinfo="skip"))

        pair_chr = _attach_chr(pair, df)
        for cond, xcol, ycol in [(COND1,"x_c1","m6a_mean_c1"),
                                 (COND2,"x_c2","m6a_mean_c2")]:
            for grp, sym, sub in _by_chr_groups(pair_chr):
                if sub.empty: continue
                fig.add_trace(MarkerTrace(
                    x=sub[xcol].to_numpy(float),
                    y=sub[ycol].to_numpy(float),
                    mode="markers",
                    name=str(cond) if grp == "X" else f"{cond} (auto)",
                    showlegend=True,
                    legendgroup=str(cond),
                    marker=dict(color=get_color(str(cond)), size=7, opacity=0.85, symbol=sym),
                    customdata=np.stack([sub["type"].to_numpy(str), sub["bed_start"].to_numpy()], axis=1),
                    hovertemplate=(f"cond=%{{fullData.name}}<br>type=%{{customdata[0]}}"
                                   f"<br>bed_start=%{{customdata[1]}}<br>mean m6A=%{{y:.2f}}%%"
                                   f"<br>{x_hover}=%{{x}}<extra></extra>")
                ))

        _add_shape_legend(fig)
        fig.update_layout(
            title=f"{title}: Accessibility (%m6A) paired ({COND1} ↔ {COND2}) [lnP0 filter only for {sorted(TARGET_TYPES)}]",
            template=TEMPLATE,
            xaxis=_xaxis_layout(),
            yaxis=dict(title="Accessibility (% m6A in ±50 bp)"),
            legend=dict(title="Condition / Shapes", itemsizing="constant", groupclick="toggleitem"),
            width=800
        )
        print(f"[m6A-PAIR] {title}: pairs={len(pair):,}  pos={bool(pos_x)}  neg={bool(neg_x)}")
        fig.show()


    elif MODE.upper() == "DELTA":
        pair = pair[
            np.isfinite(pair["m6a_mean_c1"]) & np.isfinite(pair["m6a_mean_c2"]) &
            np.isfinite(pair["chip_med_c1"]) & np.isfinite(pair["chip_med_c2"]) &
            (pair["chip_med_c1"] > 0) & (pair["chip_med_c2"] > 0)
        ].copy()
        if pair.empty:
            print(f"[m6A-DELTA] {title}: nothing after numeric/log filter."); continue

        # % change relative to base (COND1)
        base  = pair["m6a_mean_c1"].to_numpy(float)
        treat = pair["m6a_mean_c2"].to_numpy(float)
        REL_EPS = 1e-6
        pair["delta_m6a_rel"] = np.divide(
            100.0 * (treat - base), base,
            out=np.full_like(base, np.nan, dtype=float),
            where=np.abs(base) >= REL_EPS
        )

        # central x (chip) then choose plotting x per X_AXIS_MODE
        pair["chip_x"] = np.nanmedian(
            np.stack([pair["chip_med_c1"].to_numpy(float),
                      pair["chip_med_c2"].to_numpy(float)], axis=1), axis=1)

        ok = np.isfinite(pair["delta_m6a_rel"]) & np.isfinite(pair["chip_x"]) & (pair["chip_x"] > 0)
        data = pair.loc[ok, ["type","bed_start","delta_m6a_rel","chip_x"]]
        if data.empty:
            print(f"[m6A-DELTA] {title}: nothing after finite/log filter."); continue

        # choose x: rank or chip
        if str(X_AXIS_MODE).upper() == "RANK":
            data["x_plot"] = pd.Series(data["chip_x"]).rank(method="first").astype(int)
            x_hover = "rank"
        else:
            data["x_plot"] = data["chip_x"].to_numpy(float) + EPS
            x_hover = "ChIP"

        fig = go.Figure()
        if str(X_AXIS_MODE).upper() == "RANK":
            _overlay_rank_trends(fig, data, "x_plot", "delta_m6a_rel", window=TREND_WINDOW)
        data_chr = _attach_chr(data, df)
        for typ, sub_t in data_chr.groupby("type", sort=False):
            for grp, sym, sub in _by_chr_groups(sub_t):
                if sub.empty: continue
                fig.add_trace(MarkerTrace(
                    x=sub["x_plot"].to_numpy(float),
                    y=sub["delta_m6a_rel"].to_numpy(float),
                    mode="markers",
                    name=str(typ) if grp == "X" else f"{typ} (auto)",
                    showlegend=True,
                    legendgroup=str(typ),
                    marker=dict(color=get_color(str(typ)), size=7, opacity=0.85, symbol=sym),
                    customdata=np.stack([sub["type"].to_numpy(str), sub["bed_start"].to_numpy()], axis=1),
                    hovertemplate=(f"Δ% m6A (rel)=%{{y:.1f}}%<br>type=%{{customdata[0]}}"
                                   f"<br>bed_start=%{{customdata[1]}}<br>{x_hover}=%{{x}}"
                                   "<extra></extra>")
                ))

        _add_shape_legend(fig)
        fig.update_layout(
            title=f"{title}: ΔAccessibility (% change vs base) [lnP0 filter only for {sorted(TARGET_TYPES)}]",
            template=TEMPLATE,
            xaxis=_xaxis_layout(),
            yaxis=dict(title="ΔAccessibility (% change vs base)"),
            legend=dict(title="Type / Shapes", itemsizing="constant", groupclick="toggleitem"),
            width=800
        )
        print(f"[m6A-DELTA] {title}: points={len(data):,}")
        fig.show()

    else:
        raise ValueError("MODE must be 'PAIR' or 'DELTA'")




In [None]:
# ╔══════════════════════════════════════════════════════════════════╗
# ║  CELL A: Build per-read MNase dyad occupancy (parallel)          ║
# ╚══════════════════════════════════════════════════════════════════╝
# Loads two MNase dyad bedGraph WIG files, filters to chromosomes
# and windows present in `filtered_reads_df`, and computes for each
# read a binned dyad-occupancy vector in rel_pos space (pos - bed_start).
# Output: global `mnase_seq_nuc_occupancy_df`
# ────────────────────────────────────────────────────────────────────

# ========================== CONFIG ========================== #
MNASE_WIG_FILES = [
    "/Data1/ext_data/lieb_mnase_2017/GSM2098437_RT_rep1_MNaseTC_30m_smoothDyads_ce11.wig",
    "/Data1/ext_data/lieb_mnase_2017/GSM2098437_RT_rep2_MNaseTC_30m_smoothDyads_ce11.wig",
]
MNASE_BIN_BP       = 10        # bin step for MNase sampling
MNASE_CACHE_PATH   = None      # e.g. "/Data1/cache/mnase_seq_nuc_occupancy.parquet"
MNASE_NUM_WORKERS  = max(2, mp.cpu_count() - 2)
MNASE_CHUNKSIZE    = 256
MNASE_DEBUG        = True
# relative-zero for MNase rel_pos
MNASE_RELATIVE_ZERO = "midpoint"   # 'midpoint' or 'start'
MNASE_ROUND_REL     = True         # round to nearest bp so rel_pos is int-like

# ========================== IMPORTS ========================= #
import os, sys, math
import numpy as np, pandas as pd
import multiprocessing as mp
from tqdm.auto import tqdm

# =================== CHROM NAME MAPPING ===================== #
_ROMAN = {"I":"I","II":"II","III":"III","IV":"IV","V":"V","X":"X","M":"M"}
def _chrom_ce11_from_df(name: str) -> str:
    # "CHROMOSOME_II" → "chrII"; "CHROMOSOME_X"→"chrX"
    if not isinstance(name, str): return None
    if name.startswith("CHROMOSOME_"):
        suf = name.split("_", 1)[1]
        return f"chr{_ROMAN.get(suf, suf)}"
    return None

import numpy as np, pandas as pd

print("dtypes:", filtered_reads_df[["chrom","bed_start","bed_end"]].dtypes.to_dict())

probe = filtered_reads_df[["chrom","bed_start","bed_end"]].copy()
probe["wig_chrom"] = probe["chrom"].map(lambda s: ("chr"+s.split("_",1)[1]) if isinstance(s,str) and s.startswith("CHROMOSOME_") else None)

# counts
print("NaN bed_start:", probe["bed_start"].isna().sum(),
      "NaN bed_end:", probe["bed_end"].isna().sum())

print("±inf bed_start:", np.isinf(probe["bed_start"]).sum(),
      "±inf bed_end:", np.isinf(probe["bed_end"]).sum())

print("wig_chrom value_counts:\n", probe["wig_chrom"].value_counts(dropna=False).head(20))

# rows that will poison the groupby
bad = (
    probe["wig_chrom"].notna()
    & (~np.isfinite(probe["bed_start"]) | ~np.isfinite(probe["bed_end"]))
)
print("bad rows:", int(bad.sum()))
if bad.any():
    display(probe.loc[bad].head(20))


# Restrict to chromosomes and windows covered by reads
_reads = filtered_reads_df[["chrom","bed_start","bed_end"]].copy()

# coerce numerics, drop NaN and ±inf
_reads["bed_start"] = pd.to_numeric(_reads["bed_start"], errors="coerce")
_reads["bed_end"]   = pd.to_numeric(_reads["bed_end"],   errors="coerce")
_reads = _reads.replace([np.inf, -np.inf], np.nan).dropna(subset=["bed_start","bed_end"])

# map chromosome names
def _chrom_ce11_from_df(name: str) -> str:
    if not isinstance(name, str) or not name.startswith("CHROMOSOME_"):
        return None
    suf = name.split("_", 1)[1]
    return "chr" + {"I":"I","II":"II","III":"III","IV":"IV","V":"V","X":"X","M":"M"}.get(suf, suf)

_reads["wig_chrom"] = _reads["chrom"].map(_chrom_ce11_from_df)
_reads = _reads[_reads["wig_chrom"].notna()]

# enforce sane windows
_reads = _reads[_reads["bed_end"] > _reads["bed_start"]]

agg = (
    _reads.groupby("wig_chrom", as_index=True)
          .agg(lo=("bed_start","min"), hi=("bed_end","max"))
)

# drop any remaining non-finite groups, then cast
finite_mask = np.isfinite(agg["lo"].values) & np.isfinite(agg["hi"].values)
if (~finite_mask).any():
    print("[MNase] Dropping groups with non-finite lo/hi:",
          agg.index[~finite_mask].tolist(), file=sys.stderr)
agg = agg.iloc[finite_mask]

agg = np.floor(agg).astype(np.int64)  # safe cast
chrom_windows = agg.to_dict(orient="index")

allowed_chroms = set(chrom_windows.keys())
print("[MNase] Allowed chromosomes:", sorted(allowed_chroms))
print("[MNase] Windows per chromosome:",
      {k:(v["lo"], v["hi"]) for k,v in chrom_windows.items()})

if MNASE_DEBUG:
    print(f"[MNase] Allowed chromosomes: {sorted(allowed_chroms)}", file=sys.stderr)
    print(f"[MNase] Windows per chromosome: "
          f"{ {k:(v['lo'],v['hi']) for k,v in chrom_windows.items()} }", file=sys.stderr)

# =================== WIG PARSE + INDEX ====================== #
def _parse_wig_to_index(path, allowed, windows):
    """Return dict: chrom -> (starts, ends, values) arrays within requested windows."""
    out = {c: ([], [], []) for c in allowed}
    keep = 0
    with open(path, "r") as fh:
        for line in fh:
            if not line or line[0] == "#":
                continue
            parts = line.split()
            if len(parts) != 4:
                continue
            chrom, s, e, v = parts[0], int(parts[1]), int(parts[2]), float(parts[3])
            if chrom not in allowed:
                continue
            lo, hi = windows[chrom]["lo"], windows[chrom]["hi"]
            if e <= lo or s >= hi:
                continue
            s2, e2 = max(s, lo), min(e, hi)
            if s2 < e2:
                out[chrom][0].append(s2)
                out[chrom][1].append(e2)
                out[chrom][2].append(v)
                keep += 1
    for c in list(out.keys()):
        starts = np.asarray(out[c][0], dtype=np.int64)
        ends   = np.asarray(out[c][1], dtype=np.int64)
        vals   = np.asarray(out[c][2], dtype=np.float32)
        # ensure sorted
        idx = np.argsort(starts, kind="mergesort")
        out[c] = (starts[idx], ends[idx], vals[idx])
    return out

if MNASE_DEBUG: print(f"[MNase] Parsing WIGs …", file=sys.stderr)
WIG_IDX = []
for p in MNASE_WIG_FILES:
    if MNASE_DEBUG: print(f"  ↳ {p}", file=sys.stderr)
    WIG_IDX.append(_parse_wig_to_index(p, allowed_chroms, chrom_windows))
if MNASE_DEBUG: print(f"[MNase] Done parsing.", file=sys.stderr)

def _sample_bedgraph(idx_tuple, pos):
    """Vectorized lookup for positions 'pos' using (starts, ends, vals)."""
    starts, ends, vals = idx_tuple
    if starts.size == 0:
        return np.zeros_like(pos, dtype=np.float32)
    ii = np.searchsorted(starts, pos, side="right") - 1
    ok = (ii >= 0) & (pos < ends[np.clip(ii, 0, len(ends)-1)])
    out = np.zeros_like(pos, dtype=np.float32)
    out[ok] = vals[ii[ok]]
    return out

# =============== PER-READ EXTRACTION (PARALLEL) ============== #
def _prep_read_records(df):
    recs = []
    for r in df.itertuples(index=False):
        wc = _chrom_ce11_from_df(r.chrom)
        if wc is None or wc not in allowed_chroms:
            continue
        recs.append((r.read_id, r.condition, wc, int(r.bed_start), int(r.bed_end)))
    return recs

read_records = _prep_read_records(filtered_reads_df[["read_id","condition","chrom","bed_start","bed_end"]])
if MNASE_DEBUG: print(f"[MNase] Reads eligible: {len(read_records):,}", file=sys.stderr)

# Share WIG_IDX to workers
def _worker_read_to_mnase(args):
    rid, cond, wchr, bstart, bend = args
    # skip bad bounds
    if bend <= bstart or (bend - bstart) < MNASE_BIN_BP:
        return None

    # 1) bin centers across the genomic read span [bed_start, bed_end)
    centers = np.arange(bstart, bend, MNASE_BIN_BP, dtype=np.int64) + (MNASE_BIN_BP // 2)

    # 2) average WIG values across replicates at genomic centers
    vals = None
    for rep in WIG_IDX:
        if wchr not in rep:
            return None
        v = _sample_bedgraph(rep[wchr], centers)
        vals = v if vals is None else (vals + v)
    vals = vals / float(len(WIG_IDX))

    # 3) convert genomic centers → rel_pos with zero at the read midpoint
    if MNASE_RELATIVE_ZERO == "midpoint":
        mid = 0.5 * (bstart + bend)   # may be .5 when read length is odd
    else:  # 'start'
        mid = float(bstart)

    rel = centers.astype(np.float64) - mid
    if MNASE_ROUND_REL:
        rel = np.rint(rel).astype(np.int32)      # integer rel_pos like other tracks
    else:
        rel = rel.astype(np.float32)

    # Optional debug on first few reads
    # if MNASE_DEBUG:
    #     print(f"[MNase] {rid}: bstart={bstart}, bend={bend}, mid={mid:.1f}, "
    #           f"centers[0..2]={centers[:3].tolist()}, rel[0..2]={rel[:3].tolist()}",
    #           file=sys.stderr)

    return dict(
        read_id=rid, condition=cond, wig_chrom=wchr,
        bed_start=int(bstart), bed_end=int(bend),
        mnase_rel_pos=rel,
        mnase_dyad_occupancy=vals.astype(np.float32),
    )

results = []
if read_records:
    with mp.Pool(processes=MNASE_NUM_WORKERS) as pool, tqdm(total=len(read_records), desc="MNase per-read") as bar:
        for out in pool.imap_unordered(_worker_read_to_mnase, read_records, chunksize=MNASE_CHUNKSIZE):
            if out is not None:
                results.append(out)
            bar.update(1)

mnase_seq_nuc_occupancy_df = pd.DataFrame(results)
if MNASE_DEBUG:
    n_total = len(read_records)
    n_kept  = len(mnase_seq_nuc_occupancy_df)
    print(f"[MNase] Built mnase_seq_nuc_occupancy_df: {n_kept:,}/{n_total:,} reads with data.", file=sys.stderr)

if MNASE_CACHE_PATH:
    os.makedirs(os.path.dirname(MNASE_CACHE_PATH), exist_ok=True)
    mnase_seq_nuc_occupancy_df.to_parquet(MNASE_CACHE_PATH, index=False)
    if MNASE_DEBUG: print(f"[MNase] Cached → {MNASE_CACHE_PATH}", file=sys.stderr)


In [None]:
m = (combined_bed_df['type'].str.contains('MEX', na=False)) & (combined_bed_df['bed_strand']=='-')
(combined_bed_df.loc[m, 'bed_start'] > combined_bed_df.loc[m, 'bed_end']).sum()
ctr_l = combined_bed_df.loc[m, 'bed_start'] + bed_window
ctr_r = combined_bed_df.loc[m, 'bed_end']   - bed_window
print((ctr_l - ctr_r).abs().describe())  # should be ~0 if starts are left edges

x = combined_bed_df.query("type=='MEX_motif' & bed_strand=='-'").copy()
x["half_width"] = (x.bed_end - x.bed_start)/2
print(x["half_width"].describe())


In [None]:
# ╔══════════════════════════════════════════════════════════════════╗
# ║  CELL B: Raster + %m6A + IND profile + normalized occ + MNase    ║
# ╚══════════════════════════════════════════════════════════════════╝

# ========================== USER CONFIG ========================== #
PLOT_COND    = analysis_cond[0]     # primary condition (required)
PLOT_COND2   = analysis_cond[5]     # secondary condition (None → off)
TYPE_TO_PLOT = "MEX_motif"           # set to a string to filter, or None to disable
STRAND_TO_PLOT = "-" # "+" or None or "-"
CHR_TYPE_FILTER = ["Autosome"]   # e.g., ["X"] or ["Autosome"]; None disables


SCATTER_READS_N   = 200
PLOT_WINDOW       = 2000
MOV_AVG_WINDOW    = 10               # kept for compatibility (unused here)
BIN_ENC_BP        = 30               # raster encoding bin
M6A_BIN           = 5
M6A_MOVAVG_WIN    = 1                # half-window (bins) for the %m6A moving average
HIST_BIN          = 10               # density smoothing for occupancy
GAUSS_SIG_HIST    = 2

# Evenness params (unused in this cell; kept for compatibility)
EVEN_WIN          = 90
EVEN_BIN          = 30
EVEN_STEP         = 30
SMOOTH_EVEN_BINS  = 0

MIN_GAP_BP        = 1                # ≥1 bp gap on a raster row
type_label        = "edge-grown"

# Tracks
TRACK_BIN         = HIST_BIN         # x-bin for nuc occupancy normalization
MNASE_PLOT_BIN    = MNASE_BIN_BP if "MNASE_BIN_BP" in globals() else 10

# IND profile config
IND_BIN_BP               = 10        # x-bin for IND(x)
MIN_M6A_IN_IND_PROFILE   = 0         # ≥ this many m6A within the gated span; 0 disables
INCLUDE_EDGES_IN_IND     = True      # closed vs open interval at boundaries
MAX_IND_BP               = 600       # exclude larger values
IND_MODE                 = "edges"   # "edges" for inner-edge gap; "centers" for center-to-center

# I/O and debug
SAVE_FIGS         = True
DEBUG             = True

# ───────────────── condition→colour helper ───────────────── #
from itertools import cycle
CLR_SDC = "#b12537"; CLR_N2  = "#4974a5"; CLR_DPY = "#47B562"
_GREY_CYCLE, _cond2clr_cache = cycle(["#4d4d4d", "#a3a3a3"]), {}
def cond_color(cond: str) -> str:
    key = cond.lower()
    if "sdc" in key: return CLR_SDC
    if "dpy" in key: return CLR_DPY
    if "n2"  in key: return CLR_N2
    if key not in _cond2clr_cache: _cond2clr_cache[key] = next(_GREY_CYCLE)
    return _cond2clr_cache[key]

# ---------- libs & paths ------------------------------------------ #
from pathlib import Path
from datetime import datetime
import numpy as np, pandas as pd, sys
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.ndimage import uniform_filter1d

OUT_DIR_FIG = Path("/tmp/edge_grown_qc_plots"); OUT_DIR_FIG.mkdir(parents=True, exist_ok=True)
STAMP       = datetime.now().strftime("%Y%m%d_%H%M%S")

# ---------- core-length palette ----------------------------------- #
def _apply_type_chr_filters(df: pd.DataFrame) -> pd.DataFrame:
    out = df
    if STRAND_TO_PLOT:
        out = out[out["bed_strand"]==STRAND_TO_PLOT]
    if TYPE_TO_PLOT:
        out = out[out["type"] == TYPE_TO_PLOT]
    if CHR_TYPE_FILTER not in (None, "", []):
        allowed = {str(x).lower() for x in CHR_TYPE_FILTER}
        out = out[out["chr_type"].astype(str).str.lower().isin(allowed)]
    return out


def _core_color(clen):
    if clen <= 100:   return "#e5c951"
    elif clen <= 128: return "#d07987"
    elif clen <= 180: return "#a06ab4"
    else:             return "#15819a"

# ---------- Raster traces ----------------------------------------- #
def raster_traces(df_sub:pd.DataFrame, cond:str):
    traces=[]
    for rec in df_sub.itertuples(index=False):
        rid=rec.row_id
        rel=np.asarray(rec.rel_pos); qual=np.asarray(rec.mod_qual_bin)
        m_pos=rel[qual==1]; m_pos=m_pos[(m_pos>=-PLOT_WINDOW)&(m_pos<=PLOT_WINDOW)]
        if m_pos.size:
            traces.append(go.Scatter(x=m_pos,y=[rid]*len(m_pos),
                                     mode="markers",
                                     marker=dict(size=2,color="black"),
                                     hoverinfo="skip",showlegend=False))
        for cen,core,*_ in rec.nuc_coords:
            clr=_core_color(core)
            traces.extend([
                go.Scatter(x=[cen-core//2,cen+core//2],y=[rid,rid],mode="lines",
                           line=dict(width=4,color=clr),opacity=0.25,
                           hoverinfo="skip",showlegend=False),
                go.Scatter(x=[cen],y=[rid],mode="markers",
                           marker=dict(size=4,color=clr),
                           hoverinfo="skip",showlegend=False)
            ])
        traces.append(go.Scatter(x=[rel.min(),rel.max()],y=[rid,rid],
                                 mode="lines",
                                 line=dict(color="rgba(120,120,120,0.4)",width=1),
                                 hoverinfo="skip",showlegend=False))
    return traces

def build_plot_df(cond:str)->pd.DataFrame:
    df = filtered_reads_df.loc[filtered_reads_df["condition"]==cond].reset_index(drop=True)
    pre_n = len(df)
    df = _apply_type_chr_filters(df)
    if DEBUG:
        msg = [f"[INFO] {cond}: {pre_n:,} reads total"]
        msg.append(f"[INFO] {cond}: {len(df):,} after type/chr filters"
                   f" (TYPE_TO_PLOT={TYPE_TO_PLOT}, CHR_TYPE_FILTER={CHR_TYPE_FILTER})")
        print("\n".join(msg), file=sys.stderr)
    if df.empty: return df
    # Keep reads with ≥1 nuc in window
    df = df[df.nuc_centers.apply(lambda c:any((-PLOT_WINDOW<=x<=PLOT_WINDOW) for x in c))]
    if DEBUG: print(f"[INFO] {cond}: {len(df):,} reads inside window", file=sys.stderr)
    df = df.sample(min(SCATTER_READS_N,len(df)), random_state=0).reset_index(drop=True)

    # Greedy row-packing by extents
    row_ends, row_ids=[],[]
    for rec in df.itertuples(index=False):
        r_start, r_end = min(rec.rel_pos), max(rec.rel_pos)
        for idx,last_end in enumerate(row_ends):
            if r_start > last_end + (MIN_GAP_BP-1):
                row_ids.append(idx+1); row_ends[idx]=r_end; break
        else:
            row_ends.append(r_end); row_ids.append(len(row_ends))
    df=df.assign(row_id=row_ids)
    return df


# ---------- %m6A profile ------------------------------------------ #
def m6a_profile(df_full:pd.DataFrame):
    edges=np.arange(-PLOT_WINDOW,PLOT_WINDOW+M6A_BIN,M6A_BIN)
    bin_centres=(edges[:-1]+edges[1:])/2
    c1=np.zeros(len(edges)-1,int); ctot=np.zeros_like(c1)
    for rel_pos,qual in df_full[['rel_pos','mod_qual_bin']].itertuples(index=False):
        x=np.asarray(rel_pos); q=np.asarray(qual)
        mask=(q==1)|(q==0); idx=np.digitize(x[mask],edges)-1
        ok=(idx>=0)&(idx<len(c1)); idx=idx[ok]; q=q[mask][ok]
        ctot[idx]+=1; c1[idx]+=(q==1)
    with np.errstate(divide='ignore',invalid='ignore'):
        pct=c1/ctot.astype(float)
    if M6A_MOVAVG_WIN > 0:
        win = 2 * M6A_MOVAVG_WIN + 1
        num = uniform_filter1d(c1.astype(float), size=win, mode="nearest")
        den = uniform_filter1d(ctot.astype(float), size=win, mode="nearest")
        with np.errstate(divide='ignore', invalid='ignore'):
            pct = np.divide(num, den, out=np.zeros_like(num), where=den>0)
    return bin_centres, pct

# ---------- Normalized nuc occupancy ------------------------------- #
def normalized_nuc_occupancy(df_full: pd.DataFrame):
    centres = np.concatenate(df_full.nuc_centers.values) if len(df_full) else np.array([])
    if centres.size == 0:
        return np.array([]), np.array([])
    bins = np.arange(-PLOT_WINDOW, PLOT_WINDOW + TRACK_BIN, TRACK_BIN)
    counts, _ = np.histogram(centres, bins=bins)
    avg = counts.mean() if counts.size else 0.0
    norm = counts / avg if avg > 0 else np.zeros_like(counts, dtype=float)
    bc = (bins[:-1] + bins[1:]) / 2
    return bc, norm

# ---------- MNase dyad occupancy track ---------------------------- #
def mnase_dyad_occupancy_track(df_full: pd.DataFrame):
    if "mnase_seq_nuc_occupancy_df" not in globals():
        if DEBUG: print("[MNase] mnase_seq_nuc_occupancy_df not found → skipping track.", file=sys.stderr)
        return np.array([]), np.array([])
    use_ids = set(df_full["read_id"])
    sub = mnase_seq_nuc_occupancy_df[mnase_seq_nuc_occupancy_df["read_id"].isin(use_ids)]
    if sub.empty:
        if DEBUG: print("[MNase] No MNase entries for reads in this panel.", file=sys.stderr)
        return np.array([]), np.array([])

    g_edges = np.arange(-PLOT_WINDOW, PLOT_WINDOW + MNASE_PLOT_BIN, MNASE_PLOT_BIN)
    g_cent  = (g_edges[:-1] + g_edges[1:]) / 2
    sumv = np.zeros_like(g_cent, dtype=float)
    cntv = np.zeros_like(g_cent, dtype=int)

    for rel, occ in zip(sub["mnase_rel_pos"], sub["mnase_dyad_occupancy"]):
        rel = np.asarray(rel); occ = np.asarray(occ, float)
        mask = (rel >= g_edges[0]) & (rel < g_edges[-1])
        if not mask.any(): continue
        idx = np.digitize(rel[mask], g_edges) - 1
        np.add.at(sumv, idx, occ[mask])
        np.add.at(cntv, idx, 1)

    with np.errstate(divide='ignore', invalid='ignore'):
        meanv = np.divide(sumv, cntv, out=np.zeros_like(sumv), where=cntv>0)
    return g_cent, meanv

# ─────────────────────────── IND helpers ────────────────────────── #
def _centers_and_cores(nuc_coords):
    """Return sorted centers and matching cores from [(center, core_len, ...), ...]."""
    if not nuc_coords: return np.array([], int), np.array([], int)
    cc, kk = [], []
    for t in nuc_coords:
        try:
            c = int(t[0]); k = int(t[1])
        except Exception:
            continue
        cc.append(c); kk.append(k)
    if not cc: return np.array([], int), np.array([], int)
    cc = np.array(cc, int); kk = np.array(kk, int)
    order = np.argsort(cc)
    return cc[order], kk[order]

def _ind_profile_for_read(
    nuc_coords, x_centers, pos1=None, min_m6a=0, include_edges=True,
    max_ind_bp=None, mode="edges"
):
    """
    Always choose left/right by centers bracketing x.
    mode="centers": IND = cR - cL
    mode="edges":   IND = (cR - kR//2) - (cL + kL//2)   # inner-edge gap
    m6A gating uses the same pair: [cL,cR] for centers, [end_left,start_right] for edges.
    """
    C, K = _centers_and_cores(nuc_coords)
    n = C.size
    if n < 2:
        return np.full_like(x_centers, np.nan, dtype=float)

    out = np.full_like(x_centers, np.nan, dtype=float)

    pos1_sorted = None
    if min_m6a > 0 and pos1 is not None:
        p = np.asarray(pos1, int)
        if p.size: pos1_sorted = np.sort(p)

    for i, x in enumerate(x_centers):
        r = np.searchsorted(C, x, side="left")   # first center >= x
        l = r - 1                                # last center <  x
        if l < 0 or r >= n:
            continue

        cL, cR = C[l], C[r]
        kL, kR = K[l], K[r]

        if mode == "centers":
            ind = cR - cL
            lo = cL if include_edges else cL + 1
            hi = cR if include_edges else cR - 1
        else:
            end_left    = cL + (kL // 2)
            start_right = cR - (kR // 2)
            ind = start_right - end_left
            lo = end_left if include_edges else end_left + 1
            hi = start_right if include_edges else start_right - 1

        if (max_ind_bp is not None) and (ind > max_ind_bp):
            continue

        if min_m6a > 0 and pos1_sorted is not None and hi >= lo and ind >= 0:
            a = np.searchsorted(pos1_sorted, lo, side="left")
            b = np.searchsorted(pos1_sorted, hi, side="right")
            if (b - a) < min_m6a:
                continue

        out[i] = float(ind)
    return out

def ind_profile(df_full: pd.DataFrame, step_bp:int=IND_BIN_BP,
                min_m6a:int=MIN_M6A_IN_IND_PROFILE, include_edges:bool=INCLUDE_EDGES_IN_IND,
                max_ind_bp=None, mode:str=IND_MODE):
    """
    Aggregate IND(x) across reads for a condition.
    Returns: bin_centers, median_IND, valid_fraction
    """
    edges = np.arange(-PLOT_WINDOW, PLOT_WINDOW + step_bp, step_bp)
    x_centers = (edges[:-1] + edges[1:]) / 2.0
    if df_full.empty:
        return x_centers, np.full_like(x_centers, np.nan, float), np.zeros_like(x_centers, float)

    ind_matrix = []
    for nuc_coords, rel_pos, mod_bin in df_full[["nuc_coords","rel_pos","mod_qual_bin"]].itertuples(index=False):
        pos1 = None
        if min_m6a > 0 and (rel_pos is not None) and (mod_bin is not None):
            rp = np.asarray(rel_pos); mb = np.asarray(mod_bin)
            if rp.size and mb.size:
                n = min(rp.size, mb.size)
                pos1 = rp[:n][mb[:n] == 1]
        ind_vec = _ind_profile_for_read(
            nuc_coords, x_centers, pos1=pos1,
            min_m6a=min_m6a, include_edges=include_edges,
            max_ind_bp=max_ind_bp, mode=mode
        )
        ind_matrix.append(ind_vec)

    M = np.vstack(ind_matrix) if ind_matrix else np.empty((0, x_centers.size))
    with np.errstate(all='ignore'):
        med = np.nanmedian(M, axis=0)
        frac = np.nanmean(np.isfinite(M), axis=0) if M.size else np.zeros_like(x_centers, float)
    return x_centers, med, frac

# -------------------------------------------------------------------
# Build dataframes for primary and optional secondary condition
# -------------------------------------------------------------------
cond_specs=[(PLOT_COND,   build_plot_df(PLOT_COND))]
if PLOT_COND2 not in (None,""):
    cond_specs.append((PLOT_COND2, build_plot_df(PLOT_COND2)))

max_row_A = cond_specs[0][1].row_id.max() if not cond_specs[0][1].empty else 0
max_row_B = cond_specs[1][1].row_id.max() if len(cond_specs)==2 and not cond_specs[1][1].empty else 0

# ----------------------- assemble traces -------------------------- #
mid_traces = {"m6a": [], "ind": [], "occ": [], "mnase": []}

for cond, _plot_df in cond_specs:
    df_full = _apply_type_chr_filters(
        filtered_reads_df[filtered_reads_df.condition == cond]
    )
    if df_full.empty:
        continue
    col = cond_color(cond)

    # % m6A
    x_m6a, y_m6a = m6a_profile(df_full)
    mid_traces["m6a"].append(go.Scatter(x=x_m6a, y=y_m6a, mode="lines",
                                        line=dict(color=col, width=2),
                                        name=f"% m6A ({cond})"))

    # IND profile (median across reads), center-bracketed pairs in both modes
    x_ind, y_ind, frac_valid = ind_profile(
        df_full, step_bp=IND_BIN_BP,
        min_m6a=MIN_M6A_IN_IND_PROFILE,
        include_edges=INCLUDE_EDGES_IN_IND,
        max_ind_bp=MAX_IND_BP, mode=IND_MODE
    )
    mid_traces["ind"].append(go.Scatter(
        x=x_ind, y=y_ind, mode="lines",
        line=dict(color=col, width=2),
        name=f"IND ({'edges' if IND_MODE=='edges' else 'centers'}) ({cond})",
        hovertemplate="x=%{x:.0f} bp<br>IND=%{y:.1f} bp<extra></extra>"
    ))

    # normalized nuc occupancy
    x_occ, y_occ = normalized_nuc_occupancy(df_full)
    if y_occ.size:
        mid_traces["occ"].append(go.Scatter(x=x_occ, y=y_occ, mode="lines",
                                            line=dict(color=col, width=2),
                                            name=f"Normalized nuc occupancy ({cond})"))

    # MNase dyad occupancy
    x_mn, y_mn = mnase_dyad_occupancy_track(df_full)
    if y_mn.size:
        mid_traces["mnase"].append(go.Scatter(x=x_mn, y=y_mn, mode="lines",
                                              line=dict(color=col, width=2),
                                              name=f"MNase dyad occupancy ({cond})"))

# -------------------------------------------------------------------
# Figure: raster + middle tracks (m6A, IND, norm occ, MNase)
# -------------------------------------------------------------------
if len(cond_specs) == 2:
    rows = 6  # primary raster, m6A, IND, norm occ, MNase, secondary raster
    row_heights = [0.38, 0.12, 0.12, 0.12, 0.12, 0.38]
    sec_raster_row = 6
else:
    rows = 5  # primary raster, m6A, IND, norm occ, MNase
    row_heights = [0.50, 0.13, 0.13, 0.12, 0.12]

fig_raster = make_subplots(rows=rows, cols=1, shared_xaxes=True,
                           vertical_spacing=0.02, row_heights=row_heights)

# primary raster
for tr in raster_traces(cond_specs[0][1], cond_specs[0][0]):
    fig_raster.add_trace(tr,row=1,col=1)
fig_raster.update_yaxes(title="Read #", range=[0.5, max_row_A + 0.5],
                        showticklabels=False, row=1, col=1)

# middle panels
for tr in mid_traces["m6a"]:  fig_raster.add_trace(tr,row=2,col=1)
fig_raster.update_yaxes(title="% m6A", row=2, col=1)

for tr in mid_traces["ind"]:  fig_raster.add_trace(tr,row=3,col=1)
fig_raster.update_yaxes(
    title="Edge gap (bp)" if IND_MODE=="edges" else "Center-to-center (bp)",
    row=3, col=1
)

for tr in mid_traces["occ"]:  fig_raster.add_trace(tr,row=4,col=1)
fig_raster.update_yaxes(title="normalized nuc occupancy", row=4, col=1)

if mid_traces["mnase"]:
    for tr in mid_traces["mnase"]: fig_raster.add_trace(tr,row=5,col=1)
    fig_raster.update_yaxes(title="MNase dyad occupancy", row=5, col=1)
else:
    if DEBUG: print("[MNase] Track absent in this panel.", file=sys.stderr)

# secondary raster (optional)
if len(cond_specs) == 2:
    for tr in raster_traces(cond_specs[1][1], cond_specs[1][0]):
        fig_raster.add_trace(tr, row=sec_raster_row, col=1)
    fig_raster.update_yaxes(title="Read #", range=[0.5, max_row_B + 0.5],
                            showticklabels=False, row=sec_raster_row, col=1)

# axes + layout
fig_raster.update_xaxes(title="rel_pos (bp)", range=[-PLOT_WINDOW, PLOT_WINDOW], row=rows, col=1)
for ax in fig_raster.layout:
    if ax.startswith(("xaxis","yaxis")):
        fig_raster.layout[ax].showgrid=False

title = f"{PLOT_COND}" + (f" vs {PLOT_COND2}" if len(cond_specs)==2 else "")
fig_raster.update_layout(template="plotly_white", width=1100,
                         height=1040 if rows==6 else 950,
                         title=title + " | raster + %m6A + IND + normalized nuc occupancy + MNase")

fig_raster.show()

if SAVE_FIGS:
    fig_raster.write_image(OUT_DIR_FIG/f"raster_{STAMP}.png", scale=2)
    fig_raster.write_image(OUT_DIR_FIG/f"raster_{STAMP}.svg")
    if DEBUG: print(f"[SAVE] raster_{STAMP}.png / .svg → {OUT_DIR_FIG}", file=sys.stderr)


In [None]:
# ╔══════════════════════════════════════════════════════════════════╗
# ║  STAND‑ALONE QC CELL – template profile + single‑read deep‑dive  ║
# ╚══════════════════════════════════════════════════════════════════╝
#
#   Generates only:
#     1.  All 1st‑pass templates on a centred profile plot
#     2.  A single‑read diagnostic panel (m6A + weighted Pearson‑r)
#   then shows & (optionally) saves the figures.
# -------------------------------------------------------------------

# ============================== CONFIG ============================== #
PLOT_COND      = analysis_cond[0]          # condition to draw from
TYPE_TO_PLOT   = None                      # e.g. "MOTIFS_rex48" or None
MOV_AVG_WINDOW = 10                        # r‑smoothing
PLOT_WINDOW    = 1000                      # for single‑read axes
SAVE_FIGS      = True
DEBUG          = True
COMPOSITE_R_SMOOTH = 20   # set to an integer window size (odd), or 0 to disable

# ── colour controls ------------------------------------------------ #
TEMPLATE_COLOR_CYCLE = ["#333333", "#555555", "#777777", "#999999"]#["#e5c951", "#d07987", "#a06ab4", "#15819a"]
DINUC_PARENT_CORE    = 298
DINUC_CLR            = "#FF6DB6"
TEMPLATE_COLOR_MAP = {
    tpl: TEMPLATE_COLOR_CYCLE[i % len(TEMPLATE_COLOR_CYCLE)]
    for i, tpl in enumerate(_tpl_stats.keys())
}
CORE_COLOR_MAP = {tpl[1]: clr for tpl, clr in TEMPLATE_COLOR_MAP.items()}
CORE_COLOR_MAP[DINUC_PARENT_CORE] = DINUC_CLR

# ── exports -------------------------------------------------------- #
from pathlib import Path
from datetime import datetime
OUT_DIR_FIG = Path("/Data1/git/meyer-nanopore/scripts/analysis/images_20250707/qc_plots")
OUT_DIR_FIG.mkdir(parents=True, exist_ok=True)
STAMP = datetime.now().strftime("%Y%m%d_%H%M%S")

# ============================== LIBS =============================== #
import numpy as np, pandas as pd, math, sys
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.ndimage import uniform_filter1d

# ====================== HELPERS / UTILITIES ======================== #
# ── colour controls ------------------------------------------------ #
RANGE_COLORS = ["#333333", "#555555", "#777777", "#999999"] #["#333333", "#555555", "#777777", "#999999"]

def color_by_core(core_len: int) -> str:
    if core_len < 100:         return RANGE_COLORS[0]
    if core_len <= 124:        return RANGE_COLORS[1]
    if core_len <= 169:        return RANGE_COLORS[2]
    return RANGE_COLORS[3]

DINUC_PARENT_CORE = 298        # unchanged
DINUC_CLR          = color_by_core(DINUC_PARENT_CORE)  # bright for halves
NAVY               = "#1f4489"

def _hex_to_rgb(h): h=h.lstrip("#"); return tuple(int(h[i:i+2],16) for i in (0,2,4))
def _add_core_shapes(fig, core_list, y0, y1, row, col):
    for item in core_list:
        cen, core = item[:2]
        parent    = item[2] if len(item) == 3 else core
        clr_hex   = DINUC_CLR if parent == DINUC_PARENT_CORE else color_by_core(parent)
        r, g, b   = _hex_to_rgb(clr_hex)
        fig.add_shape(
            type="rect",
            x0=cen-core//2, x1=cen+core//2,
            y0=y0, y1=y1,
            xref=f"x{col}" if col > 1 else "x",
            yref=f"y{row}",
            fillcolor=f"rgba({r},{g},{b},0.4)",
            line_width=0,
        )


# ============================ DATA SET ============================= #
df_plot = filtered_reads_df.loc[filtered_reads_df["condition"] == PLOT_COND]
if TYPE_TO_PLOT:
    df_plot = df_plot[df_plot["type"] == TYPE_TO_PLOT]
if DEBUG: print(f"[INFO] {len(df_plot):,} reads in subset", file=sys.stderr)
df_plot = df_plot.reset_index(drop=True)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. TEMPLATE PROFILE — centred & padded
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
max_tpl_len = max(len(v[0]) for v in _tpl_stats.values())
x_tpl       = np.arange(max_tpl_len) - max_tpl_len // 2

fig_tpl = go.Figure()
for tpl_key, (tpl_arr, *_ ) in _tpl_stats.items():
    pad_left  = (max_tpl_len - len(tpl_arr)) // 2
    pad_right = max_tpl_len - len(tpl_arr) - pad_left
    fig_tpl.add_trace(
        go.Scatter(
            x=x_tpl,
            y=np.pad(tpl_arr, (pad_left, pad_right), constant_values=np.nan),
            mode="lines",
            line=dict(color=TEMPLATE_COLOR_MAP[tpl_key], width=4),
            name=f"u={tpl_key[0]}, core={tpl_key[1]}, d={tpl_key[2]}",
        )
    )

fig_tpl.update_layout(
    template="plotly_white", width=700, height=350,
    title="1st‑pass templates (centred & padded)",
    xaxis_title="relative position (bp)", yaxis_title="value",
    showlegend=True
)
fig_tpl.update_xaxes(showgrid=False); fig_tpl.update_yaxes(showgrid=False)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 2. SINGLE‑READ DIAGNOSTIC PANEL
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ────────────────── choose a read with 50‑100 nucleosomes ───────── #
# (fallback: longest read >500 bp if none meet the nuc‑count filter)
cand_idx = [
    i for i, r in df_plot.iterrows()
    if r["read_length"] > 500 and 50 <= len(r["nuc_coords"]) <= 100
]

if cand_idx:
    first_idx = cand_idx[0]            # take the first qualifying read
    if DEBUG:
        print(f"[INFO] picked read {df_plot.loc[first_idx,'read_id']} "
              f"with {len(df_plot.loc[first_idx,'nuc_coords'])} nucleosomes",
              file=sys.stderr)
else:
    first_idx = next(i for i, r in df_plot.iterrows() if r["read_length"] > 500)
    if DEBUG:
        print("[WARN] no read with 50‑100 nucleosomes found; "
              "using first read >500 bp instead", file=sys.stderr)

rec = df_plot.iloc[first_idx]

x   = np.asarray(rec["rel_pos"], int)
sig = np.asarray(rec["mod_qual_bin"], float)

vec = np.full(x.ptp() + 1, np.nan); vec[x - x.min()] = sig
if PERFORM_FILLING: vec = _fill_met_domains(vec, MET_DOMAIN_WIDTH)
if PERFORM_INTERP:  vec = _mode_interpolate(vec, INTERP_WINDOW)

# Pearson‑r per template
# ── build a composite max‑r curve (smoothing applied per template) ──
max_r_by_pos = {}          # x‑coord -> running max r

for tpl_key, (tpl_arr, tpl_m, tpl_s, half_len) in _tpl_stats.items():
    r_raw = _corr_window(vec, tpl_arr, tpl_m, tpl_s, cond=rec["condition"])
    if r_raw.size == 0:
        continue
    r_s  = uniform_filter1d(r_raw, MOV_AVG_WINDOW, mode="nearest")
    x_r  = np.arange(x.min(), x.min() + r_raw.size) + half_len
    for x_pos, r_val in zip(x_r, r_s):
        # keep the maximum r seen so far at this position
        if (x_pos not in max_r_by_pos) or (r_val > max_r_by_pos[x_pos]):
            max_r_by_pos[x_pos] = r_val

# convert the dict to sorted arrays for plotting
comp_x, comp_r = map(np.array, zip(*sorted(max_r_by_pos.items())))
# → optional moving‑average on the composite max‑r curve
if COMPOSITE_R_SMOOTH and COMPOSITE_R_SMOOTH > 1:
    comp_r = uniform_filter1d(comp_r, size=COMPOSITE_R_SMOOTH, mode="nearest")
    
fig_single = make_subplots(
    rows=2, cols=1, shared_xaxes=True,
    row_heights=[0.25, 0.75], vertical_spacing=0.02
)

# ── m6A markers & bars ─────────────────────────────────────────── #
mask1 = sig == 1; mask0 = sig == 0
fig_single.add_trace(
    go.Bar(
        x=x[mask1], y=[1] * mask1.sum(),
        base=0, width=2, marker_color="rgba(0,0,0,1)",
        hoverinfo="skip", showlegend=False),
    row=1, col=1
)
fig_single.add_trace(
    go.Scatter(
        x=x[mask1], y=sig[mask1],
        mode="markers", marker=dict(symbol="circle", size=4, color="black"),
        name="m6A = 1"),
    row=1, col=1
)
fig_single.add_trace(
    go.Scatter(
        x=x[mask0], y=sig[mask0],
        mode="markers",
        marker=dict(symbol="circle-open", size=4, color="black",
                    line=dict(width=0.25)),
        name="m6A = 0"),
    row=1, col=1
)
_add_core_shapes(fig_single, rec["nuc_coords"], -0.05, 1.05, 1, 1)
fig_single.update_yaxes(range=[-0.1, 1.1], title="m6A", row=1, col=1)

fig_single.add_trace(
    go.Scatter(
        x=comp_x, y=comp_r, mode="lines",
        line=dict(color="black", width=2),
        name="max r (composite)"),
    row=2, col=1
)

_add_core_shapes(fig_single, rec["nuc_coords"], -1.05, 1.05, 2, 1)
fig_single.update_yaxes(range=[-1.05, 1.05], title="weighted r", row=2, col=1)

fig_single.update_layout(
    template="plotly_white", width=1100, height=600,
    title=f"Read {rec['read_id']} ({rec['read_length']} bp)")
for ax in ("xaxis", "xaxis2", "yaxis", "yaxis2"):
    fig_single.layout[ax].showgrid = False

# ============================== SHOW ============================== #
fig_tpl.show()
fig_single.show()

# ============================== SAVE ============================== #
if SAVE_FIGS:
    for name, fig in [("template", fig_tpl),
                      (f"single_{rec['read_id']}", fig_single)]:
        png_path = OUT_DIR_FIG / f"{name}_{STAMP}.png"
        svg_path = OUT_DIR_FIG / f"{name}_{STAMP}.svg"
        fig.write_image(png_path, scale=2)
        fig.write_image(svg_path)
        if DEBUG:
            print(f"[SAVE] {png_path.name} / {svg_path.name} → {OUT_DIR_FIG}",
                  file=sys.stderr)


In [None]:
# ╔══════════════════════════════════════════════════════════════════╗
# ║  CELL 3 (re‑styled) – QC plots + optional exp_id filtering       ║
# ╚══════════════════════════════════════════════════════════════════╝
from datetime import datetime
from pathlib import Path
import numpy as np, pandas as pd, math, random
import plotly.express as px, plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.ndimage import uniform_filter1d, gaussian_filter1d
from scipy.stats import gaussian_kde
from sklearn.decomposition import PCA

# ========================== USER CONFIG ========================== #
PLOT_COND_A      = analysis_cond[0]          # primary condition
PLOT_COND_B      = analysis_cond[2]          # secondary condition
MOV_AVG_WINDOW   = 10                        # r‑smoothing
SCATTER_READS_N  = 50                        # raster subsample
CORE_HIST_BIN    = 1                         # histogram bin (bp)
BIN_ENC_BP, PLOT_WINDOW = 20, 1000
SAVE_FIGS        = True
FIG_DIR          = Path("/tmp/nuc_qc_plots"); FIG_DIR.mkdir(exist_ok=True, parents=True)
STAMP            = datetime.now().strftime("%Y%m%d_%H%M%S")

# —— TEMPORARY exclusion of specific exp_id values ———————— #
EXCLUDE_EXP_IDS = [
    # "exp12345", "exp98765"
]
# ------------------------------------------------------------------ #
df_base = filtered_reads_df.loc[~filtered_reads_df["exp_id"].isin(EXCLUDE_EXP_IDS)].copy()
if df_base.empty:
    raise ValueError("All rows filtered out by EXCLUDE_EXP_IDS – nothing to plot.")

# ───── 7‑bin discrete core‑length palette ───────────────────────── #
_COLOR_ANCHORS = { 80:"#e5c951", 110:"#d07987", 147:"#a06ab4", 200:"#15819a" }
def _hex_to_rgb(h): h=h.lstrip("#"); return tuple(int(h[i:i+2],16) for i in (0,2,4))
def _rgb_to_hex(rgb): return f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
def _mid_hex(c0,c1):
    r0,g0,b0=_hex_to_rgb(c0); r1,g1,b1=_hex_to_rgb(c1)
    return _rgb_to_hex(((r0+r1)//2,(g0+g1)//2,(b0+b1)//2))
_anchor_lens=sorted(_COLOR_ANCHORS); _anchor_hex=[_COLOR_ANCHORS[L] for L in _anchor_lens]
_discrete_lens=[]; _discrete_cols=[]
for i,L in enumerate(_anchor_lens):
    _discrete_lens.append(L); _discrete_cols.append(_anchor_hex[i])
    if i<len(_anchor_lens)-1:
        _discrete_lens.append((L+_anchor_lens[i+1])/2)
        _discrete_cols.append(_mid_hex(_anchor_hex[i],_anchor_hex[i+1]))
_boundaries=[(_discrete_lens[i]+_discrete_lens[i+1])/2 for i in range(len(_discrete_lens)-1)]
def _core_color(clen):
    if clen<=_boundaries[0]: return _discrete_cols[0]
    if clen> _boundaries[-1]:return _discrete_cols[-1]
    return _discrete_cols[np.digitize(clen,_boundaries)]

def _add_core_shapes(fig, core_list, y0, y1, row, col):
    for item in core_list:
        cen, core = item[:2]
        clr=_core_color(core); r,g,b=_hex_to_rgb(clr)
        fig.add_shape(type="rect",x0=cen-core//2,x1=cen+core//2,
                      y0=y0,y1=y1,
                      xref=f"x{col}" if col>1 else "x",yref=f"y{row}",
                      fillcolor=f"rgba({r},{g},{b},0.4)",line_width=0)
        #fig.add_vline(x=cen,row=row,col=col,line=dict(color=f"rgb({r},{g},{b})",width=2))

# ====================== 1) core‑length histogram ================== #
core_by_cond={}
for cond,grp in df_base.groupby("condition"):
    lengths=[core for coords in grp["nuc_coords"] for core in (c[1] for c in coords)]
    core_by_cond[cond]=np.array(lengths,int)
all_lengths=np.concatenate(list(core_by_cond.values()))
min_len,max_len=all_lengths.min(),all_lengths.max()
bins=np.arange(min_len,max_len+CORE_HIST_BIN,CORE_HIST_BIN)

fig_hist=go.Figure(); palette=px.colors.qualitative.Set2+px.colors.qualitative.Set1
for i,(cond,lens) in enumerate(core_by_cond.items()):
    counts,_=np.histogram(lens,bins=bins)
    pct=100*(counts.astype(float)+0.5)/counts.sum()
    centers=(bins[:-1]+bins[1:])/2
    fig_hist.add_trace(go.Bar(x=centers,y=pct,width=CORE_HIST_BIN,
                              marker_color=palette[i%len(palette)],
                              opacity=0.3,showlegend=False))
    kde=gaussian_kde(lens,bw_method=0.05)
    x_kde=np.linspace(min_len,max_len,500); y_kde=kde(x_kde)*CORE_HIST_BIN*100
    fig_hist.add_trace(go.Scatter(x=x_kde,y=y_kde,mode="lines",
                                  line=dict(color=palette[i%len(palette)],width=2),
                                  name=f"{cond} KDE"))
fig_hist.update_layout(template="plotly_white",barmode="group",
    title="Nucleosome core‑length distribution (% of nucs)",
    xaxis_title="core length (bp)",yaxis_title="percentage of nucleosomes",
    width=1000,height=500); fig_hist.show()

# =========================== 2) single read ======================= #
df_cond=df_base.query("condition == @PLOT_COND_A")
i_example=next(i for i,r in df_cond.iterrows() if r.read_length>3500)
rec=df_cond.iloc[i_example]
x=np.asarray(rec.rel_pos,int); sig=np.asarray(rec.mod_qual_bin,float)
vec=np.full(x.ptp()+1,np.nan); vec[x-x.min()]=sig
if PERFORM_FILLING: vec=_fill_met_domains(vec,MET_DOMAIN_WIDTH)
if PERFORM_INTERP:  vec=_mode_interpolate(vec,INTERP_WINDOW)

fig_single=make_subplots(rows=2,cols=1,shared_xaxes=True,row_heights=[0.55,0.45],
                         vertical_spacing=0.03)
mask1=sig==1; mask0=sig==0
fig_single.add_trace(go.Bar(x=x[mask1],y=[1]*mask1.sum(),base=0,width=2,
                            marker_color="rgba(0,0,0,1)",hoverinfo="skip"),row=1,col=1)
fig_single.add_trace(go.Scatter(x=x[mask1],y=sig[mask1],mode="markers",
                    marker=dict(symbol="circle",size=4,color="black"),name="m6A = 1"),row=1,col=1)
fig_single.add_trace(go.Scatter(x=x[mask0],y=sig[mask0],mode="markers",
                    marker=dict(symbol="circle-open",size=4,color="black",
                                line=dict(width=0.25)),name="m6A = 0"),row=1,col=1)
_add_core_shapes(fig_single,rec.nuc_coords,-0.05,1.05,1,1)
fig_single.update_yaxes(range=[-0.1,1.1],title="m6A",row=1,col=1)

for tpl_key,(tpl_arr,tpl_m,tpl_s,half_len) in _tpl_stats.items():
    r_raw=_corr_window(vec,tpl_arr,tpl_m,tpl_s,cond=rec.condition)
    if r_raw.size:
        r_s=uniform_filter1d(r_raw,MOV_AVG_WINDOW,mode="nearest")
        x_r=np.arange(x.min(),x.min()+r_raw.size)+half_len
        fig_single.add_trace(go.Scatter(x=x_r,y=r_s,mode="lines",
                            line=dict(color=_core_color(tpl_key[1])),
                            name=f"r (core={tpl_key[1]})"),row=2,col=1)
_add_core_shapes(fig_single,rec.nuc_coords,-1.05,1.05,2,1)
fig_single.update_yaxes(title="weighted r",row=2,col=1)
fig_single.update_layout(template="plotly_white",width=900,height=500,
        title=f"{rec.read_id} ({rec.read_length} bp, {PLOT_COND_A})")
for ax in ("xaxis","xaxis2","yaxis","yaxis2"): fig_single.layout[ax].showgrid=False
fig_single.show()

# ========================= 3) raster etc. ========================= #
grid_edges=np.arange(-PLOT_WINDOW,PLOT_WINDOW+BIN_ENC_BP,BIN_ENC_BP)
n_bins=len(grid_edges)-1; scan=np.arange(-PLOT_WINDOW,PLOT_WINDOW+1,30)

def _one_hot(centres):
    if not centres: return np.zeros(n_bins,int)
    idx=np.digitize(centres,grid_edges)-1; idx=idx[(idx>=0)&(idx<n_bins)]
    v=np.zeros(n_bins,int); v[np.unique(idx)]=1; return v

def _prep_condition(cond):
    df=df_base.query("condition == @cond")
    sub=df.sample(min(SCATTER_READS_N,len(df)),random_state=0)

    def _closest(v):
        if not v: return np.inf,np.inf
        arr=np.asarray(v); idx=np.argmin(np.abs(arr))
        return abs(arr[idx]), arr[idx]
    sub=(sub.assign(order_key=sub.nuc_centers.apply(_closest))
            .sort_values("order_key",kind="mergesort").reset_index(drop=True))

    mask=sub["rel_pos"].apply(lambda rel:(np.array(rel)>=-PLOT_WINDOW).any() and
                                          (np.array(rel)<= PLOT_WINDOW).any())
    sub_scatter=sub[mask].reset_index(drop=True)

    scat=[]; all_cen=[]
    for rid,row in enumerate(sub_scatter.itertuples(index=False),1):
        rel=np.asarray(row.rel_pos); qual=np.asarray(row.mod_qual_bin)
        m_pos=rel[qual==1]; m_pos=m_pos[(m_pos>=-PLOT_WINDOW)&(m_pos<=PLOT_WINDOW)]
        if m_pos.size: scat.append(go.Scatter(x=m_pos,y=[rid]*len(m_pos),
                                              mode="markers",marker=dict(size=2,color="black"),
                                              hoverinfo="skip",showlegend=False))
        for cen,core,*_ in row.nuc_coords:
            clr=_core_color(core)
            scat.extend([go.Scatter(x=[cen-core//2,cen+core//2],y=[rid,rid],mode="lines",
                                     line=dict(width=4,color=clr),opacity=0.25,
                                     hoverinfo="skip",showlegend=False),
                         go.Scatter(x=[cen],y=[rid],mode="markers",
                                     marker=dict(size=4,color=clr),
                                     hoverinfo="skip",showlegend=False)])
        all_cen.append(np.asarray(row.nuc_centers))
    all_cen=np.concatenate(all_cen) if all_cen else np.array([])
    hist_vals,_=np.histogram(all_cen,bins=grid_edges)
    hist_pct=hist_vals/len(sub) if len(sub) else np.zeros_like(hist_vals,dtype=float)
    hist_sm=uniform_filter1d(hist_pct,size=3,mode="nearest") if hist_pct.size else hist_pct
    even=[]
    for c in scan:
        rel=all_cen-c
        counts,_=np.histogram(rel,bins=np.arange(-90,90,30)); counts=counts+0.5
        p=counts/counts.sum(); H=-(p*np.log(p)).sum()
        even.append(1-math.exp(H)/len(counts))
    return sub,scat,hist_sm,even

sub_A,scat_A,hist_A,even_A=_prep_condition(PLOT_COND_A)
sub_B,scat_B,hist_B,even_B=_prep_condition(PLOT_COND_B)
bin_centres=(grid_edges[:-1]+grid_edges[1:])/2

fig_raster=make_subplots(rows=4,cols=1,shared_xaxes=True,vertical_spacing=0.02,
                         row_heights=[0.45,0.25,0.15,0.45])
for tr in scat_A: fig_raster.add_trace(tr,1,1)
fig_raster.update_yaxes(title=f"{PLOT_COND_A} read #",showticklabels=False,row=1,col=1)
for tr in scat_B: fig_raster.add_trace(tr,4,1)
fig_raster.update_yaxes(title=f"{PLOT_COND_B} read #",showticklabels=False,row=4,col=1)
fig_raster.add_trace(go.Scatter(x=bin_centres,y=hist_A,mode="lines",
                                line=dict(color="black",width=3),
                                name=f"{PLOT_COND_A} cores/read"),row=2,col=1)
fig_raster.add_trace(go.Scatter(x=bin_centres,y=hist_B,mode="lines",
                                line=dict(color="black",width=3,dash="dot"),
                                name=f"{PLOT_COND_B} cores/read"),row=2,col=1)
fig_raster.update_yaxes(title="# cores / read",row=2,col=1)
fig_raster.add_trace(go.Scatter(x=scan,y=even_A,mode="lines",
                                line=dict(color="black",width=2),
                                name=f"{PLOT_COND_A} clustering"),row=3,col=1)
fig_raster.add_trace(go.Scatter(x=scan,y=even_B,mode="lines",
                                line=dict(color="black",width=2,dash="dot"),
                                name=f"{PLOT_COND_B} clustering"),row=3,col=1)
fig_raster.update_yaxes(title="clustering",row=3,col=1)
fig_raster.update_xaxes(title="rel_pos (bp)",range=[-PLOT_WINDOW,PLOT_WINDOW],row=3,col=1)
for ax in ("xaxis","xaxis2","xaxis3","xaxis4","yaxis","yaxis2","yaxis3","yaxis4"):
    fig_raster.layout[ax].showgrid=False
fig_raster.update_layout(template="plotly_white",width=1050,height=1200,
    title=f"{PLOT_COND_A} vs {PLOT_COND_B} · raster / cores/read / clustering",
    legend=dict(yanchor="bottom",y=-0.05,xanchor="left",x=0.01))
fig_raster.show()

# ============================== SAVE ============================== #
if SAVE_FIGS:
    for name,fig in [("core_len_hist",fig_hist),
                     (f"single_{rec['read_id']}",fig_single),
                     ("raster_100reads",fig_raster)]:
        fig.write_image(FIG_DIR/f"{name}_{STAMP}.png",scale=2)
        print(f"[SAVE] {name}_{STAMP}.png → {FIG_DIR}")


In [None]:
################################################################################
#  Integrated nucleosome analytics – EVENNESS-ONLY edition                     #
#  (run immediately after “CELL 1”)                                            #
################################################################################
import numpy as np, pandas as pd, plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.ndimage import uniform_filter1d, gaussian_filter1d
from scipy.signal import find_peaks   # for dynamic-seed helper
from pathlib import Path
from datetime import datetime
import math, colorsys                  # only used later

# ───────────────────────── USER CONFIG ───────────────────────── #
COND_A, COND_B, COND_C = analysis_cond[0], analysis_cond[2], analysis_cond[1]
EXP_ID_SUBSET = [
    # "BM_05_30_24_SMACseq_R10_rep1",
    # "AG-22_11_30_23"
]

print(f"Conditions: {COND_A} vs {COND_B}   (balanced with {COND_C})")
# 1) READ-FILTER / IQR-PLOTS
FILTER_MODE      = "all"
THRESH_DIST      = 88
MOTIF_FILTER_N   = 1
MAX_READS        = 100000
READ_SORT_KEY    = "type"

# 2) CONSENSUS / ENTROPY
CONS_START_POS   = (-0, 0)     # fixed –ve / +ve seed
NRL              = 10          # bp ladder spacing
CONS_WINDOW_BP   = 90
entropy_bin      = 10
SMOOTH_COUNTS_HALF_BINS = 0
ENT_SMOOTH_WINDOW= 1
ENT_IQR_LINE_W   = 1
ENT_MED_LINE_W   = 4
BOOTSTRAP_SPLITS = 1
# ───────────────────────── USER CONFIG ───────────────────────── #
# Normalisation strategy for comparing COND_A vs COND_B (and COND_C for balancing)
#   "centres"        → down‑sample so each (type, bed_start) contributes
#                       the same *number of nucleosome centres*.
#   "scaled"         → keep all reads and scale per‑rung histograms so the
#                       effective centre counts match (no read loss).
#   "equal_read_len" → down‑sample so each (type, bed_start) contributes the
#                       same total cumulative rel_read_length.
READ_BALANCE_MODE = "scaled"         # choose: "centres" | "scaled" | "equal_read_len"

# ╔══════════════  SYMMETRY / MIRRORING  ══════════════╗
SAVE_PER_TYPE = False     # ← toggle ON / OFF
CONSIDER_BED_START = True


MIRROR_ABS_DISTANCE = False   # True → pool –x & +x into |distance|
                              # False → keep signed positions
# ╚════════════════════════════════════════════════════╝

# ───────── MOTIF LINES (optional) ───────── #
SHOW_MOTIF_LINES = False          # dashed grey guides on plots
MOTIF_LINE_STYLE = dict(color="grey", width=1, dash="dash")
SHOW_TSS_Q4_LINES  = False                     # green guides for TSS_q4
TSS_Q4_LINE_STYLE  = dict(color="green", width=1, dash="dash")
SHOW_INDIVIDUAL_LINES = False    # if True, plot the thin per‑replicate traces

# ─────────────────────────────────────────── #
# 3) GLOBAL & EXPORTS
PLOT_WINDOW      = 5000
SHOW_DEBUG       = True
CLR_A , CLR_B    = "#b12537", "#4974a5"
STAMP = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
OUT_DIR_NEW      = Path("/Data1/git/meyer-nanopore/scripts/analysis"
                        "/images_20250710/SDC2_N2_univnucX/bg2")
OUT_DIR_NEW.mkdir(parents=True, exist_ok=True)
rng              = np.random.default_rng(seed=43)
################################################################################
#                               Helper utilities                               #
################################################################################
def dbg(msg):           # light wrapper so we can silence easily
    if SHOW_DEBUG:
        print(msg)

def centred_ma(arr, k):
    if k == 0:
        return arr.astype(float)
    pad = np.pad(arr.astype(float), k, mode="reflect")
    ker = np.ones(2*k + 1) / (2*k + 1)
    return np.convolve(pad, ker, mode="valid")

# ---------------- motif / read filter (unchanged from previous cell) ---------
def motif_pass(row):
    if FILTER_MODE == "all" or MOTIF_FILTER_N == 0:
        return True
    if not row.motif_rel_start or not row.nuc_centers:
        return False
    motifs = sorted(row.motif_rel_start, key=lambda x: abs(x))[:MOTIF_FILTER_N]
    centres= np.asarray(row.nuc_centers)
    dists  = [np.min(np.abs(centres - m)) for m in motifs]
    if FILTER_MODE == "under":
        return all(d < THRESH_DIST for d in dists)
    if not all(d > THRESH_DIST for d in dists):
        return False
    hits = np.asarray(row.rel_pos)[np.asarray(row.mod_qual_bin) == 1]
    if hits.size == 0:
        return False
    return any(np.min(np.abs(hits - m)) <= THRESH_DIST for m in motifs)

# -----------------------------------------------------------------------------


################################################################################
#                      0)  filter & basic subsample                            #
################################################################################

# ─────────────── OPTIONAL exp_id OVERRIDE ────────────────
# ------------------------------------------------------------
# Build a temporary view called `work_df` and leave the master
# `filtered_reads_df` untouched.
# ------------------------------------------------------------
work_df = filtered_reads_df.copy()            # default: just a view
# work_df["type"] = "strong_rex"
work_df = work_df.loc[
    work_df["type"].isin(["TSS_q3"])
]                          # <- copy so we can mutate safely

# work_df = work_df.loc[
#     work_df["bed_start"] == 11821084
# ]

if EXP_ID_SUBSET:                      # activate only when you list exp_ids
    dbg(f"[exp_id‑filter] keeping {len(EXP_ID_SUBSET)} experiments")

    # if     exp_id = "N2_biorep1_fiber_old_R10_04_2025", set it to "N2_biorep2_fiber_old_R10_04_2025",
    work_df["exp_id"] = work_df["exp_id"].replace(
        "N2_biorep2_fiber_old_R10_04_2025_D",
        "AG-22_11_30_23"
    )
    work_df["exp_id"] = work_df["exp_id"].replace(
        "N2_biorep2_fiber_old_R10_04_2025_B",
        "AG-22_11_30_23"
    )
    work_df["exp_id"] = work_df["exp_id"].replace(
        "BS_10_17_24_SMACseq_R10_rep2",
        "BM_05_30_24_SMACseq_R10_rep1"
    )


    # 1) slice rows  (this returns a *new* frame)
    work_df = work_df.loc[
        work_df["exp_id"].isin(EXP_ID_SUBSET)
    ]
    

    
    # 2) re‑label the grouping column **inside the copy**
    work_df["condition"] = work_df["exp_id"]

    # 3) redefine labels for plots / masks
    if len(EXP_ID_SUBSET) < 2:
        raise ValueError("EXP_ID_SUBSET needs ≥2 exp_ids")
    COND_A, COND_B = EXP_ID_SUBSET[:2]
    COND_C = EXP_ID_SUBSET[2] if len(EXP_ID_SUBSET) > 2 else "BAL"

    dbg(f"[exp_id‑filter] A={COND_A}, B={COND_B}, C={COND_C}")
# ──────────────────────────────────────────────────────────

# ---------- collect unique motif start positions ----------
if SHOW_MOTIF_LINES:
    _raw = []
    for tup in work_df["motif_rel_start"].dropna():
        # tuple / list / scalar → always iterate
        _raw.extend(tup if hasattr(tup, "__iter__") else [tup])

    motif_positions = sorted({
        abs(int(p)) if MIRROR_ABS_DISTANCE else int(p)
        for p in _raw
    })
else:
    motif_positions = []

# ---------- collect unique TSS_q4 start positions ----------  # INSERT
if SHOW_TSS_Q4_LINES:
    _tss_raw = []
    for rels, attrs in work_df[["tss_rel_start", "tss_attributes"]].dropna().itertuples(index=False):
        for pos, attr in zip(rels, attrs):          # attr = (start, strand, type)
            if attr[2] == "TSS_q4":                 # quartile‑4 only
                _tss_raw.append(pos)
            elif attr[2] == "TSS_q3":                 # quartile‑4 only
                _tss_raw.append(pos)
    tss_q4_positions = sorted({
        abs(int(p)) if MIRROR_ABS_DISTANCE else int(p)
        for p in _tss_raw
    })
else:
    tss_q4_positions = []
# ------------------------------------------------------------
# ----------------------------------------------------------
def add_motif_guides(fig, *, positions, style, yref="paper"):
    """Draw vertical dashed lines at each motif position."""
    for pos in positions:
        fig.add_shape(
            type="line",
            x0=pos, x1=pos, y0=0, y1=1,
            xref="x", yref=yref,
            line=style,
            layer="below"          # keep data traces on top
        )


maskA   = work_df["condition"] == COND_A
maskB   = work_df["condition"] == COND_B
dfA_all = work_df[maskA & work_df.apply(motif_pass, axis=1)]
dfB_all = work_df[maskB & work_df.apply(motif_pass, axis=1)]
maskC   = work_df["condition"] == COND_C
dfC_all = work_df[maskC & work_df.apply(motif_pass, axis=1)]
dbg(f"{COND_C}: {len(dfC_all)} reads (balancing only)")
dbg(f"{COND_A}: {len(dfA_all)} reads  | {COND_B}: {len(dfB_all)} reads")

# ─────────────── BALANCE NUCLEOSOME CENTRES ────────────── #
# ──────────── BALANCE helper  (handles 2‑ OR 3‑condition input) ────────────
def match_total_readlen_three(dfA, dfB, dfC=None, *, mode="length", base_seed=43, verbose=False):
    """
    Down-sample reads so that, within each (type, bed_start) group,
    all conditions contribute the same total metric (nucleosome centers or read length)
    but only considering regions that overlap the ±PLOT_WINDOW.
    """
    if mode not in {"length", "centres"}:
        raise ValueError("mode must be 'length' or 'centres'")

    use_third = dfC is not None and not dfC.empty
    if not use_third:
        dfC = pd.DataFrame(columns=dfA.columns)

    # Window filter for nucleosome centers
    def _centres_in_window(centers):
        return np.sum((np.asarray(centers) >= -PLOT_WINDOW) &
                      (np.asarray(centers) <= PLOT_WINDOW))

    # Metric per read: only count centers or read_length within ±PLOT_WINDOW
    def _metric_per_read(df):
        if mode == "length":
            return np.array([np.sum((np.asarray(r.rel_pos) >= -PLOT_WINDOW) &
                                    (np.asarray(r.rel_pos) <= PLOT_WINDOW))
                             for r in df.itertuples(index=False)])
        else:
            return np.array([_centres_in_window(r.nuc_centers)
                             for r in df.itertuples(index=False)])

    # Total metric per group
    def _total(df):
        if df.empty:
            return 0
        return _metric_per_read(df).sum()

    keep_A, keep_B, keep_C = set(), set(), set()
    frames_for_grouping = [dfA, dfB] + ([dfC] if use_third else [])
    all_groups = pd.concat(frames_for_grouping).groupby(["type", "bed_start"])

    for (typ, bed), _ in all_groups:
        grpA = dfA[(dfA.type == typ) & (dfA.bed_start == bed)]
        grpB = dfB[(dfB.type == typ) & (dfB.bed_start == bed)]
        grpC = dfC[(dfC.type == typ) & (dfC.bed_start == bed)] if use_third else pd.DataFrame()

        totA = _total(grpA)
        totB = _total(grpB)
        totC = _total(grpC) if use_third else np.inf
        target = min(totA, totB, totC)

        if target == 0 or np.isinf(target):
            continue  # skip empty groups

        # Deterministic selection based on in-window metrics
        def _select(df, cond):
            metrics = _metric_per_read(df)
            sel, acc = [], 0
            local_rng = np.random.default_rng(hash((cond, typ, bed, base_seed)) & 0xFFFFFFFF)
            order = local_rng.permutation(len(df))
            for idx in order:
                m = metrics[idx]
                if m == 0:  # skip reads without in-window contribution
                    continue
                if acc + m > target:
                    continue
                sel.append(df.index[idx])
                acc += m
                if acc >= target:
                    break
            return sel

        keep_A.update(_select(grpA, "A"))
        keep_B.update(_select(grpB, "B"))
        if use_third:
            keep_C.update(_select(grpC, "C"))

        if verbose:
            who = "A-B-C" if use_third else "A-B"
            unit = "centres" if mode == "centres" else "bp"
            print(f"[BAL-{mode[:3].upper()}] ({typ}, {bed}) → {target} {unit} each ({who})")

    dfA_bal = dfA.loc[sorted(keep_A)].reset_index(drop=True)
    dfB_bal = dfB.loc[sorted(keep_B)].reset_index(drop=True)
    dfC_bal = dfC.loc[sorted(keep_C)].reset_index(drop=True) if use_third else dfC

    return dfA_bal, dfB_bal, dfC_bal

# ------ choose balancing strategy --------------------------------
if READ_BALANCE_MODE == "centres":
    dbg("[MODE] down‑sample by *in‑window nucleosome centres* per type/bed_start")
    dfA_all_bal, dfB_all_bal, dfC_all_bal = match_total_readlen_three(
        dfA_all, dfB_all, dfC_all, mode="centres"
    )

elif READ_BALANCE_MODE == "equal_read_len":
    dbg("[MODE] down‑sample by *in‑window read length* per type/bed_start")
    dfA_all_bal, dfB_all_bal, dfC_all_bal = match_total_readlen_three(
        dfA_all, dfB_all, dfC_all, mode="length"
    )

elif READ_BALANCE_MODE == "scaled":
    dbg("[MODE] keep all reads, use per‑rung histogram scaling (no down‑sampling)")
    dfA_all_bal, dfB_all_bal = dfA_all, dfB_all
else:
    raise ValueError(f"Unknown READ_BALANCE_MODE: {READ_BALANCE_MODE}")

# ---------- quick debug: in‑window totals ----------
def _tot_centres(df):
    return sum(
        np.sum((np.asarray(c) >= -PLOT_WINDOW) & (np.asarray(c) <= PLOT_WINDOW))
        for c in df.nuc_centers
    )

def _tot_len(df):
    return sum(
        np.sum((np.asarray(r.rel_pos) >= -PLOT_WINDOW) &
               (np.asarray(r.rel_pos) <=  PLOT_WINDOW))
        for r in df.itertuples(index=False)
    )

dbg("── In‑window totals (after balancing) ──")
if READ_BALANCE_MODE == "centres":
    dbg(f"{COND_A}: { _tot_centres(dfA_all_bal)} centres   |   "
        f"{COND_B}: { _tot_centres(dfB_all_bal)} centres")
else:
    dbg(f"{COND_A}: { _tot_len(dfA_all_bal)} bp read length   |   "
        f"{COND_B}: { _tot_len(dfB_all_bal)} bp read length")
dbg("────────────────────────────────────────")
# ------------------------------------------------------

###############################################################################
#  PRE-COMPUTE BOOTSTRAP GROUPS  (deterministic per cond×type×bed_start)      #
###############################################################################
def add_bootstrap_idx(df, splits, base_seed=1234):
    """
    Return a *copy* of df with an added 'boot_idx' column ∈ {0 … splits-1}.
    The assignment is reproducible and depends only on (cond,type,bed_start).
    """
    df = df.copy()
    boot_col = np.full(len(df), -1, int)

    for (cond, typ, bed), sub_idx in (
        df.groupby(["condition", "type", "bed_start"]).groups.items()
    ):
        # hash→seed so every slice is shuffled the same way,
        # independent of how many other slices exist.
        slice_seed = (hash((cond, typ, bed, base_seed)) & 0xFFFFFFFF)
        local_rng  = np.random.default_rng(slice_seed)

        order = local_rng.permutation(len(sub_idx))
        split_idx = np.array_split(order, splits)

        for j, arr in enumerate(split_idx):
            boot_col[sub_idx[arr]] = j   # assign group j

    assert (boot_col >= 0).all(), "bootstrap assignment failed"
    df["boot_idx"] = boot_col
    return df

df_bal_all = pd.concat([dfA_all_bal, dfB_all_bal], ignore_index=True)
df_bal_all = add_bootstrap_idx(df_bal_all, splits=BOOTSTRAP_SPLITS)

dfA_all_bal = df_bal_all[df_bal_all["condition"] == COND_A].reset_index(drop=True)
dfB_all_bal = df_bal_all[df_bal_all["condition"] == COND_B].reset_index(drop=True)

################################################################################
#                      1)  consensus ladder set-up                              #
################################################################################
def grouped_reads(df):
    for (cond, typ, bed), sub in df.groupby(["condition","type","bed_start"]):
        yield cond, typ, bed, sub.reset_index(drop=True)

# fixed seeds → build per-group ladder
def _build_ladder(seed_pair):
    neg_seed, pos_seed = seed_pair
    ladder = []

    # walk left from neg_seed
    p = neg_seed
    while p >= -PLOT_WINDOW:
        ladder.append(p)
        p -= NRL

    # walk right from pos_seed
    p = pos_seed
    while p <= PLOT_WINDOW:
        ladder.append(p)
        p += NRL

    return sorted(set(ladder))

GROUP_SEEDS  = {(cond,typ): CONS_START_POS
                for cond,typ,_,_ in grouped_reads(
                    pd.concat([dfA_all,dfB_all]))}
LADDER_GROUP = {k: _build_ladder(v) for k,v in GROUP_SEEDS.items()}
shared_len   = min(len(l) for l in LADDER_GROUP.values())
for k in LADDER_GROUP:
    LADDER_GROUP[k] = LADDER_GROUP[k][:shared_len]

consensus   = [float(np.mean([LADDER_GROUP[k][i] for k in LADDER_GROUP]))
               for i in range(shared_len)]
cat_labels  = [f"{round(c,1)}" for c in consensus]
DOT_OFFSET  = NRL * 0.1        # used by several plots

# identity map (no adaptive shift)
CONS_REF = {k:{p:p for p in LADDER_GROUP[k]} for k in LADDER_GROUP}

# ── choose the plotting x-axis, respecting the MIRROR flag ──
if MIRROR_ABS_DISTANCE:
    PLOT_RUNG_AXIS = sorted({abs(c) for c in consensus})
    AXIS_LABEL     = "Distance from recruitment (bp)"
else:
    PLOT_RUNG_AXIS = consensus
    AXIS_LABEL     = "Consensus position (bp)"


###############################################################################
#  QUICK STATS – nucleosome counts per entropy window (±CONS_WINDOW_BP)       #
###############################################################################
def _window_counts(df, axis_rungs, win_bp):
    """
    Return {cond: {rung: centre_count}} for every condition in df.
    """
    out = {c: {r: 0 for r in axis_rungs} for c in df["condition"].unique()}
    for cond, sub in df.groupby("condition"):
        centres = np.concatenate(sub.nuc_centers.to_numpy()) if not sub.empty else np.array([])
        for rung in axis_rungs:
            out[cond][rung] += np.sum(np.abs(centres - rung) < win_bp)
    return out

def _summarise(count_dict):
    vals = np.asarray(list(count_dict.values()), int)
    return dict( min    = int(vals.min()),
                 Q1     = float(np.percentile(vals, 25)),
                 median = float(np.median(vals)),
                 mean   = float(vals.mean()),
                 Q3     = float(np.percentile(vals, 75)),
                 max    = int(vals.max()) )

# choose which dataframe represents the “final” inputs to entropy
df_stats_src = pd.concat([dfA_all_bal, dfB_all_bal], ignore_index=True)

axis_rungs = (sorted({abs(c) for c in PLOT_RUNG_AXIS})
              if MIRROR_ABS_DISTANCE else list(PLOT_RUNG_AXIS))

win_counts = _window_counts(df_stats_src, axis_rungs, CONS_WINDOW_BP)

dbg("\n─── Nucleosomes per window (±{} bp) ───".format(CONS_WINDOW_BP))
for cond in [COND_A, COND_B]:
    stats = _summarise(win_counts[cond])
    dbg(f"{cond}:  n per window  "
        f"min={stats['min']}, Q1={stats['Q1']:.1f}, "
        f"median={stats['median']:.1f}, mean={stats['mean']:.1f}, "
        f"Q3={stats['Q3']:.1f}, max={stats['max']}")


# ################################################################################
# #                      2)  offsets to consensus (fig_iqr)                      #
# ################################################################################
# def collect_offsets_all(df):
#     out={COND_A:{c:[] for c in consensus},
#          COND_B:{c:[] for c in consensus}}
#     for cond,typ,_,sub in grouped_reads(df):
#         ladder=LADDER_GROUP[(cond,typ)]
#         for row in sub.itertuples(index=False):
#             for cen in row.nuc_centers:
#                 idx=np.argmin(np.abs(np.asarray(ladder)-cen))
#                 if idx>=shared_len: continue
#                 ref=CONS_REF[(cond,typ)][ladder[idx]]
#                 d = abs(cen-ref)
#                 if d<=CONS_WINDOW_BP:
#                     out[cond][consensus[idx]].append(d)
#     return out
#
# offsets = collect_offsets_all(pd.concat([dfA_all,dfB_all], ignore_index=True))
#
# fig_iqr = go.Figure()
# for j,(cond,clr) in enumerate(zip([COND_A,COND_B],[CLR_A,CLR_B])):
#     q1,med,q3=[],[],[]
#     for c in consensus:
#         v=offsets[cond][c]
#         q1.append(np.nan if not v else np.percentile(v,25))
#         med.append(np.nan if not v else np.percentile(v,50))
#         q3.append(np.nan if not v else np.percentile(v,75))
#     err_up   = np.asarray(q3)-np.asarray(med)
#     err_down = np.asarray(med)-np.asarray(q1)
#     x = [c-DOT_OFFSET if j==0 else c+DOT_OFFSET for c in consensus]
#     fig_iqr.add_trace(go.Scatter(
#         x=x,y=med,mode="markers",name=cond,
#         marker=dict(size=10,color=clr),
#         error_y=dict(type="data",symmetric=False,
#                      array=err_up,arrayminus=err_down,
#                      thickness=1,color=clr),
#         hovertemplate=("cons=%{customdata[2]}<br>"
#                        "Q1=%{customdata[0]:.1f}  "
#                        "median=%{y:.1f}  "
#                        "Q3=%{customdata[1]:.1f}"),
#         customdata=np.column_stack([q1,q3,consensus])
#     ))
# fig_iqr.add_shape(type="line",x0=0,x1=0,y0=0,y1=1,
#                   xref="x",yref="paper",
#                   line=dict(color="black",dash="dash"))
# fig_iqr.add_annotation(x=0,y=1.05,xref="x",yref="paper",
#                        text="best consensus match motif",
#                        showarrow=False)
# fig_iqr.update_xaxes(type="linear",tickmode="array",
#                      tickvals=consensus,ticktext=cat_labels,
#                      title="Consensus position (bp)")
# fig_iqr.update_layout(template="plotly_white",width=1000,height=500,
#                       xaxis=dict(showgrid=False,zeroline=False),
#                       yaxis=dict(showgrid=False,zeroline=False),
#                       title=(f"Offsets to consensus (NRL={NRL}, "
#                              f"window=±{CONS_WINDOW_BP} bp)"),
#                       yaxis_title="Centre – consensus (bp)")
# fig_iqr.show()
#
# ################################################################################
# #                      3)  σ (dispersion)  – fig_sigma                         #
# ################################################################################
# def collect_sigma_all(df):
#     sig={COND_A:{c:[] for c in consensus},
#          COND_B:{c:[] for c in consensus}}
#     for cond,typ,_,sub in grouped_reads(df):
#         ladder=LADDER_GROUP[(cond,typ)]
#         for i,rung in enumerate(ladder):
#             pool=[]
#             for row in sub.itertuples(index=False):
#                 pool.extend([c for c in row.nuc_centers
#                              if abs(c-rung)<=CONS_WINDOW_BP])
#             if len(pool)>1:
#                 sig[cond][consensus[i]].append(float(np.std(pool,ddof=0)))
#     return sig
#
# sigma_vals = collect_sigma_all(pd.concat([dfA_all,dfB_all], ignore_index=True))
#
# fig_sigma = go.Figure()
# for j,(cond,clr) in enumerate(zip([COND_A,COND_B],[CLR_A,CLR_B])):
#     q1,med,q3=[],[],[]
#     for c in consensus:
#         v=sigma_vals[cond][c]
#         q1.append(np.nan if not v else np.percentile(v,25))
#         med.append(np.nan if not v else np.percentile(v,50))
#         q3.append(np.nan if not v else np.percentile(v,75))
#     err_up   = np.asarray(q3)-np.asarray(med)
#     err_down = np.asarray(med)-np.asarray(q1)
#     x=[c-DOT_OFFSET if j==0 else c+DOT_OFFSET for c in consensus]
#     fig_sigma.add_trace(go.Scatter(
#         x=x,y=med,mode="markers",name=cond,
#         marker=dict(size=10,color=clr),
#         error_y=dict(type="data",symmetric=False,
#                      array=err_up,arrayminus=err_down,
#                      thickness=1,color=clr)))
# fig_sigma.update_xaxes(type="linear",tickmode="array",
#                        tickvals=consensus,ticktext=cat_labels,
#                        title="Consensus position (bp)")
# fig_sigma.update_layout(template="plotly_white",width=1000,height=500,
#                         xaxis=dict(showgrid=False,zeroline=False),
#                         yaxis=dict(showgrid=False,zeroline=False),
#                         title=(f"Centre dispersion σ within ±{CONS_WINDOW_BP} bp"),
#                         yaxis_title="σ of centre positions (bp)")
# fig_sigma.show()

################################################################################
#                      4)  Shannon evenness (KL proxy)                         #
################################################################################
# ───────────────── EVENNESS / KL-bootstrap helper ────────────────── #
###############################################################################
#  Clustering bootstrap helper  (+ optional ABS-mirroring)                    #
###############################################################################
def collect_kl_bootstrap_grouped(df_all, *, bin_size=20,
                                 splits=2, return_full=False):
    """
    Aggregate clustering (1−evenness) per rung for each
    (condition, type, bed_start, boot_idx) replicate.

    If READ_LEVEL_BALANCING is False we down-weight histograms so that, for
    every rung, A and B contribute the same *effective* number of centres.
    """
    axis_rungs = (sorted({abs(c) for c in PLOT_RUNG_AXIS})
                  if MIRROR_ABS_DISTANCE else list(PLOT_RUNG_AXIS))

    def _blank(): return {c: [] for c in axis_rungs}
    even_out = {COND_A: _blank(), COND_B: _blank()}
    full_map = {}

    win_bp  = CONS_WINDOW_BP
    edges   = np.arange(-win_bp, win_bp + bin_size, bin_size)
    n_bins  = len(edges) - 1

    # ─────────────── 0. per-window weight table ───────────────
    if READ_BALANCE_MODE == "scaled":
        # raw centre counts per condition&rung
        raw_counts = {COND_A: {r: 0 for r in axis_rungs},
                      COND_B: {r: 0 for r in axis_rungs}}
        for cond, sub in df_all.groupby("condition"):
            centres = np.concatenate(sub.nuc_centers.to_numpy())
            for rung in axis_rungs:
                raw_counts[cond][rung] += np.sum(np.abs(centres - rung) < win_bp)

        scale_factor = {COND_A: {}, COND_B: {}}
        for rung in axis_rungs:
            nA, nB = raw_counts[COND_A][rung], raw_counts[COND_B][rung]
            n_min  = min(nA, nB)
            for cond, n in [(COND_A, nA), (COND_B, nB)]:
                scale_factor[cond][rung] = (n_min / n) if n else 0.0
    # ───────────────────────────────────────────────────────────


    # ─────────────── 1. loop over replicates ───────────────
    for (cond, typ, bed, boot_idx), grp in df_all.groupby(
            ["condition", "type", "bed_start", "boot_idx"]):

        centres = np.concatenate(grp.nuc_centers.to_numpy())
        this_rep = {}

        for rung in axis_rungs:
            rel = centres - rung
            rel = rel[np.abs(rel) < win_bp]
            if rel.size == 0:
                continue

            counts, _ = np.histogram(rel, bins=edges)
            counts = centred_ma(counts, SMOOTH_COUNTS_HALF_BINS) + 0.5

            # -------- weighting toggle --------
            if READ_BALANCE_MODE == "scaled":
                counts = counts * scale_factor[cond][rung]
            # ----------------------------------

            p_vec   = counts / counts.sum()
            H       = -(p_vec[p_vec > 0] * np.log(p_vec[p_vec > 0])).sum()
            cluster = 1.0 - math.exp(H) / n_bins

            out_key = abs(rung) if MIRROR_ABS_DISTANCE else rung
            even_out[cond][out_key].append(cluster)
            this_rep[out_key] = cluster

        full_map[(cond, typ, bed, boot_idx)] = this_rep

    return (even_out, full_map) if return_full else even_out

# ------------------------------------------------------------------
#  NEW helper 2: paired Δ = B − A for matching (type, boot) reps
# ------------------------------------------------------------------
def paired_deltas(full_map_A, full_map_B):
    out = {r: [] for r in PLOT_RUNG_AXIS}
    common = set(full_map_A) & set(full_map_B)      # (type,bed,boot)
    for key in common:
        a_r = full_map_A[key]
        b_r = full_map_B[key]
        for rung in PLOT_RUNG_AXIS:
            if rung in a_r and rung in b_r:
                out[rung].append(b_r[rung] - a_r[rung])
    return out
# ------ helper: bootstrap evenness identical to previous cell ---------------

kl_vals, full_map = collect_kl_bootstrap_grouped(
    pd.concat([dfA_all_bal, dfB_all_bal], ignore_index=True),
    bin_size = entropy_bin,
    splits   = BOOTSTRAP_SPLITS,
    return_full = True
)
full_map_A = {(typ, bed, b): v
              for (cond, typ, bed, b), v in full_map.items()
              if cond == COND_A}

full_map_B = {(typ, bed, b): v
              for (cond, typ, bed, b), v in full_map.items()
              if cond == COND_B}

fig_kl = go.Figure()
if SHOW_INDIVIDUAL_LINES:
    # — add one thin line per (cond, type, bed_start, bootstrap) —
    for (cond, typ, bed, boot), rep in full_map.items():
        x = PLOT_RUNG_AXIS
        y = [rep.get(r, np.nan) for r in x]
        fig_kl.add_trace(go.Scatter(
            x=x,
            y=y,
            mode="lines",
            line=dict(width=0.5, color=CLR_A if cond==COND_A else CLR_B),
            opacity=0.3,
            showlegend=False
        ))
    
for cond, clr in zip([COND_A, COND_B], [CLR_A, CLR_B]):
    q1, med, q3 = [], [], []
    for rung in PLOT_RUNG_AXIS:
        vals = kl_vals[cond][rung]
        q1 .append(np.nan if not vals else np.percentile(vals, 25))
        med.append(np.nan if not vals else np.percentile(vals, 50))
        q3 .append(np.nan if not vals else np.percentile(vals, 75))
    if ENT_SMOOTH_WINDOW > 1:
        q1  = centred_ma(np.asarray(q1),  ENT_SMOOTH_WINDOW)
        med = centred_ma(np.asarray(med), ENT_SMOOTH_WINDOW)
        q3  = centred_ma(np.asarray(q3),  ENT_SMOOTH_WINDOW)
    fig_kl.add_trace(go.Scatter(x=PLOT_RUNG_AXIS, y=q1 ,
                                mode="lines", line=dict(width=1, dash="dot",
                                color=clr), showlegend=False))
    fig_kl.add_trace(go.Scatter(x=PLOT_RUNG_AXIS, y=q3 ,
                                mode="lines", line=dict(width=1, dash="dot",
                                color=clr), showlegend=False))
    fig_kl.add_trace(go.Scatter(x=PLOT_RUNG_AXIS, y=med,
                                mode="lines", line=dict(width=ENT_MED_LINE_W,
                                color=clr), name=f"{cond} median"))
fig_kl.update_layout(template="plotly_white", width=1000, height=450,
    title=(f"Clustering of centre distribution (±{CONS_WINDOW_BP} bp window, "
           f"{entropy_bin} bp bins; 1 pt = one type × bootstrap)"),
    yaxis_title="Clustering  (0 → uniform, 1 → clustered)",
    xaxis_title=AXIS_LABEL)
fig_kl.update_xaxes(showgrid=False, zeroline=False)
fig_kl.update_yaxes(showgrid=False, zeroline=False)

if SHOW_MOTIF_LINES:
    add_motif_guides(fig_kl,  positions=motif_positions,
                                      style=MOTIF_LINE_STYLE)
if SHOW_TSS_Q4_LINES:                                    # INSERT
    add_motif_guides(fig_kl,  positions=tss_q4_positions,
                               style=TSS_Q4_LINE_STYLE)  # INSERT
fig_kl.show()


# ──────────────────────────────────────────────────────────────────
def bootstrap_deltas(even_dict_A, even_dict_B, *, max_rows=None):
    """
    Given two evenness dicts (rung → list[cluster values]) coming out of
    collect_kl_bootstrap_grouped, return:
        diff_per_rung[rung]  -> list[ΔClustering = B – A] (length ≤ max_rows)

    The two replicate lists are *paired by index*.  If they are unequal in
    length we truncate to the shorter one (or to max_rows if provided).
    """
    out = {r: [] for r in PLOT_RUNG_AXIS }
    for rung in PLOT_RUNG_AXIS :
        a = np.asarray(even_dict_A[rung], float)
        b = np.asarray(even_dict_B[rung], float)
        n = min(len(a), len(b))
        if max_rows is not None:
            n = min(n, max_rows)
        if n:
            out[rung].extend(b[:n] - a[:n])
    return out


# ╔═══════════  NEW: ΔEvenness (B–A)  pooled across types  ═══════════╗
SAVE_TYPE_DIFF_COMBINED = True          # toggle ON / OFF
#   +1  →  B very even,   A very clustered
#   -1  →  B very clustered, A very even
#    0  →  identical evenness
# ╚═══════════════════════════════════════════════════════════════════╝
if SAVE_TYPE_DIFF_COMBINED:
    pooled_deltas = {r: [] for r in PLOT_RUNG_AXIS }

    delta_vals = paired_deltas(full_map_A, full_map_B)

    q1  = [np.nanpercentile(delta_vals[r], 25) if delta_vals[r] else np.nan
           for r in PLOT_RUNG_AXIS]
    med = [np.nanmedian   (delta_vals[r])      if delta_vals[r] else np.nan
           for r in PLOT_RUNG_AXIS]
    q3  = [np.nanpercentile(delta_vals[r], 75) if delta_vals[r] else np.nan
           for r in PLOT_RUNG_AXIS]
    if ENT_SMOOTH_WINDOW > 1:
        q1  = centred_ma(np.asarray(q1),  ENT_SMOOTH_WINDOW)
        med = centred_ma(np.asarray(med), ENT_SMOOTH_WINDOW)
        q3  = centred_ma(np.asarray(q3),  ENT_SMOOTH_WINDOW)

    fig_diff = go.Figure()
    if SHOW_INDIVIDUAL_LINES:
        # — add one thin grey line per paired (type, bed_start, bootstrap) group —
        common = set(full_map_A) & set(full_map_B)
        for key in common:
            a_r = full_map_A[key]
            b_r = full_map_B[key]
            x, y = [], []
            for rung in PLOT_RUNG_AXIS:
                if rung in a_r and rung in b_r:
                    x.append(rung)
                    y.append(b_r[rung] - a_r[rung])
            fig_diff.add_trace(go.Scatter(
                x=x,
                y=y,
                mode="lines",
                line=dict(width=0.5, color="grey"),
                opacity=0.3,
                showlegend=False
            ))
    fig_diff.add_trace(go.Scatter(x=PLOT_RUNG_AXIS, y=q1 ,
                                  mode="lines", line=dict(color="black",
                                  dash="dot", width=1), showlegend=False))
    fig_diff.add_trace(go.Scatter(x=PLOT_RUNG_AXIS, y=q3 ,
                                  mode="lines", line=dict(color="black",
                                  dash="dot", width=1), showlegend=False))
    fig_diff.add_trace(go.Scatter(x=PLOT_RUNG_AXIS, y=med,
                                  mode="lines", line=dict(color="black",
                                  width=ENT_MED_LINE_W),
                                  name="median ΔClustering"))
    fig_diff.add_shape(type="line",
                       x0=min(PLOT_RUNG_AXIS), x1=max(PLOT_RUNG_AXIS),
                       y0=0, y1=0,
                       line=dict(color="grey", dash="dash"))
    fig_diff.update_layout(template="plotly_white", width=1000, height=450,
        title=(f"ΔClustering (B – A), paired by type & bootstrap"),
        xaxis_title=AXIS_LABEL,
        yaxis_title="<0 → B more clustered<br>>0 → A more clustered")
    fig_diff.update_xaxes(showgrid=False, zeroline=False)
    fig_diff.update_yaxes(showgrid=False, zeroline=False)
    if SHOW_MOTIF_LINES: add_motif_guides(fig_diff, positions=motif_positions,
                                      style=MOTIF_LINE_STYLE)
    if SHOW_TSS_Q4_LINES:                                    # INSERT
        add_motif_guides(fig_diff,  positions=tss_q4_positions,
                                   style=TSS_Q4_LINE_STYLE)  # INSERT
    fig_diff.show()
# end SAVE_TYPE_DIFF_COMBINED

# ---------- B. per-type Evenness & ΔEvenness (optional) ----------


if SAVE_PER_TYPE:
    # build list of “groups” to iterate
    if CONSIDER_BED_START:
        # all unique (type, bed_start) pairs present in A or B
        groups = sorted(
            set(tuple(x) for x in pd.concat([dfA_all_bal, dfB_all_bal])[
                ["type","bed_start"]
             ].drop_duplicates().to_numpy()),
            key=lambda x: (x[0], x[1])
        )
    else:
        # just unique types
        groups = sorted(set(dfA_all_bal["type"]) | set(dfB_all_bal["type"]))

    for grp in groups:
        if CONSIDER_BED_START:
            typ, bed = grp
            sub_A = dfA_all_bal[(dfA_all_bal.type == typ) & (dfA_all_bal.bed_start == bed)]
            sub_B = dfB_all_bal[(dfB_all_bal.type == typ) & (dfB_all_bal.bed_start == bed)]
            title_suffix = f"{typ}_bed{bed}"
        else:
            typ = grp
            sub_A = dfA_all_bal[dfA_all_bal.type == typ]
            sub_B = dfB_all_bal[dfB_all_bal.type == typ]
            title_suffix = typ

        if sub_A.empty and sub_B.empty:
            continue

        sub_df_bal = pd.concat([sub_A, sub_B], ignore_index=True)
        even_vals = collect_kl_bootstrap_grouped(
            sub_df_bal,
            bin_size    = entropy_bin,
            splits      = BOOTSTRAP_SPLITS,
            return_full = False
        )

        # ── helper to get Q1/median/Q3 for one condition ──
        def _quartiles(e_dict):
            q1, med, q3 = [], [], []
            for rung in PLOT_RUNG_AXIS:
                vals = e_dict[rung]
                if vals:
                    _q1, _m, _q3 = np.percentile(vals, [25, 50, 75])
                else:
                    _q1 = _m = _q3 = np.nan
                q1.append(_q1); med.append(_m); q3.append(_q3)
            if ENT_SMOOTH_WINDOW > 1:
                q1  = centred_ma(np.asarray(q1),  ENT_SMOOTH_WINDOW)
                med = centred_ma(np.asarray(med), ENT_SMOOTH_WINDOW)
                q3  = centred_ma(np.asarray(q3),  ENT_SMOOTH_WINDOW)
            return q1, med, q3

        q1_A, med_A, q3_A = _quartiles(even_vals[COND_A])
        q1_B, med_B, q3_B = _quartiles(even_vals[COND_B])

        # ---------- build paired-bootstrap Δ lists ----------
        delta_vals = bootstrap_deltas(even_vals[COND_A], even_vals[COND_B])

        # quartiles of the Δ distribution
        diff_q1  = [np.nanpercentile(delta_vals[r], 25) if delta_vals[r] else np.nan
                    for r in PLOT_RUNG_AXIS]
        diff_med = [np.nanmedian   (delta_vals[r])       if delta_vals[r] else np.nan
                    for r in PLOT_RUNG_AXIS]
        diff_q3  = [np.nanpercentile(delta_vals[r], 75) if delta_vals[r] else np.nan
                    for r in PLOT_RUNG_AXIS]

        if ENT_SMOOTH_WINDOW > 1:
            diff_q1  = centred_ma(np.asarray(diff_q1),  ENT_SMOOTH_WINDOW)
            diff_med = centred_ma(np.asarray(diff_med), ENT_SMOOTH_WINDOW)
            diff_q3  = centred_ma(np.asarray(diff_q3),  ENT_SMOOTH_WINDOW)

        # ─────────────── figure layout ───────────────
        from plotly.subplots import make_subplots

        fig_t = make_subplots(
            rows=2, cols=1,
            shared_xaxes=True,
            vertical_spacing=0.05,
            row_heights=[0.6, 0.4]
        )

        # ── Row 1 : Evenness per condition ──
        for cond, clr, q1, med, q3 in [
            (COND_A, CLR_A, q1_A, med_A, q3_A),
            (COND_B, CLR_B, q1_B, med_B, q3_B)
        ]:
            fig_t.add_trace(
                go.Scatter(x=PLOT_RUNG_AXIS, y=q1,
                           mode="lines", line=dict(color=clr, dash="dot", width=1),
                           showlegend=False),
                row=1, col=1)
            fig_t.add_trace(
                go.Scatter(x=PLOT_RUNG_AXIS, y=q3,
                           mode="lines", line=dict(color=clr, dash="dot", width=1),
                           showlegend=False),
                row=1, col=1)
            fig_t.add_trace(
                go.Scatter(x=PLOT_RUNG_AXIS, y=med,
                           mode="lines", line=dict(color=clr, width=ENT_MED_LINE_W),
                           name=f"{cond} median"),
                row=1, col=1)

        # ── Row 2 : ΔEvenness (B – A) ──
        fig_t.add_trace(
            go.Scatter(x=PLOT_RUNG_AXIS, y=diff_q1,
                       mode="lines", line=dict(color="black", dash="dot", width=1),
                       showlegend=False),
            row=2, col=1)
        fig_t.add_trace(
            go.Scatter(x=PLOT_RUNG_AXIS, y=diff_q3,
                       mode="lines", line=dict(color="black", dash="dot", width=1),
                       showlegend=False),
            row=2, col=1)
        fig_t.add_trace(
            go.Scatter(x=PLOT_RUNG_AXIS, y=diff_med,
                       mode="lines", line=dict(color="black", width=ENT_MED_LINE_W),
                       name="median ΔEvenness"),
            row=2, col=1)

        # zero reference line
        fig_t.add_shape(type="line",
                        x0=min(PLOT_RUNG_AXIS), x1=max(PLOT_RUNG_AXIS),
                        y0=0, y1=0,
                        line=dict(color="grey", dash="dash"),
                        row=2, col=1)

        # ── cosmetics ──
        fig_t.update_layout(
            template     = "plotly_white",
            width        = 1000,
            height       = 600,
            title        = f"Clustering & ΔClustering (MT – WT) (type = {typ})",
            xaxis_title  = AXIS_LABEL,
            yaxis_title  = "1 → clustered, 0 → uniform)",
            yaxis2_title = "<0 → ET more clustered<br> >0 → MT more clustered"
        )
        fig_t.update_xaxes(showgrid=False, zeroline=False)
        fig_t.update_yaxes(showgrid=False, zeroline=False)
        fig_t.update_yaxes(showgrid=False, zeroline=False, row=2, col=1)
        if SHOW_MOTIF_LINES: add_motif_guides(fig_t,   positions=motif_positions,
                                      style=MOTIF_LINE_STYLE)
        if SHOW_TSS_Q4_LINES:                                    # INSERT
            add_motif_guides(fig_t,  positions=tss_q4_positions,
                                       style=TSS_Q4_LINE_STYLE)  # INSERT

        # ── save files ──
        tag = f"evenness_{title_suffix}_{COND_A}_vs_{COND_B}_{STAMP}"
        fig_t.write_image(OUT_DIR_NEW / f"{tag}.png", scale=2)
        fig_t.write_image(OUT_DIR_NEW / f"{tag}.svg")
        dbg(f"[SAVE] per-type {title_suffix}: {tag}.png / .svg")
# end if SAVE_PER_TYPE

################################################################################
#                           EXPORTS  (unchanged)                               #
################################################################################
base_tag = f"evenness_{COND_A}_vs_{COND_B}_{STAMP}"
fig_kl    .write_image(OUT_DIR_NEW/f"{base_tag}_kl.png"   ,scale=2)
fig_kl    .write_image(OUT_DIR_NEW/f"{base_tag}_kl.svg")
dbg(f"[SAVE] all evenness plots → {OUT_DIR_NEW}")
################################################################################
#                           EXPORTS  (unchanged)                               #
################################################################################
base_tag = f"evenness_diff_{COND_A}_vs_{COND_B}_{STAMP}"
fig_diff    .write_image(OUT_DIR_NEW/f"{base_tag}_kl.png"   ,scale=2)
fig_diff    .write_image(OUT_DIR_NEW/f"{base_tag}_kl.svg")
dbg(f"[SAVE] all evenness plots → {OUT_DIR_NEW}")

In [None]:
# ╔════════════════════════════════════════════════════════════════╗
# ║  CELL Y++++: CV · min‑CV · Δmin‑CV · nuc‑per‑read ratio        ║
# ╚════════════════════════════════════════════════════════════════╝
#
#  • The CV / min‑CV logic (panels 1‑3) is **unchanged** and still uses
#    DIST_WINDOW_BP, CV_STEP, MIN_WINDOW_BP, SMOOTH_HALF_WIN.
#  • The nuc‑count panels (4‑5) now have their *own* bin‑size & smoothing
#    controls:  NUC_STEP  and  NUC_SMOOTH_HALF_WIN.
# -------------------------------------------------------------------

import numpy as np, pandas as pd, plotly.graph_objects as go
from plotly.subplots import make_subplots

# ───────── USER CONFIG – CV / min‑CV (unchanged) ───────── #
DIST_WINDOW_BP      = 80      # ± window (bp) for CV calculations
CV_STEP             = 5      # grid spacing (bp)
MIN_WINDOW_BP       = 80     # ± window for local‑min CV search
SMOOTH_HALF_WIN     = 5       # MA half‑width for CV curves
# ───────── USER CONFIG – nuc‑count panels (NEW) ────────── #
NUC_STEP            = 10       # grid spacing (bp) for nuc‑count curves
NUC_SMOOTH_HALF_WIN = 1      # MA half‑width for nuc‑count curves
# --------------------------------------------------------- #

CLR_MAP = {COND_A: CLR_A, COND_B: CLR_B}

# ── shared helpers ────────────────────────────────────────
def _smooth(arr, half_win):
    arr = np.asarray(arr, float)
    if half_win == 0 or arr.size == 0:
        return arr
    pad = np.pad(arr, half_win, mode="edge")
    ker = np.ones(2*half_win + 1) / (2*half_win + 1)
    out = np.convolve(pad, ker, mode="valid")
    out[np.isnan(arr)] = np.nan
    return out

def _w_read(row):
    return 1 if READ_BALANCE_MODE in {"centres", "equal_read_len"} else row.read_length
# --------------------------------------------------------- #

# ───────── grids ─────────
orig_rungs   = (sorted({abs(c) for c in PLOT_RUNG_AXIS})
                if MIRROR_ABS_DISTANCE else list(PLOT_RUNG_AXIS))
axis_min, axis_max = min(orig_rungs), max(orig_rungs)

axis_pos_cv   = np.arange(axis_min, axis_max + CV_STEP,   CV_STEP)
axis_pos_nuc  = np.arange(axis_min, axis_max + NUC_STEP,  NUC_STEP)

# ──────────────────────────────────────────────────────────
#  1)  replicate‑level CV / min‑CV / ratio curves (old code)
# ──────────────────────────────────────────────────────────
def _replicate_curves_CV(read_df):
    offset_lists = {p: [] for p in axis_pos_cv}
    nuc_cnt      = np.zeros_like(axis_pos_cv, float)
    read_cnt     = np.zeros_like(axis_pos_cv, float)

    for row in read_df.itertuples(index=False):
        centres = np.asarray(row.nuc_centers, float)
        relpos  = np.asarray(row.rel_pos,      int)
        for idx, p in enumerate(axis_pos_cv):
            lo, hi = p - DIST_WINDOW_BP, p + DIST_WINDOW_BP
            if not ((relpos >= lo) & (relpos <= hi)).any():
                continue
            w = _w_read(row)
            read_cnt[idx] += w
            m = (np.abs(centres - p) < DIST_WINDOW_BP)
            if m.any():
                nuc_cnt[idx] += m.sum() * w
                offset_lists[p].extend(np.abs(centres[m] - p))

    # CV(|offset|)
    cv_raw = np.array([
        (np.nan if len(lst) < 2 or np.mean(lst) == 0
         else np.std(lst, ddof=0) / np.mean(lst))
        for lst in offset_lists.values()
    ])
    # local‑min CV
    half_idx = int(np.ceil(MIN_WINDOW_BP / CV_STEP))
    min_raw  = np.array([
        (np.nanmin(cv_raw[max(0,i-half_idx):i+half_idx+1])
         if np.any(~np.isnan(cv_raw[max(0,i-half_idx):i+half_idx+1])) else np.nan)
        for i in range(len(cv_raw))
    ])
    # nuc/read ratio OR raw counts
    if READ_BALANCE_MODE in {"centres", "equal_read_len"}:
        ratio_raw = nuc_cnt
    else:
        ratio_raw = np.divide(nuc_cnt, read_cnt,
                              out=np.full_like(nuc_cnt, np.nan),
                              where=read_cnt > 0)
    return cv_raw, min_raw, ratio_raw

# ──────────────────────────────────────────────────────────
#  2)  replicate‑level nuc‑count curves (NEW grid / smooth)
# ──────────────────────────────────────────────────────────
def _replicate_curves_NUC(read_df):
    nuc_cnt = np.zeros_like(axis_pos_nuc, float)
    for row in read_df.itertuples(index=False):
        centres = np.asarray(row.nuc_centers, float)
        for idx, p in enumerate(axis_pos_nuc):
            if np.any(np.abs(centres - p) < DIST_WINDOW_BP):
                nuc_cnt[idx] += 1  # one per centre
    return nuc_cnt

# ───────── gather replicate curves ─────────
rep_cv   = {COND_A: {}, COND_B: {}}
rep_nuc  = {COND_A: {}, COND_B: {}}

for cond, df_cond in ((COND_A, dfA_all_bal), (COND_B, dfB_all_bal)):
    for key, sub in df_cond.groupby(["type", "bed_start", "boot_idx"]):
        rep_cv [cond][key] = _replicate_curves_CV (sub.reset_index(drop=True))
        rep_nuc[cond][key] = _replicate_curves_NUC(sub.reset_index(drop=True))

# ───────── quartiles helper ─────────
def _q123(mat_list, axis_len):
    if not mat_list:
        return (np.full(axis_len, np.nan),)*3
    mat = np.vstack(mat_list)
    return tuple(np.nanpercentile(mat, q, axis=0) for q in (25,50,75))

# ───────── CV‑family quartiles ─────────
cv_Q, min_Q, ratio_Q = {}, {}, {}
for cond in (COND_A, COND_B):
    cv_Q   [cond] = _q123([v[0] for v in rep_cv[cond].values()],   len(axis_pos_cv))
    min_Q  [cond] = _q123([v[1] for v in rep_cv[cond].values()],   len(axis_pos_cv))
    ratio_Q[cond] = _q123([v[2] for v in rep_cv[cond].values()],   len(axis_pos_cv))

for cond, targ in [(COND_A, cv_Q), (COND_B, cv_Q),
                   (COND_A, min_Q), (COND_B, min_Q),
                   (COND_A, ratio_Q), (COND_B, ratio_Q)]:
    q1, med, q3 = targ[cond]
    targ[cond] = (_smooth(q1, SMOOTH_HALF_WIN),
                  _smooth(med, SMOOTH_HALF_WIN),
                  _smooth(q3, SMOOTH_HALF_WIN))

# ───────── nuc‑count quartiles ─────────
nuc_Q = {}
for cond in (COND_A, COND_B):
    nuc_raw = [v for v in rep_nuc[cond].values()]
    nuc_Q[cond] = _q123(nuc_raw, len(axis_pos_nuc))
    q1, med, q3 = nuc_Q[cond]
    nuc_Q[cond] = (_smooth(q1, NUC_SMOOTH_HALF_WIN),
                   _smooth(med, NUC_SMOOTH_HALF_WIN),
                   _smooth(q3, NUC_SMOOTH_HALF_WIN))

# Δ nuc‑count (A−B)
delta_nuc = []
common = set(rep_nuc[COND_A]) & set(rep_nuc[COND_B])
for k in common:
    delta_nuc.append(rep_nuc[COND_A][k] - rep_nuc[COND_B][k])
dN_Q = _q123(delta_nuc, len(axis_pos_nuc))
dN_Q = tuple(_smooth(arr, NUC_SMOOTH_HALF_WIN) for arr in dN_Q)

# ───────────────────────── plotting ─────────────────────────
fig = make_subplots(
    rows=5, cols=1, shared_xaxes=False,
    vertical_spacing=0.04,
    row_heights=[0.20]*5,
    subplot_titles=(
        f"Smoothed CV  (±{DIST_WINDOW_BP} bp)",
        f"Smoothed local‑min CV  (±{MIN_WINDOW_BP} bp)",
        f"Δ local‑min CV  ({COND_A} – {COND_B})",
        "Nucleosome count",
        f"Δ nuc count  ({COND_A} – {COND_B})"
    )
)

def _add_q(fig, x, qtuple, color, name, row):
    q1, med, q3 = qtuple
    fig.add_trace(go.Scatter(x=x, y=q1, mode="lines",
                             line=dict(color=color, dash="dot", width=1),
                             showlegend=False), row=row, col=1)
    fig.add_trace(go.Scatter(x=x, y=q3, mode="lines",
                             line=dict(color=color, dash="dot", width=1),
                             showlegend=False), row=row, col=1)
    fig.add_trace(go.Scatter(x=x, y=med, mode="lines",
                             line=dict(color=color, width=3), name=name),
                  row=row, col=1)

# Panels 1‑3 (CV, min‑CV, Δmin‑CV) use axis_pos_cv
_add_q(fig, axis_pos_cv, cv_Q  [COND_A], CLR_A, COND_A, row=1)
_add_q(fig, axis_pos_cv, cv_Q  [COND_B], CLR_B, COND_B, row=1)
_add_q(fig, axis_pos_cv, min_Q [COND_A], CLR_A, COND_A, row=2)
_add_q(fig, axis_pos_cv, min_Q [COND_B], CLR_B, COND_B, row=2)

q1_d, med_d, q3_d = _q123(
    [v[1] - rep_cv[COND_B][k][1] for k,v in rep_cv[COND_A].items()
     if k in rep_cv[COND_B]], len(axis_pos_cv))
q1_d = _smooth(q1_d, SMOOTH_HALF_WIN)
med_d= _smooth(med_d, SMOOTH_HALF_WIN)
q3_d = _smooth(q3_d, SMOOTH_HALF_WIN)
for arr in (q1_d, q3_d):
    fig.add_trace(go.Scatter(x=axis_pos_cv, y=arr, mode="lines",
                             line=dict(color="black", dash="dot", width=1),
                             showlegend=False), row=3, col=1)
fig.add_trace(go.Scatter(x=axis_pos_cv, y=med_d, mode="lines",
                         line=dict(color="black", width=3),
                         name="median Δ"), row=3, col=1)
fig.add_shape(type="line", x0=axis_min, x1=axis_max, y0=0, y1=0,
              line=dict(color="grey", dash="dash"), row=3, col=1)

# Panel 4 – nuc count (axis_pos_nuc)
_add_q(fig, axis_pos_nuc, nuc_Q[COND_A], CLR_A, COND_A, row=4)
_add_q(fig, axis_pos_nuc, nuc_Q[COND_B], CLR_B, COND_B, row=4)
fig.update_yaxes(title="nuc count", row=4, col=1)

# Panel 5 – Δ nuc count
q1_n, med_n, q3_n = dN_Q
fig.add_trace(go.Scatter(x=axis_pos_nuc, y=q1_n,
                         mode="lines", line=dict(color="black", dash="dot", width=1),
                         showlegend=False), row=5, col=1)
fig.add_trace(go.Scatter(x=axis_pos_nuc, y=q3_n,
                         mode="lines", line=dict(color="black", dash="dot", width=1),
                         showlegend=False), row=5, col=1)
fig.add_trace(go.Scatter(x=axis_pos_nuc, y=med_n,
                         mode="lines", line=dict(color="black", width=3),
                         name="median Δ nuc"), row=5, col=1)
fig.add_shape(type="line", x0=axis_min, x1=axis_max, y0=0, y1=0,
              line=dict(color="grey", dash="dash"), row=5, col=1)
fig.update_yaxes(title="Δ nuc count", row=5, col=1)

# Cosmetics & guides
for r in (1,2,3):
    fig.update_yaxes(showgrid=False, zeroline=False, row=r, col=1)
fig.update_xaxes(showgrid=False, zeroline=False, row=1, col=1)
fig.update_xaxes(title=AXIS_LABEL, row=5, col=1)
if SHOW_MOTIF_LINES:
    add_motif_guides(fig, positions=motif_positions, style=MOTIF_LINE_STYLE)
if SHOW_TSS_Q4_LINES:
    add_motif_guides(fig, positions=tss_q4_positions, style=TSS_Q4_LINE_STYLE)

fig.update_layout(template="plotly_white", width=800, height=1600,
                  showlegend=False)

# set x axis range to -900 to 900
fig.update_xaxes(range=(-900, 900))
fig.show()

# ── DEBUG: in‑window nucleosome counts ──
def _count_centres_in_window(df):
    return sum(np.sum((np.asarray(c)>=-PLOT_WINDOW)&(np.asarray(c)<=PLOT_WINDOW))
               for c in df.nuc_centers)

dbg(f"[In‑window Nucs] {COND_A}: {_count_centres_in_window(dfA_all_bal)} | "
    f"{COND_B}: {_count_centres_in_window(dfB_all_bal)}  (±{PLOT_WINDOW} bp)")


In [None]:
# ╔════════════════════════════════════════════════════════════════╗
# ║  Inter‑nucleosome distance  – single‑plot + per‑group exports  ║
# ╚════════════════════════════════════════════════════════════════╝
#
# ‣ Re‑uses global flags already defined elsewhere:
#     • SAVE_PER_TYPE        – toggle per‑(type[/bed_start]) exports
#     • CONSIDER_BED_START   – group by type *and* bed_start when True
#     • SHOW_MOTIF_LINES / SHOW_TSS_Q4_LINES  – dashed guide lines
#
# ‣ All CONFIG knobs relevant to *this* cell are (re)stated below.
# ‣ Set OUT_DIR_INT to change the export folder just for these plots.
# ‣ Requires `add_motif_guides`, `motif_positions`, `tss_q4_positions`,
#   and the balanced dataframes `dfA_all_bal`, `dfB_all_bal`
#   to be in scope (guaranteed by upstream cells).
# ------------------------------------------------------------------

import numpy as np, pandas as pd, plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
import multiprocessing as mp
from tqdm.auto import tqdm

# ───────── USER CONFIG – inter‑nuc distance ───────── #
INT_DIST_STEP        = 5       # grid spacing (bp)
INT_DIST_MAX         =  1000    # keep pairs with distance ≤ this
INT_SMOOTH_HALF_WIN  = 2       # moving‑avg half‑width (0 ⇒ no smoothing)
REQUIRE_MOD_IN_GAP   = True   # True → demand ≥1 mod hit between cores
INT_XWINDOW          = (-5000, 5000)
INT_YWINDOW1         = (160, 320)   # Panel 1 (median distance)
INT_YWINDOW2         = (-120,  120)   # Panel 2 (Δ distance)
# ─── Robust trimmed statistics ─── #
DISPERSION_METHOD = "iqr"   # or "iqr"
TRIM_FRAC         = 0.10        # ignored if DISPERSION_METHOD == "iqr"

# Saving / export
SAVE_PER_TYPE        = False        # (inherited) toggle per‑group files
CONSIDER_BED_START   = False   # (inherited) grouping policy
OUT_DIR_INT          = Path("/Data1/git/meyer-nanopore/scripts/analysis/images_20250725/DPY27_N2_dyad_to_dyad_900bp_strong_rex")   # ← change if needed
STAMP                = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
# ---------------------------------------------------- #
# ───────── MOTIF LINES (optional) ───────── #
SHOW_MOTIF_LINES = False          # dashed grey guides on plots
MOTIF_LINE_STYLE = dict(color="grey", width=1, dash="dash")
SHOW_TSS_Q4_LINES  = False                     # green guides for TSS_q4
TSS_Q4_LINE_STYLE  = dict(color="grey", width=1, dash="dashdot")

OUT_DIR_INT.mkdir(parents=True, exist_ok=True)

# ╔═════════════════  REQUIRED INPUTS  (assertions)  ═════════════╗
for name in ("dfA_all_bal", "dfB_all_bal", "PLOT_WINDOW",
             "COND_A", "COND_B", "CLR_A", "CLR_B"):
    assert name in globals(), f"Variable {name} not defined – run earlier cells."
# ╚══════════════════════════════════════════════════════════════╝
# ===============================================================
# (A) ONE‑TIME definition – paste right after the USER CONFIG box
#     (just before the assertions are run is a good spot)
# ===============================================================
# Decide the fields that uniquely identify a “replicate”
#   – includes bed_start only when CONSIDER_BED_START is True
KEY_FIELDS = ["type", "bed_start", "boot_idx"] if CONSIDER_BED_START else \
             ["type", "boot_idx"]
# ===============================================================
# ──────────────────────────────────────────────────────────────
#  NEW helper – consistent condition → colour mapping
# ──────────────────────────────────────────────────────────────
from itertools import cycle

CLR_SDC = "#b12537"     # red
CLR_N2  = "#4974a5"     # blue
CLR_DPY = "#47B562"     # green

_GREY_CYCLE = cycle(["#4d4d4d", "#a3a3a3"])
_cond2clr_cache = {}

def cond_color(cond: str) -> str:
    """
    Return the plotting colour for a condition label.
      * contains 'SDC' → red
      * contains 'DPY' → green
      * contains 'N2'  → blue
      * else           → next shade of grey (stable per unknown label)
    """
    key = cond.lower()
    if "sdc" in key:
        return CLR_SDC
    if "dpy" in key:
        return CLR_DPY
    if "n2" in key:
        return CLR_N2

    # unseen ‘other’ label → assign next grey in the cycle
    if key not in _cond2clr_cache:
        _cond2clr_cache[key] = next(_GREY_CYCLE)
    return _cond2clr_cache[key]

# ─── replace the previous “col_A / col_B” definitions with:
col_A = cond_color(COND_A)
col_B = cond_color(COND_B)
# ──────────────────────────────────────────────────────────────
#  Recompute MOTIF / TSS guide‑line positions for this cell only
# ──────────────────────────────────────────────────────────────
src_df = pd.concat([dfA_all_bal, dfB_all_bal], ignore_index=True)

if SHOW_MOTIF_LINES and "motif_rel_start" in src_df.columns:
    _raw = []
    for tup in src_df["motif_rel_start"].dropna():
        _raw.extend(tup if hasattr(tup, "__iter__") else [tup])
    motif_positions = sorted({
        abs(int(p)) if globals().get("MIRROR_ABS_DISTANCE", False) else int(p)
        for p in _raw
    })
else:
    motif_positions = []

if SHOW_TSS_Q4_LINES and {"tss_rel_start", "tss_attributes"} <= set(src_df.columns):
    _tss_raw = []
    for rels, attrs in src_df[["tss_rel_start", "tss_attributes"]].dropna().itertuples(index=False):
        for pos, attr in zip(rels, attrs):              # attr = (start, strand, type)
            if attr[2] in ("TSS_q4", "TSS_q3"):
                _tss_raw.append(pos)
    tss_q4_positions = sorted({
        abs(int(p)) if globals().get("MIRROR_ABS_DISTANCE", False) else int(p)
        for p in _tss_raw
    })
else:
    tss_q4_positions = []

def add_motif_guides(fig, *, positions, style, yref="paper"):
    """Vertical dashed lines at motif / TSS positions."""
    for pos in positions:
        fig.add_shape(type="line", x0=pos, x1=pos, y0=0, y1=1,
                      xref="x", yref=yref, line=style, layer="below")

# ──────────────────────────────────────────────────────────────
#  Helper utilities
# ──────────────────────────────────────────────────────────────
def _smooth(arr, k):
    if k == 0:
        return arr.astype(float)
    pad = np.pad(arr.astype(float), k, mode="edge")
    ker = np.ones(2 * k + 1, float) / (2 * k + 1)
    return np.convolve(pad, ker, mode="valid")

def _q123(list_of_arrays, length):
    if not list_of_arrays:
        return (np.full(length, np.nan),
                np.full(length, np.nan),
                np.full(length, np.nan))
    stack = np.vstack([np.asarray(a, float) for a in list_of_arrays])
    with np.errstate(all="ignore"):
        q1  = np.nanpercentile(stack, 25, axis=0)
        med = np.nanpercentile(stack, 50, axis=0)
        q3  = np.nanpercentile(stack, 75, axis=0)
    return q1, med, q3

# ──────────────────────────────────────────────────────────────
#  Axis grid for distance curves
# ──────────────────────────────────────────────────────────────
axis_min, axis_max = INT_XWINDOW
axis_pos_int = np.arange(axis_min, axis_max + INT_DIST_STEP, INT_DIST_STEP)

# ──────────────────────────────────────────────────────────────
#  Core computation: replicate → (median distance, pair count)
# ──────────────────────────────────────────────────────────────
# ──────────────────────────────────────────────────────────────
#  NEW: compute q1 / median / q3 inside each replicate
# ──────────────────────────────────────────────────────────────
# ──────────────────────────────────────────────────────────────
#  NEW: trimmed‑mean ± trimmed‑SD curves for one replicate
# ──────────────────────────────────────────────────────────────
# ──────────────────────────────────────────────────────────────
#  Helper: human‑readable label for plot titles
# ──────────────────────────────────────────────────────────────
def disp_label():
    return ("Median ± IQR" if DISPERSION_METHOD == "iqr"
            else f"Trimmed mean ± SD (trim {int(TRIM_FRAC*100)} %)")

# ──────────────────────────────────────────────────────────────
#  Compute (lower, centre, upper) curves for one replicate
#    • "iqr"      → (q1, median, q3)
#    • "trimmed"  → (mean‑SD, mean, mean+SD) after trimming
# ──────────────────────────────────────────────────────────────
from scipy.stats import trim_mean

def _replicate_curves_INT(read_df):
    dist_lists = {p: [] for p in axis_pos_int}
    pair_cnt   = np.zeros_like(axis_pos_int, float)

    for row in read_df.itertuples(index=False):
        centres = np.asarray(row.nuc_centers, float)
        if centres.size < 2:
            continue

        relpos = np.asarray(row.rel_pos,      int)
        mods   = np.asarray(row.mod_qual_bin, int)

        in_win = centres[(centres >= -PLOT_WINDOW) & (centres <= PLOT_WINDOW)]
        if in_win.size < 2:
            continue

        for idx, p in enumerate(axis_pos_int):
            ups   = in_win[in_win <= p]
            downs = in_win[in_win >= p]
            if ups.size == 0 or downs.size == 0:
                continue

            dist = downs.min() - ups.max()
            if not (0 < dist <= INT_DIST_MAX):
                continue

            if REQUIRE_MOD_IN_GAP:
                in_gap = (relpos > ups.max()) & (relpos < downs.min())
                if not (in_gap.any() and (mods[in_gap] == 1).any()):
                    continue

            dist_lists[p].append(dist)
            pair_cnt[idx] += 1

    lo, mid, hi = [], [], []
    for vals in dist_lists.values():
        if not vals:
            lo.append(np.nan); mid.append(np.nan); hi.append(np.nan); continue
        vals = np.asarray(vals, float)

        if DISPERSION_METHOD == "iqr":
            q1, med, q3 = np.nanpercentile(vals, [25, 50, 75])
            lo.append(q1); mid.append(med); hi.append(q3)
        else:  # "trimmed"
            vals = np.sort(vals)
            n    = len(vals)
            k    = int(np.floor(TRIM_FRAC * n))
            if 2 * k >= n - 1:
                lo.append(np.nan); mid.append(np.nan); hi.append(np.nan); continue
            trimmed = vals[k:-k]
            mu = trimmed.mean()
            sd = trimmed.std(ddof=0)
            lo.append(mu - sd); mid.append(mu); hi.append(mu + sd)

    return (np.array(lo), np.array(mid), np.array(hi)), pair_cnt


# ──────────────────────────────────────────────────────────────
#  Parallel gathering of all replicate curves
# ──────────────────────────────────────────────────────────────
def _task_replicate(args):
    cond, key, subdf = args
    return cond, key, _replicate_curves_INT(subdf)
# ===============================================================
tasks = []
for cond, df_cond in ((COND_A, dfA_all_bal), (COND_B, dfB_all_bal)):
    # group by the key we just defined
    for key, sub in df_cond.groupby(KEY_FIELDS, sort=False):
        # key is a tuple whose length matches KEY_FIELDS
        tasks.append((cond, key, sub.reset_index(drop=True)))
# ===============================================================

rep_int = {COND_A: {}, COND_B: {}}
with mp.Pool(max(2, mp.cpu_count() - 1)) as pool:
    for cond, key, result in tqdm(pool.imap_unordered(_task_replicate, tasks),
                                  total=len(tasks), desc="INT‑distance curves"):
        rep_int[cond][key] = result

# ──────────────────────────────────────────────────────────────
#  Pooled quartiles (median ± IQR) & Δ‑distance
# ──────────────────────────────────────────────────────────────
# ──────────────────────────────────────────────────────────────
#  Pooled (mean‑of‑groups) IQR bands & Δ‑distance
# ──────────────────────────────────────────────────────────────
dist_Q, pair_Q = {}, {}
for cond in (COND_A, COND_B):
    # collect each group's q‑tuple and pair‑count
    q_tuples = [v[0] for v in rep_int[cond].values()]  # (q1,med,q3)
    pair_raw = [v[1] for v in rep_int[cond].values()]

    if q_tuples:
        q1_mean  = np.nanmean([t[0] for t in q_tuples], axis=0)
        med_mean = np.nanmean([t[1] for t in q_tuples], axis=0)
        q3_mean  = np.nanmean([t[2] for t in q_tuples], axis=0)
        dist_Q[cond] = tuple(
            _smooth(a, INT_SMOOTH_HALF_WIN)
            for a in (q1_mean, med_mean, q3_mean)
        )
    else:
        dist_Q[cond] = (np.full_like(axis_pos_int, np.nan),) * 3

    if pair_raw:
        pair_mean = np.nanmean(pair_raw, axis=0)
        pair_Q[cond] = (pair_mean, pair_mean, pair_mean)  # single curve
    else:
        pair_Q[cond] = (np.full_like(axis_pos_int, np.nan),) * 3

# averaged‑median Δ and its averaged IQR bounds
delta_q1  = dist_Q[COND_A][0] - dist_Q[COND_B][0]
delta_med = dist_Q[COND_A][1] - dist_Q[COND_B][1]
delta_q3  = dist_Q[COND_A][2] - dist_Q[COND_B][2]
dD_Q = tuple(_smooth(a, INT_SMOOTH_HALF_WIN)
             for a in (delta_q1, delta_med, delta_q3))


# ──────────────────────────────────────────────────────────────
#  Plot helpers
# ──────────────────────────────────────────────────────────────
def _add_q(fig, x, qtuple, color, name, row):
    q1, med, q3 = qtuple
    fig.add_trace(go.Scatter(x=x, y=q1, mode="lines",
                             line=dict(color=color, dash="dot", width=1),
                             showlegend=False), row=row, col=1)
    fig.add_trace(go.Scatter(x=x, y=q3, mode="lines",
                             line=dict(color=color, dash="dot", width=1),
                             showlegend=False), row=row, col=1)
    fig.add_trace(go.Scatter(x=x, y=med, mode="lines",
                             line=dict(color=color, width=3),
                             name=name), row=row, col=1)

# ╔═══════════════════  Pooled figure  ═════════════════════════╗
fig_int = make_subplots(
    rows=3, cols=1, shared_xaxes=False,
    vertical_spacing=0.05,
    row_heights=[0.33, 0.33, 0.34],
    subplot_titles=(
        f"Inter‑nucleosome distance – {disp_label()} (≤ {INT_DIST_MAX} bp)",
        f"Δ distance  ({COND_A} – {COND_B})",
        "Valid pair count"
    )

)
_add_q(fig_int, axis_pos_int, dist_Q[COND_A], col_A, COND_A, row=1)
_add_q(fig_int, axis_pos_int, dist_Q[COND_B], col_B, COND_B, row=1)
fig_int.update_yaxes(title="Dyad‑dyad dist (bp)",
                     range=INT_YWINDOW1, row=1, col=1)

q1_d, med_d, q3_d = dD_Q
for arr in (q1_d, q3_d):
    fig_int.add_trace(go.Scatter(x=axis_pos_int, y=arr, mode="lines",
                                 line=dict(color="black", dash="dot", width=1),
                                 showlegend=False), row=2, col=1)
fig_int.add_trace(go.Scatter(x=axis_pos_int, y=med_d, mode="lines",
                             line=dict(color="black", width=3),
                             name="Dyad‑dyad dist (bp)"), row=2, col=1)
fig_int.add_shape(type="line", x0=axis_min, x1=axis_max, y0=0, y1=0,
                  line=dict(color="grey", dash="dash"), row=2, col=1)
fig_int.update_yaxes(title="Δ median per rex (bp)",
                     range=INT_YWINDOW2, row=2, col=1)

_add_q(fig_int, axis_pos_int, pair_Q[COND_A], col_A, COND_A, row=3)
_add_q(fig_int, axis_pos_int, pair_Q[COND_B], col_B, COND_B, row=3)
fig_int.update_yaxes(title="# pairs", row=3, col=1)

for r in (1, 2, 3):
    fig_int.update_xaxes(range=INT_XWINDOW, showgrid=False, zeroline=False,
                         row=r, col=1)
    fig_int.update_yaxes(showgrid=False, zeroline=False, row=r, col=1)
for t in fig_int.data:
    t.showlegend = (t.name in (COND_A, COND_B) and t.line.width == 3)

if SHOW_MOTIF_LINES: add_motif_guides(fig_int, positions=motif_positions,
                                      style=MOTIF_LINE_STYLE)
if SHOW_TSS_Q4_LINES: add_motif_guides(fig_int, positions=tss_q4_positions,
                                       style=TSS_Q4_LINE_STYLE)

fig_int.update_layout(
    template="plotly_white", width=800, height=1200,
    legend=dict(orientation='h', x=0.5, y=1.02,
                xanchor='center', yanchor='bottom'),
    margin=dict(t=100)
)
fig_int.show()

# Save the pooled inter‑nuc distance figure
base_tag = f"intdist_pooled_{COND_A}_vs_{COND_B}_{STAMP}"
fig_int.write_image(OUT_DIR_INT / f"{base_tag}.png", scale=2)
fig_int.write_image(OUT_DIR_INT / f"{base_tag}.svg")
print(f"[SAVE] pooled plot: {base_tag}.png / .svg → {OUT_DIR_INT}")

# ╔═══════════  Per‑type(/bed_start) exports  ══════════════════╗
# ╔═══════════  Per-type(/bed) exports ═════════════╗
if SAVE_PER_TYPE:
    if CONSIDER_BED_START:
        groups=sorted({tuple(x) for x in src_df[["type","bed_start"]].drop_duplicates().to_numpy()})
    else:
        groups=sorted(src_df["type"].unique())
    for grp in groups:
        if CONSIDER_BED_START:
            typ,bed=grp
            df_grp=src_df[(src_df.type==typ)&(src_df.bed_start==bed)]
            sel=lambda k: k[0]==typ and k[1]==bed
            suffix=f"{typ}_bed{bed}"
        else:
            typ=grp
            df_grp=src_df[src_df.type==typ]
            sel=lambda k: k[0]==typ
            suffix=typ

        # group-specific motif/TSS
        if SHOW_MOTIF_LINES and "motif_rel_start" in df_grp:
            raw=[]
            for tup in df_grp["motif_rel_start"].dropna():
                raw.extend(tup if hasattr(tup,"__iter__") else [tup])
            mot_pos=sorted({abs(int(p)) if globals().get("MIRROR_ABS_DISTANCE",False)
                            else int(p) for p in raw})
        else:
            mot_pos=[]
        if SHOW_TSS_Q4_LINES and {"tss_rel_start","tss_attributes"}<=set(df_grp):
            raw=[]
            for rels,attrs in df_grp[["tss_rel_start","tss_attributes"]].dropna().itertuples(index=False):
                for pos,attr in zip(rels,attrs):
                    if attr[2] in ("TSS_q4","TSS_q3"):
                        raw.append(pos)
            tss_pos=sorted({abs(int(p)) if globals().get("MIRROR_ABS_DISTANCE",False)
                             else int(p) for p in raw})
        else:
            tss_pos=[]

        # quartiles for this group
        dist_Qg, pair_Qg = {}, {}
        for cond in (COND_A,COND_B):
            D=[v[0] for k,v in rep_int[cond].items() if sel(k)]
            C=[v[1] for k,v in rep_int[cond].items() if sel(k)]
            if D or C:
                dist_Qg[cond]=tuple(_smooth(a,INT_SMOOTH_HALF_WIN)
                                    for a in _q123(D,len(axis_pos_int)))
                pair_Qg[cond]=tuple(_smooth(a,INT_SMOOTH_HALF_WIN)
                                    for a in _q123(C,len(axis_pos_int)))

        if not dist_Qg: continue

        # Δ for group
        if COND_A in dist_Qg and COND_B in dist_Qg:
            delta_g=[dist_Qg[COND_A][1]-dist_Qg[COND_B][1]]
            dQg=tuple(_smooth(a,INT_SMOOTH_HALF_WIN)
                      for a in _q123(delta_g,len(axis_pos_int)))
        else:
            dQg=dD_Q

        # build figure
        fig=make_subplots(rows=3,cols=1,shared_xaxes=False,vertical_spacing=0.05,
                          row_heights=[0.33,0.33,0.34],
                          subplot_titles=[
                              f"Inter‑nuc distance (median±IQR, ≤{INT_DIST_MAX}bp)",
                              f"Δ distance ({COND_A}–{COND_B})","# pairs"])
        _add_q(fig,axis_pos_int,dist_Qg.get(COND_A,dist_Q[COND_A]),col_A,COND_A,1)
        _add_q(fig,axis_pos_int,dist_Qg.get(COND_B,dist_Q[COND_B]),col_B,COND_B,1)
        fig.update_yaxes(range=INT_YWINDOW1,title="Dyad‑dyad dist",row=1,col=1)

        # ── Row 2 : Δ distance  (median ± IQR) ──
        q1_sub, med_sub, q3_sub = dQg          # unpack

        # IQR bounds (dotted)
        for y_arr in (q1_sub, q3_sub):
            fig.add_trace(
                go.Scatter(
                    x=axis_pos_int,
                    y=y_arr,
                    mode="lines",
                    line=dict(color="black", dash="dot", width=1),
                    showlegend=False,
                ),
                row=2,
                col=1,
            )

        # Median Δ line (solid)
        fig.add_trace(
            go.Scatter(
                x=axis_pos_int,
                y=med_sub,
                mode="lines",
                line=dict(color="black", width=3),
                name="Δ median",
            ),
            row=2,
            col=1,
        )

        # zero reference
        fig.add_shape(
            type="line",
            x0=axis_min, x1=axis_max, y0=0, y1=0,
            line=dict(color="grey", dash="dash"),
            row=2,
            col=1,
        )
        fig.update_yaxes(range=INT_YWINDOW2, title="Δ median per rex (bp)",
                         row=2, col=1)
        fig.update_yaxes(range=INT_YWINDOW2,title="Δ (bp)",row=2,col=1)

        _add_q(fig,axis_pos_int,pair_Qg.get(COND_A,pair_Q[COND_A]),col_A,COND_A,3)
        _add_q(fig,axis_pos_int,pair_Qg.get(COND_B,pair_Q[COND_B]),col_B,COND_B,3)
        fig.update_yaxes(title="# pairs",row=3,col=1)

        for r in (1,2,3):
            fig.update_xaxes(range=INT_XWINDOW,showgrid=False,zeroline=False,row=r,col=1)
            fig.update_yaxes(showgrid=False,zeroline=False,row=r,col=1)
        for t in fig.data: t.showlegend=(t.name in (COND_A,COND_B) and t.line.width==3)

        if SHOW_MOTIF_LINES: add_motif_guides(fig,positions=mot_pos,style=MOTIF_LINE_STYLE)
        if SHOW_TSS_Q4_LINES:add_motif_guides(fig,positions=tss_pos,style=TSS_Q4_LINE_STYLE)

        fig.update_layout(template="plotly_white",width=800,height=1200,
                          legend=dict(orientation='h',x=0.5,y=1.02,
                                      xanchor='center',yanchor='bottom'),
                          margin=dict(t=100),
                          title=f"Inter‑nuc distance – {suffix}")
        tag=f"intdist_{suffix}_{COND_A}_vs_{COND_B}_{STAMP}"
        fig.write_image(OUT_DIR_INT/f"{tag}.png",scale=2)
        fig.write_image(OUT_DIR_INT/f"{tag}.svg")
        print(f"[SAVE] {suffix}: {tag}.png/.svg")


In [None]:
# ╔════════════════════════════════════════════════════════════════╗
# ║  CELL X: 1st→other inter‑nuc distance · NRL (± SD)             ║
# ╚════════════════════════════════════════════════════════════════╝
import numpy as np, plotly.graph_objects as go
from scipy.stats import gaussian_kde
from scipy.signal import find_peaks

# ───────────────────────── USER CONFIG ────────────────────────── #
MAX_CENTER_ABS  = 5000      # keep centres with |coord| < this
DIST_X_MAX      = 750       # x‑axis upper bound (bp)
GRID_STEP_BP    = 5         # grid step for KDE & x‑axis
KDE_BW_FACTOR   = 0.20      # gaussian_kde bandwidth scaling
PEAK_MIN_PROM   = 0.001     # min prominence for peak calling
BOOT_N          = 50       # bootstrap replicates for ± SD
# ───────────────────────────────────────────────────────────────── #

# helper: grey‑override when needed
def _pick_color(name, base, first):
    if ("N2" in name) or ("SDC" in name):
        return base
    return "#4d4d4d" if first else "#bfbfbf"

# assemble condition‑definitions
cond_defs = [
    (COND_A, dfA_all_bal, CLR_A),
    (COND_B, dfB_all_bal, CLR_B),
]
if "dfC_all_bal" in globals():
    cond_defs.append((COND_C, dfC_all_bal, globals().get("CLR_C", "#d95f02")))

# ─────────────────── Distance extraction per condition ─────────────────── #
dist_by_cond, col_by_cond = {}, {}
for idx, (cond, df_cond, base_col) in enumerate(cond_defs):
    dists = []
    for row in df_cond.itertuples(index=False):
        centres = np.asarray(row.nuc_centers, float)
        centres = centres[np.abs(centres) < MAX_CENTER_ABS]
        if centres.size < 2:
            continue

        # ── NEGATIVE side (centres < 0) ──
        neg = centres[centres < 0]
        if neg.size >= 2:
            first_neg = neg.min()                 # most upstream (most negative)
            dists.extend(np.abs(neg[neg != first_neg] - first_neg))

        # ── POSITIVE side (centres > 0) ──
        pos = centres[centres > 0]
        if pos.size >= 2:
            first_pos = pos.min()                # closest to origin
            dists.extend(np.abs(pos[pos != first_pos] - first_pos))

    dists = np.asarray([d for d in dists if d <= DIST_X_MAX])
    if dists.size:
        dist_by_cond[cond] = dists
        col_by_cond[cond]  = _pick_color(cond, base_col, idx == 0)

x_grid = np.arange(0, DIST_X_MAX + GRID_STEP_BP, GRID_STEP_BP)

# containers for peaks & NRL estimates
peaks_by_cond, slope_by_cond, sd_by_cond = {}, {}, {}

# ─────────────────────────── KDE + peak calling ─────────────────────────── #
fig_kde = go.Figure()
for cond, dvec in dist_by_cond.items():
    kde  = gaussian_kde(dvec, bw_method=KDE_BW_FACTOR)
    pdf  = kde(x_grid)
    pct  = pdf * GRID_STEP_BP * 100
    fig_kde.add_trace(go.Scatter(
        x=x_grid, y=pct, mode="lines",
        name=cond, line=dict(width=3, color=col_by_cond[cond])
    ))

    # peaks
    peaks_idx, _ = find_peaks(pct, prominence=PEAK_MIN_PROM)
    peaks_x = x_grid[peaks_idx]
    peaks_by_cond[cond] = peaks_x

    for px in peaks_x:
        fig_kde.add_shape(
            type="line", x0=px, x1=px, y0=0, y1=pct.max(),
            line=dict(color=col_by_cond[cond], dash="dot", width=1)
        )
        fig_kde.add_annotation(
            x=px, y=pct.max(), text=f"{px:.0f}", showarrow=False,
            yshift=4, font=dict(size=10, color=col_by_cond[cond])
        )

fig_kde.update_layout(
    template="plotly_white",
    title="Inter‑nucleosome distance from first centre (KDE + peaks)",
    xaxis_title="distance from first centre (bp)",
    yaxis_title=f"% of distances per {GRID_STEP_BP} bp",
    width=800, height=500,
    legend=dict(orientation="h", x=0.5, y=-0.25, xanchor="center"),
    margin=dict(b=100)
)
fig_kde.update_xaxes(range=[100, DIST_X_MAX], showgrid=False, zeroline=False)
fig_kde.update_yaxes(showgrid=False, zeroline=False)
fig_kde.show()

# ───────────────────── Scatter + linear NRL fit (± SD) ───────────────────── #
fig_nrl = go.Figure()

for cond, peaks_x in peaks_by_cond.items():
    if len(peaks_x) < 2:
        continue
    orders = np.arange(1, len(peaks_x) + 1)

    # OLS slope
    slope, intercept = np.polyfit(orders, peaks_x, 1)
    slope_by_cond[cond] = slope

    # bootstrap SD
    boot_slopes = []
    rng = np.random.default_rng(seed=0)
    dvec = dist_by_cond[cond]

    for _ in range(BOOT_N):
        resample = rng.choice(dvec, size=dvec.size, replace=True)
        kde_rs   = gaussian_kde(resample, bw_method=KDE_BW_FACTOR)(x_grid)
        pk_idx_rs, _ = find_peaks(kde_rs * GRID_STEP_BP * 100,
                                  prominence=PEAK_MIN_PROM)
        pk_rs = x_grid[pk_idx_rs]
        if len(pk_rs) < 2:
            continue
        ord_rs = np.arange(1, len(pk_rs) + 1)
        s_rs, _ = np.polyfit(ord_rs, pk_rs, 1)
        boot_slopes.append(s_rs)

    sd = np.std(boot_slopes, ddof=1) if boot_slopes else np.nan
    sd_by_cond[cond] = sd

    # scatter + trend
    fig_nrl.add_trace(go.Scatter(
        x=orders, y=peaks_x, mode="markers",
        marker=dict(color=col_by_cond[cond], size=8),
        name=f"{cond} peaks", showlegend=False
    ))
    y_fit = slope * orders + intercept
    fig_nrl.add_trace(go.Scatter(
        x=orders, y=y_fit, mode="lines",
        line=dict(color=col_by_cond[cond], width=3),
        name=f"{cond}  NRL = {slope:.1f} ± {sd:.1f} bp"
    ))

fig_nrl.update_layout(
    template="plotly_white",
    title="NRL estimation from peak order vs. distance",
    xaxis_title="Nucleosome repeats (x)",
    yaxis_title="dyad-dyad distance (bp)",
    width=800, height=500,
    legend=dict(orientation="h", x=0.5, y=-0.25, xanchor="center"),
    margin=dict(b=100)
)
fig_nrl.update_xaxes(showgrid=False, zeroline=False)
fig_nrl.update_yaxes(showgrid=False, zeroline=False)
fig_nrl.show()


In [None]:
# ╔════════════════════════════════════════════════════════════════╗
# ║  NEW: inter‑nucleosome distance panels  – BOOTSTRAP VERSION    ║
# ╚════════════════════════════════════════════════════════════════╝
import numpy as np, pandas as pd, plotly.graph_objects as go
from plotly.subplots import make_subplots
import multiprocessing as mp
from tqdm.auto import tqdm
from functools import partial
from random import Random

# ───────── USER CONFIG – inter‑nuc distance ───────── #
INT_DIST_STEP       = 5       # grid spacing (bp)
INT_DIST_MAX        = 500     # keep pairs with distance ≤ this
INT_SMOOTH_HALF_WIN = 1       # moving‑avg half‑width (0 ⇒ no smoothing)
REQUIRE_MOD_IN_GAP  = False   # False ⇒ ignore modification filter
INT_XWINDOW         = (-900, +900)

INT_YWINDOW1        = (150, 275)   # Panel 1 y‑range
INT_YWINDOW2        = (-20,  60)   # Panel 2 y‑range

BOOT_N              = 20      # ◀── number of bootstrap groups
BOOT_SEED           = 42      # reproducible shuffling
# ---------------------------------------------------- #

axis_pos_int = np.arange(axis_min, axis_max + INT_DIST_STEP, INT_DIST_STEP)

# ──────────────────────────────────────────────────────────────────
#  Helper: one dataframe → median distance & pair‑count curves
# ──────────────────────────────────────────────────────────────────
def _curves_INT(read_df):
    dist_lists = {p: [] for p in axis_pos_int}
    pair_cnt   = np.zeros_like(axis_pos_int, float)

    for row in read_df.itertuples(index=False):
        centres = np.asarray(row.nuc_centers, float)
        if centres.size < 2:
            continue

        relpos = np.asarray(row.rel_pos,      int)
        mods   = np.asarray(row.mod_qual_bin, int)

        in_win = centres[(centres >= -PLOT_WINDOW) & (centres <= PLOT_WINDOW)]
        if in_win.size < 2:
            continue

        for idx, p in enumerate(axis_pos_int):
            ups   = in_win[in_win <= p]
            downs = in_win[in_win >= p]
            if (ups.size == 0) or (downs.size == 0):
                continue
            up_c, down_c = ups.max(), downs.min()
            dist = down_c - up_c
            if not (0 < dist <= INT_DIST_MAX):
                continue

            if REQUIRE_MOD_IN_GAP:
                in_gap = (relpos > up_c) & (relpos < down_c)
                if not (in_gap.any() and (mods[in_gap] == 1).any()):
                    continue

            dist_lists[p].append(dist)
            pair_cnt[idx] += 1

    dist_med = np.array([np.median(v) if v else np.nan
                         for v in dist_lists.values()])
    return dist_med, pair_cnt


# ──────────────────────────────────────────────────────────────────
#  Split each condition into BOOT_N groups  (random, no replacement)
# ──────────────────────────────────────────────────────────────────
def _partition_df(df, n_groups, rand):
    idx = list(df.index)
    rand.shuffle(idx)
    chunks = np.array_split(idx, n_groups)
    return [df.loc[c].reset_index(drop=True) for c in chunks]

rand = Random(BOOT_SEED)
groups = {COND_A: _partition_df(dfA_all_bal, BOOT_N, rand),
          COND_B: _partition_df(dfB_all_bal, BOOT_N, rand)}

# ──────────────────────────────────────────────────────────────────
#  Compute curves in parallel
# ──────────────────────────────────────────────────────────────────
def _task(subdf):
    return _curves_INT(subdf)

results = {COND_A: [], COND_B: []}
n_workers = max(2, mp.cpu_count() - 1)
with mp.Pool(n_workers) as pool:
    for cond in (COND_A, COND_B):
        for med, cnt in tqdm(pool.imap_unordered(_task, groups[cond]),
                             total=BOOT_N, desc=f"bootstrap {cond}"):
            results[cond].append((med, cnt))

# stacks: shape = (BOOT_N, n_positions)
stack_med  = {c: np.stack([v[0] for v in results[c]]) for c in results}
stack_cnt  = {c: np.stack([v[1] for v in results[c]]) for c in results}

# helper to smooth 1‑d arrays
def _smooth(arr, half_win):
    if half_win <= 0 or np.isnan(arr).all():
        return arr
    pad   = np.pad(arr, half_win, mode="edge")
    kern  = np.ones(2*half_win+1) / (2*half_win+1)
    sm    = np.convolve(pad, kern, mode="valid")
    sm[arr.size:] = np.nan  # preserve trailing NaNs if axis shorter
    return sm[:arr.size]

# median & IQR across bootstrap groups (per position)
def _bootstrap_qtuple(stack, smooth_hw):
    q1  = _smooth(np.nanpercentile(stack, 25, axis=0), smooth_hw)
    med = _smooth(np.nanmedian    (stack,      axis=0), smooth_hw)
    q3  = _smooth(np.nanpercentile(stack, 75, axis=0), smooth_hw)
    return q1, med, q3

dist_Q = {c: _bootstrap_qtuple(stack_med[c], INT_SMOOTH_HALF_WIN)
          for c in (COND_A, COND_B)}
pair_Q = {c: _bootstrap_qtuple(stack_cnt[c], INT_SMOOTH_HALF_WIN)
          for c in (COND_A, COND_B)}

# Δ‑panel : use paired bootstrap index (0..BOOT_N‑1) for A – B
delta_stack = stack_med[COND_A] - stack_med[COND_B]
dD_Q = _bootstrap_qtuple(delta_stack, INT_SMOOTH_HALF_WIN)

# ──────────────────────────────────────────────────────────────────
#  Plotting
# ──────────────────────────────────────────────────────────────────
fig_int = make_subplots(
    rows=3, cols=1, shared_xaxes=False,
    vertical_spacing=0.05,
    row_heights=[0.33, 0.33, 0.34],
    subplot_titles=(
        f"Inter‑nucleosome distance (median ± IQR, ≤ {INT_DIST_MAX} bp)",
        f"Δ distance ({COND_A} – {COND_B})",
        "Valid pair count (all bootstrap groups)"
    )
)

def _add_q(fig, x, qtuple, color, name, row):
    q1, med, q3 = qtuple
    fig.add_trace(go.Scatter(x=x, y=q1, mode="lines",
                             line=dict(color=color, dash="dot", width=1),
                             showlegend=False), row=row, col=1)
    fig.add_trace(go.Scatter(x=x, y=q3, mode="lines",
                             line=dict(color=color, dash="dot", width=1),
                             showlegend=False), row=row, col=1)
    fig.add_trace(go.Scatter(x=x, y=med, mode="lines",
                             line=dict(color=color, width=3),
                             name=name), row=row, col=1)

# Panel 1 – distance
_add_q(fig_int, axis_pos_int, dist_Q[COND_A], CLR_A, COND_A, row=1)
_add_q(fig_int, axis_pos_int, dist_Q[COND_B], CLR_B, COND_B, row=1)
fig_int.update_yaxes(title="median dist (bp)", row=1, col=1)

# Panel 2 – Δ distance
q1_d, med_d, q3_d = dD_Q
for arr in (q1_d, q3_d):
    fig_int.add_trace(go.Scatter(x=axis_pos_int, y=arr, mode="lines",
                                 line=dict(color="black", dash="dot", width=1),
                                 showlegend=False), row=2, col=1)
fig_int.add_trace(go.Scatter(x=axis_pos_int, y=med_d, mode="lines",
                             line=dict(color="black", width=3),
                             name="median Δ"), row=2, col=1)
fig_int.add_shape(type="line", x0=axis_min, x1=axis_max, y0=0, y1=0,
                  line=dict(color="grey", dash="dash"), row=2, col=1)
fig_int.update_yaxes(title="Δ distance (bp)", row=2, col=1)

# Panel 3 – pair counts
# thin lines for every bootstrap group (optional, comment out if too busy)
def _rgba(hexclr, alpha):
    h = hexclr.lstrip("#")
    r, g, b = tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
    return f"rgba({r},{g},{b},{alpha})"

for cond, color in ((COND_A, CLR_A), (COND_B, CLR_B)):
    rgba = _rgba(color, 0.25)
    # for cnt in stack_cnt[cond]:
    #     fig_int.add_trace(go.Scatter(x=axis_pos_int, y=cnt, mode="lines",
    #                                  line=dict(color=rgba, width=1),
    #                                  showlegend=False), row=3, col=1)
    _add_q(fig_int, axis_pos_int, pair_Q[cond], color, cond, row=3)

fig_int.update_yaxes(title="# pairs", row=3, col=1)

# Cosmetics
for r in (1, 2, 3):
    fig_int.update_xaxes(showgrid=False, zeroline=False, range=INT_XWINDOW,
                         row=r, col=1)
    fig_int.update_yaxes(showgrid=False, zeroline=False, row=r, col=1)

fig_int.update_yaxes(range=INT_YWINDOW1, row=1, col=1)
fig_int.update_yaxes(range=INT_YWINDOW2, row=2, col=1)
# for row 4 default y axis
# panel 3: no explicit range → default autoscale
# ──────────────────────────────────────────────────────────────────
#  Dynamically set y‑range for Panel 3 based on INT_XWINDOW
# ──────────────────────────────────────────────────────────────────
import numpy as np

# make a mask of the x‑positions within the zoom window
mask = (axis_pos_int >= INT_XWINDOW[0]) & (axis_pos_int <= INT_XWINDOW[1])

# collect all y‑values in panel 3 within that x‑range
ys = []
for cond in (COND_A, COND_B):
    # raw bootstrap counts
    ys.append(stack_cnt[cond][:, mask])
    # envelope lines (q1 and q3)
    q1, _, q3 = pair_Q[cond]
    ys.append(q1[mask])
    ys.append(q3[mask])

# flatten and compute min/max, then pad by 5%
all_vals = np.concatenate([arr.ravel() for arr in ys])
ymin, ymax = np.nanmin(all_vals), np.nanmax(all_vals)

# apply to row 3
fig_int.update_yaxes(range=(ymin , ymax ), row=3, col=1)

# dotted 175‑bp reference line
# fig_int.add_hline(y=175, line_dash="dot", line_color="grey", line_width=1,
#                   row=1, col=1)
# fig_int.add_annotation(xref="x domain", yref="y", x=1.0, y=175,
#                        text="175 bp", showarrow=False,
#                        xanchor="left", yanchor="bottom", row=1, col=1)

fig_int.update_layout(template="plotly_white",
                      width=800, height=1200,
                      showlegend=False)

fig_int.show()


In [None]:
# %%────────────────────────────────────────────────────────────────────────────
#  PLOT 100 RANDOM READS (one condition) + ENTROPY (J) UNDERNEATH
#  – Semi-transparent core / linker tracks per read
#  – mod_qual_bin==1 hits as coloured markers
#  – Entropy computed *only* from the same reads
# -----------------------------------------------------------------------------#
import numpy as np, pandas as pd, plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.ndimage import uniform_filter1d

# ───────── USER CONFIG ───────── #
CONDITION_TO_PLOT = COND_A          # ← choose one of your conditions
NUM_READS_PLOT    = 100             # how many reads to display
RAND_SEED         = 22
# visual
CORE_CLR   = "rgba(0,0,0,0.45)"
LINK_CLR   = "rgba(160,160,160,0.45)"
HIT_CLR    = "#E31A1C"
MA_WINDOW  = 3                      # moving-average half-window for J curve
# entropy windowing (reuse globals)
HALF_WIN   = WIN_BP // 2
CENTRES    = np.arange(-PLOT_WINDOW, PLOT_WINDOW+1, STEP_BP)
# ─────────────────────────────── #

rng = np.random.default_rng(RAND_SEED)

# ╔══════════════ 1. SAMPLE READS ═══════════════╗
reads_sub = (
    filtered_reads_df[filtered_reads_df.condition == CONDITION_TO_PLOT]
      .sample(n=min(NUM_READS_PLOT,
                    (filtered_reads_df.condition == CONDITION_TO_PLOT).sum()),
              random_state=RAND_SEED)
      .reset_index(drop=True)
)

# make sure core_pos / link_pos exist  (should, from earlier cell)
if "core_pos" not in reads_sub.columns:
    raise RuntimeError("core_pos/link_pos columns missing – execute state-map cell first.")

# ╔══════════════ 2. COMPOSE READ-LEVEL TRACES ═══════════════╗
def _segments(arr):
    """yield (start,end) for contiguous runs in a sorted integer array"""
    if arr.size == 0:
        return
    run_start = arr[0]
    prev = arr[0]
    for x in arr[1:]:
        if x != prev + 1:
            yield run_start, prev
            run_start = x
        prev = x
    yield run_start, prev

scatter_traces = []
for ridx, row in reads_sub.iterrows():
    y = ridx + 1                        # 1-based on y-axis

    # core (dark)
    for s, e in _segments(np.sort(row.core_pos)):
        scatter_traces.append(go.Scatter(
            x=[s, e], y=[y, y],
            mode="lines", line=dict(color=CORE_CLR, width=3),
            hoverinfo="skip", showlegend=False
        ))
    # linker (light grey)
    for s, e in _segments(np.sort(row.link_pos)):
        scatter_traces.append(go.Scatter(
            x=[s, e], y=[y, y],
            mode="lines", line=dict(color=LINK_CLR, width=3),
            hoverinfo="skip", showlegend=False
        ))
    # accessible hits
    hits = [p for p, q in zip(row.rel_pos, row.mod_qual_bin) if q == 1]
    if hits:
        scatter_traces.append(go.Scatter(
            x=hits, y=[y]*len(hits),
            mode="markers",
            marker=dict(symbol="circle", size=4, color=HIT_CLR),
            hoverinfo="skip", showlegend=False
        ))

# ╔══════════════ 3. ENTROPY (J) FOR THESE READS ═══════════════╗
J_vals = []
for cpos in CENTRES:
    lo, hi = cpos - HALF_WIN, cpos + HALF_WIN - 1
    core_cnt = linker_cnt = 0
    for _, r in reads_sub.iterrows():
        core_cnt   += ((r.core_pos >= lo) & (r.core_pos <= hi)).sum()
        linker_cnt += ((r.link_pos >= lo) & (r.link_pos <= hi)).sum()
    if core_cnt == 0 or linker_cnt == 0:
        J_vals.append(np.nan)
    else:
        tot = core_cnt + linker_cnt
        p1, p0 = core_cnt / tot, linker_cnt / tot
        H = - (p1*np.log(p1) + p0*np.log(p0))
        J_vals.append(H / np.log(2))

# optional smoothing
if MA_WINDOW > 1:
    pad = MA_WINDOW // 2
    J_sm = uniform_filter1d(np.where(np.isfinite(J_vals), J_vals, 0.0),
                            MA_WINDOW, mode="nearest")
    cnt  = uniform_filter1d(np.isfinite(J_vals).astype(int),
                            MA_WINDOW, mode="nearest")
    J_plot = np.where(cnt > 0, J_sm / np.maximum(cnt, 1), np.nan)
else:
    J_plot = J_vals

# ╔══════════════ 4. BUILD FIGURE ═══════════════╗
fig = make_subplots(
    rows=2, cols=1,
    shared_xaxes=True,
    row_heights=[0.7, 0.3],
    vertical_spacing=0.03
)

# top panel: reads
for tr in scatter_traces:
    fig.add_trace(tr, row=1, col=1)

fig.update_yaxes(title="Read #", row=1, col=1,
                 range=[0, len(reads_sub)+1], autorange="reversed")

# bottom panel: entropy
fig.add_trace(go.Scatter(
    x=CENTRES, y=J_plot,
    mode="lines",
    line=dict(width=4, color="#1F78B4"),
    name="J (core vs linker)"
), row=2, col=1)

fig.update_yaxes(title="Shannon J", row=2, col=1)

# shared x-axis
fig.update_xaxes(range=[-PLOT_WINDOW, PLOT_WINDOW],
                 title="rel_pos (bp)", row=2, col=1)

fig.update_layout(
    template="plotly_white",
    width=1000, height=700,
    title=(f"{NUM_READS_PLOT} random reads – {CONDITION_TO_PLOT} "
           f"(cores/linkers & mod hits) + entropy")
)
fig.show()


In [None]:
################################################################################
#  Sliding-window variance profile  +  variance-difference significance        #
#  – One figure, two traces: σ_A(x), σ_B(x)  on left axis                      #
#    and −log10(q) on right axis                                               #
################################################################################
import numpy as np, pandas as pd, plotly.graph_objects as go
from scipy.stats import levene
from statsmodels.stats.multitest import multipletests
from tqdm.auto import tqdm   # progress bar (silently falls back if tqdm missing)

# ───────────────────────── USER CONFIG ────────────────────────── #
WINDOW_BP      = 88           # half-window used to collect centres (bp)
STEP_BP        = 5           # stride of evaluation grid (bp)
PVALUE_METHOD  = "fdr_bh"     # p-value correction ('fdr_bh', 'holm', …)
ALPHA          = 0.05         # significance threshold on adjusted p
PLOT_TEMPLATE  = "plotly_white"
DEBUG          = True
# ───────────────────────────────────────────────────────────────── #

def dbg(msg):
    if DEBUG:
        print(msg)

# 1) Gather ALL dyad positions for each condition ──────────────────────────── #
def centres_from_df(df):
    """Return a flat NumPy array of all nucleosome centres in the dataframe."""
    return np.concatenate([r.nuc_centers for r in df.itertuples(index=False)
                           if r.nuc_centers]).astype(float)

centA = centres_from_df(dfA_all)
centB = centres_from_df(dfB_all)
dbg(f"[DATA] {len(centA):,} centres in {COND_A}, {len(centB):,} in {COND_B}")

# 2) Evaluation grid along the x-axis ──────────────────────────────────────── #
xs = np.arange(-PLOT_WINDOW, PLOT_WINDOW + 1, STEP_BP)

σA, σB, pvals = [], [], []

# 3) Sliding-window variance + Levene test ─────────────────────────────────── #
for x in tqdm(xs, desc="sliding var"):
    winA = centA[np.abs(centA - x) <= WINDOW_BP]
    winB = centB[np.abs(centB - x) <= WINDOW_BP]

    # if either condition has < 3 centres in window, mark as NaN
    if winA.size < 3 or winB.size < 3:
        σA.append(np.nan); σB.append(np.nan); pvals.append(np.nan); continue


    σA.append(np.std(winA, ddof=1))
    σB.append(np.std(winB, ddof=1))

    # Brown-Forsythe = Levene with center='median'
    stat, p = levene(winA, winB, center='median')
    pvals.append(p)

σA, σB, pvals = map(np.asarray, (σA, σB, pvals))

# 4) Multiple-testing correction ───────────────────────────────────────────── #
mask_p = np.isfinite(pvals)
_, qvals, _, _ = multipletests(pvals[mask_p], alpha=ALPHA, method=PVALUE_METHOD)
q_full = np.full_like(pvals, np.nan); q_full[mask_p] = qvals
sig = (q_full < ALPHA)

dbg(f"[STATS] {sig.sum()} / {mask_p.sum()} windows significant (FDR<{ALPHA})")

# 5) Build Plotly figure ───────────────────────────────────────────────────── #
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=xs, y=σA, mode="lines", name=f"σ {COND_A}",
    line=dict(color=CLR_A, width=2)
))
fig.add_trace(go.Scatter(
    x=xs, y=σB, mode="lines", name=f"σ {COND_B}",
    line=dict(color=CLR_B, width=2)
))

# secondary axis: −log10(q)
with np.errstate(divide='ignore'):
    neglogQ = -np.log10(q_full)

fig.add_trace(go.Scatter(
    x=xs, y=neglogQ, mode="lines",
    name="−log10(q)", line=dict(color="black", dash="dot"), yaxis="y2"
))

# highlight significant windows
fig.add_trace(go.Scatter(
    x=xs[sig], y=neglogQ[sig], mode="markers",
    marker=dict(size=6, color="red"),
    name=f"FDR<{ALPHA}"
))

fig.update_layout(
    template=PLOT_TEMPLATE,
    width=1000, height=500,
    title=f"Sliding-window variance (window=±{WINDOW_BP} bp, step={STEP_BP} bp)",
    xaxis_title="Genomic rel_pos (bp)",
    yaxis_title="σ (bp)",
    yaxis2=dict(
        title="−log10(q)",
        overlaying="y", side="right", showgrid=False
    ),
    legend=dict(x=0.02, y=0.98)
)
fig.show()
################################################################################


In [None]:
# %%  ───────────────────────────────────────────────────────────────
#  Peak-aligned per-type variance profile (mean ± IQR, adj_rel_pos)
#  ───────────────────────────────────────────────────────────────────
import numpy as np, pandas as pd, plotly.graph_objects as go
from tqdm.auto import tqdm

# ────────── USER CONFIG ADDITIONS ──────────
ALIGN_RANGE   = (-200, 200)   # bp window in which to seek the peak variance
FIG_W, FIG_H  = 1000, 500
# ───────────────────────────────────────────

def dbg(msg):
    if DEBUG:
        print(msg)

# 1) Centres per type  (same helper as before) ─────────────────────
def centres_by_type(df):
    groups = {}
    for t, sub in df.groupby("type", sort=False):
        groups[t] = np.concatenate(
            [r.nuc_centers for r in sub.itertuples(index=False) if r.nuc_centers]
        ).astype(float)
    return groups

centA_by_type = centres_by_type(dfA_all)
centB_by_type = centres_by_type(dfB_all)
type_list     = sorted(centA_by_type.keys())   # all types identical by premise

dbg(f"[DATA] {len(type_list)} types shared between conditions")

# 2) Evaluation grid ───────────────────────────────────────────────
xs  = np.arange(-PLOT_WINDOW, PLOT_WINDOW + 1, STEP_BP)
zero_idx = np.where(xs == 0)[0][0]

# 3) Utility: shift array w/ NaNs (no wrap) ────────────────────────
def shift_no_wrap(arr, shift):
    """Shift 1-D array by `shift` (int indices); new positions get NaN."""
    out = np.full_like(arr, np.nan, dtype=float)
    if shift > 0:
        out[shift:] = arr[:-shift]
    elif shift < 0:
        out[:shift] = arr[-shift:]
    else:
        out[:] = arr
    return out

# 4) Build variance matrix [n_types × n_xs] with peak alignment ───
def var_matrix(cent_dict):
    mats = []
    for t in tqdm(type_list, desc="per-type variance"):
        centres = cent_dict[t]
        # sliding-window variance
        var_vec = []
        for x in xs:
            win = centres[np.abs(centres - x) <= WINDOW_BP]
            var_vec.append(np.var(win, ddof=1) if win.size >= 2 else np.nan)
        var_vec = np.asarray(var_vec)

        # determine peak within ALIGN_RANGE
        mask = (xs >= ALIGN_RANGE[0]) & (xs <= ALIGN_RANGE[1]) & np.isfinite(var_vec)
        if mask.any():
            peak_idx = np.where(mask)[0][np.nanargmax(var_vec[mask])]
            shift    = peak_idx - zero_idx      # >0 ⇒ peak right of 0
            var_vec  = shift_no_wrap(var_vec, -shift)
        # if nothing finite in the window, leave unshifted

        mats.append(var_vec)
    return np.vstack(mats)     # shape: (n_types, n_xs)

matA = var_matrix(centA_by_type)   # σ² values per type, aligned
matB = var_matrix(centB_by_type)

# 5) Aggregate across types (ignoring NaNs) ────────────────────────
def summarise(mat):
    mean = np.nanmean(mat, axis=0)
    q1   = np.nanpercentile(mat, 25, axis=0)
    q3   = np.nanpercentile(mat, 75, axis=0)
    return mean, q1, q3

meanA, q1A, q3A = summarise(matA)
meanB, q1B, q3B = summarise(matB)

# 6) Plot on adjusted axis (adj_rel_pos == xs) ─────────────────────
fig = go.Figure()

# Condition A
fig.add_trace(go.Scatter(
    x=xs, y=meanA, mode="lines", name=f"Mean σ² {COND_A}",
    line=dict(color=CLR_A, width=2)
))
fig.add_trace(go.Scatter(
    x=xs, y=q1A, mode="lines", name=f"Q1 σ² {COND_A}",
    line=dict(color=CLR_A, width=1, dash="dot")
))
fig.add_trace(go.Scatter(
    x=xs, y=q3A, mode="lines", name=f"Q3 σ² {COND_A}",
    line=dict(color=CLR_A, width=1, dash="dot")
))

# Condition B
fig.add_trace(go.Scatter(
    x=xs, y=meanB, mode="lines", name=f"Mean σ² {COND_B}",
    line=dict(color=CLR_B, width=2)
))
fig.add_trace(go.Scatter(
    x=xs, y=q1B, mode="lines", name=f"Q1 σ² {COND_B}",
    line=dict(color=CLR_B, width=1, dash="dot")
))
fig.add_trace(go.Scatter(
    x=xs, y=q3B, mode="lines", name=f"Q3 σ² {COND_B}",
    line=dict(color=CLR_B, width=1, dash="dot")
))

fig.update_layout(
    template=PLOT_TEMPLATE,
    width=FIG_W, height=FIG_H,
    title=(f"Peak-aligned per-type sliding-window variance "
           f"(window ±{WINDOW_BP} bp, step {STEP_BP} bp)"),
    xaxis_title="adj_rel_pos (bp, peak variance aligned at 0)",
    yaxis_title="σ² of centres (bp²)",
    legend=dict(x=0.02, y=0.98)
)
fig.show()


In [None]:
###############################################################################
#  Cell 2b – Fast-start “filtered_reads_df” builder (no nucleosome calling)   #
#  ▸ Generates the filtered read table that Cell 3 expects                    #
#  ▸ Adds empty nuc_centers / nuc_coords columns                             #
#  ▸ Defines analysis_cond (edit as needed)                                   #
###############################################################################
import os, warnings
import numpy as np
import pandas as pd

# ───────────────────────── USER CONFIG ────────────────────────── #
# hard-coded condition aliases (Cell 3 expects these)
COND_A, COND_B = analysis_cond[0], analysis_cond[2]
# 1. Which three conditions do you want Cell 3 to use?
ANALYSIS_CONDS = [COND_A, COND_B]

DOWNSAMPLE_READS   = True   # False → keep native counts (default behaviour)
DOWNSAMPLE_SEED    = 123    # reproducible RNG for the sampling step
BALANCE_BY_TOTAL_BP  = True       # NEW → False keeps old count-based mode



# 2. Metadata filters (should match Cell 1’s originals)
TYPES_TO_INCLUDE      = []       # "ALL" or list or "intergenic_control"
CHR_TYPE_INCLUDE      = []                        # [] → keep all "Autosome"
CHIP_RANK_CUTOFF      = 80
ABOVE_FLAG            = True
MIN_READ_LENGTH       = 500
REL_POS_RANGE         = 2000
PLOT_WINDOW = REL_POS_RANGE
REQUIRE_CENTRAL       = False
CLR_A , CLR_B    = "#b12537", "#4974a5"

# Path to rex_chiprank.bed (used only if TYPES_TO_INCLUDE is empty)
CHIPRANK_PATH         = "/Data1/reference/rex_chiprank.bed"

# ───────────────────────── sanity checks ───────────────────────── #
if "merged_df" not in globals():
    raise RuntimeError(
        "Cell 2b needs a DataFrame named `merged_df` in memory.\n"
        "Make sure you’ve loaded it earlier in the notebook."
    )

# ───────────────────────── helper ───────────────────────── #
def _dbg(msg):
    print(f"[Cell 2b] {msg}")

# ───────────────────────── helper: bp-aware sampler ───────────── #
def _sample_to_bp(df, target_bp, rng):
    """
    Randomly shuffle reads and keep the smallest prefix whose
    cumulative read_length ≥ target_bp.
    Guarantees ≥ target bp (may overshoot by ≤ max(read_length)).
    """
    idx = df.index.to_numpy()
    rng.shuffle(idx)

    cum_len = np.cumsum(df.loc[idx, "read_length"].to_numpy())
    keep_up_to = np.searchsorted(cum_len, target_bp, side="left") + 1
    return idx[:keep_up_to]


# ───────────────────────── chip-rank lookup (optional) ───────────────────────── #
if (not TYPES_TO_INCLUDE) or (TYPES_TO_INCLUDE == ["ALL"]):
    chiprank_df = pd.read_csv(CHIPRANK_PATH, sep=r"\s+")
    chiprank_df["type"] = "MOTIFS_" + chiprank_df["type"].astype(str)
    chip_rank_lookup = dict(zip(chiprank_df["type"], chiprank_df["chip_rank"] * 100))
else:
    chip_rank_lookup = {}

# ───────────────────────── build filtered_reads_df ───────────────────────── #
keep_conds = set(ANALYSIS_CONDS)

# --- decide motif types to keep --------------------------------------------
if TYPES_TO_INCLUDE == ["ALL"]:
    keep_types = set(merged_df["type"].unique())
elif TYPES_TO_INCLUDE:                          # explicit list
    keep_types = set(TYPES_TO_INCLUDE)
else:                                           # chip-rank threshold
    keep_types = {
        t for t, rk in chip_rank_lookup.items()
        if (rk >= CHIP_RANK_CUTOFF) == ABOVE_FLAG
    }

# --- chromosomes ------------------------------------------------------------
keep_chr = set(CHR_TYPE_INCLUDE) if CHR_TYPE_INCLUDE else set(merged_df["chr_type"].unique())

# --- metadata filter --------------------------------------------------------
df0 = merged_df.query(
    "condition in @keep_conds and type in @keep_types and chr_type in @keep_chr"
).copy()
_dbg(f"after metadata filter: {len(df0)} reads")

# --- overlap with REL_POS_RANGE --------------------------------------------
mask_overlap = df0["rel_pos"].apply(
    lambda arr: ((arr >= -REL_POS_RANGE) & (arr <= REL_POS_RANGE)).any()
)
df0 = df0[mask_overlap].reset_index(drop=True)
_dbg(f"after overlap filter:  {len(df0)} reads")

# --- length / central coverage ---------------------------------------------
if REQUIRE_CENTRAL:
    half = MIN_READ_LENGTH // 2
    mask_central = df0["rel_pos"].apply(
        lambda arr: (arr.min() <= -half) and (arr.max() >= half)
    )
    df1 = df0[mask_central]
else:
    df1 = df0[df0["read_length"] >= MIN_READ_LENGTH]
df1 = df1.reset_index(drop=True)
_dbg(f"after length/central filter: {len(df1)} reads")

# --- methylation content ----------------------------------------------------
mask_valid = df1["mod_qual_bin"].apply(
    lambda x: isinstance(x, (list, np.ndarray)) and np.nansum(x) > 0
)
filtered_reads_df = df1[mask_valid].reset_index(drop=True)
_dbg(f"after methylation filter: {len(filtered_reads_df)} reads")

# ──────────────────── NEW: per-locus balancing key ──────────────────── #
# (insert just **before** the DOWNSAMPLE_READS block)

BAL_KEY_COL = "__bal_type"                         # temporary helper col

filtered_reads_df[BAL_KEY_COL] = np.where(
    filtered_reads_df["type"] == "intergenic_control",
    # include bed_start to distinguish each intergenic region
    filtered_reads_df["type"] + "_" + filtered_reads_df["bed_start"].astype(str),
    filtered_reads_df["type"]
)

# # temporarily keep only rows where type == the first unique type
# if len(filtered_reads_df[BAL_KEY_COL].unique()) > 1:
#     first_type = filtered_reads_df[BAL_KEY_COL].unique()[0]
#     filtered_reads_df = filtered_reads_df[
#         filtered_reads_df[BAL_KEY_COL] == first_type
#     ].reset_index(drop=True
#     )
# _dbg(f"after balancing key: {len(filtered_reads_df)} reads")

# ─── per-read bp covered *inside* the ±REL_POS_RANGE window ───
def _bp_in_window(rel):
    rel = np.asarray(rel, dtype=int)
    m   = (rel >= -REL_POS_RANGE) & (rel <= REL_POS_RANGE)
    # unique() ⇒ one bp counted once even if called multiple times
    return np.unique(rel[m]).size

filtered_reads_df["bp_window"] = filtered_reads_df["rel_pos"].apply(_bp_in_window)

# ──────────────────── MODIFIED down-sampling loop ───────────────────── #
if DOWNSAMPLE_READS:
    rng = np.random.RandomState(DOWNSAMPLE_SEED)
    kept_idx   = []
    rows_drop  = 0
    groups_trim = 0                                 # rename for clarity

    print("\n[Cell 2b] === DOWN-SAMPLING START (seed=", DOWNSAMPLE_SEED, ") ===")

     ##############################################################################
    #   BALANCE EACH (type or intergenic_control_locus) ACROSS CONDITIONS        #
    #   — aim for equal bp; optional secondary match on read count               #
    ##############################################################################
    REQUIRE_EQUAL_READS = False      # set True if you also need n_reads balanced

    for bal_key, df_t in filtered_reads_df.groupby(BAL_KEY_COL, sort=False):

        bp_by_cond = (
            df_t.groupby("condition")["bp_window"].sum()
                .reindex(ANALYSIS_CONDS, fill_value=0)
        )
        target_bp = bp_by_cond.min()

        print(f"   {bal_key!r}: bp A={bp_by_cond.iloc[0]:,}  "
              f"bp B={bp_by_cond.iloc[1]:,}  → target={target_bp:,} bp")

        for cond, g in df_t.groupby("condition", sort=False):

            # random order for fairness
            g_shuf = g.sample(frac=1, random_state=rng)

            # cumulate bp within the window
            cum_bp = g_shuf["bp_window"].cumsum()

            keep_mask = cum_bp <= target_bp          # still under budget?

            if not keep_mask.any():
                # every read is longer than target_bp → keep the shortest
                shortest_idx = g_shuf["bp_window"].idxmin()
                kept_idx.append(shortest_idx)
                rows_drop   += len(g_shuf) - 1
                groups_trim += 1
                continue

            # keep all reads that keep us under budget
            kept_idx.extend(g_shuf.loc[keep_mask].index)

            if REQUIRE_EQUAL_READS:
                n_target = keep_mask.sum()
                if len(g_shuf) > n_target:
                    rows_drop   += len(g_shuf) - n_target
                    groups_trim += 1
            else:
                rows_drop   += len(g_shuf) - keep_mask.sum()
                groups_trim += 1

    filtered_reads_df = filtered_reads_df.loc[kept_idx].reset_index(drop=True)
    # ─── QUICK DIAGNOSTIC: are we balancing by reads or by bp? ───
    def _balance_snapshot(df):
        snap = (
            df.groupby([BAL_KEY_COL, "condition"])
              .agg(n_reads=("bp_window", "size"),
                   sum_bp =("bp_window", "sum"),
                   mean_bp=("bp_window", "mean"))
              .unstack("condition", fill_value=0)
              .sort_index()
        )
        return snap

    print("\n[DEBUG]  Post-down-sampling balance check  (first 10 loci):")
    snap = _balance_snapshot(filtered_reads_df)
    display(snap.head(10))     # if in Jupyter, else print(snap.head(10))

    # Summaries
    same_reads = (snap.xs("n_reads", axis=1, level=0).diff(axis=1).abs().sum().sum() == 0)
    same_bp    = (snap.xs("sum_bp",  axis=1, level=0).diff(axis=1).abs().sum().sum() == 0)
    print(f"\n  • Equal #reads per locus?  {same_reads}")
    print(f"  • Equal bp per locus?      {same_bp}\n")

    print(f"[Cell 2b] === DOWN-SAMPLING DONE: total dropped {rows_drop} rows "
          f"from {groups_trim} (key×cond) combos ===\n")
else:
    print("[Cell 2b] down-sampling disabled – keeping native read counts\n")

# ───── clean-up helper column ─────
filtered_reads_df.drop(columns=BAL_KEY_COL, inplace=True)

for cond, df_cond in filtered_reads_df.groupby("condition"):
    tot_bp = df_cond["read_length"].sum()
    _dbg(f"{cond}: {len(df_cond)} reads, {tot_bp} bp after balancing")

# ───────────────────────── add empty nuc columns ──────────────────────────── #
filtered_reads_df["nuc_centers"] = [[] for _ in range(len(filtered_reads_df))]
filtered_reads_df["nuc_coords"]  = [[] for _ in range(len(filtered_reads_df))]

_dbg("filtered_reads_df ready for Cell 3")
_dbg(f"columns: {list(filtered_reads_df.columns)}")

# ╔═════════════════════ EXTRA GLOBALS NEEDED BY LATER CELLS ══════════════════╗


# basic per-condition DataFrames (no motif filtering here;
# later cells can refine them further if they wish)
dfA_all = filtered_reads_df[filtered_reads_df["condition"] == COND_A].copy()
dfB_all = filtered_reads_df[filtered_reads_df["condition"] == COND_B].copy()

_dbg(f"{COND_A}: {len(dfA_all)} reads")
_dbg(f"{COND_B}: {len(dfB_all)} reads")
_dbg("Cell 2b complete – you can run the next analysis cell now.")

In [None]:
from multiprocessing.dummy import Pool as ThreadPool
from numpy.lib.stride_tricks import as_strided
from scipy.ndimage import uniform_filter1d   # at top with other imports
# import gaussian filter 1d
from scipy.ndimage import gaussian_filter1d

SMOOTH_LINE_BP = 25       # width of running mean applied to y-traces
# width (in bp) of uniform box-car smoothing applied to all heat-maps (0 = no smoothing)
SMOOTH_HEAT_BP = 25

# ────────── OCCUPANCY THRESHOLD OPTIONS ──────────
#
# How do we decide whether a window is “occupied”?
#   • "fixed"      → use OCC_THRESH_FIXED for every condition
#   • "percentile" → per-condition threshold = P-th percentile
#   • "mean_sd"    → threshold = mean + K · sd   (per condition)
# ───── USER CONFIG toggles ─────
USE_FRAC_POS = True      # False → mean r  |  True → % reads with r>0
OCC_THRESH_METHOD      = "fixed"   # "fixed" | "percentile" | "mean_sd"
OCC_THRESH_FIXED       = 0.1           # only if method == "fixed"
OCC_THRESH_PERCENTILE  = 80             # only if method == "percentile"
OCC_THRESH_K_SD        = 1.0            # only if method == "mean_sd"

# ───────────────────────── USER CONFIG ────────────────────────── #
BIN_WIDTH_M6A      = 1      # bp per bin
SMOOTH_M6A_BINS    = 1     # Gaussian σ in bins
USE_THREADS        = True
N_WORKERS          = min(32, len(dfA_all))

### TEMPLATE DESIGN
PEAK_FRAC   = 0.18   # 0 – 1
TROUGH_FRAC = 0.8   # 0 – 1   (peak+trough must be < 1)

LINKER_OUTER_BP    = 15     # left & right outer linkers
TRANSITION_BP      = 1    # width of each cosine transition (ignored for binary)

TEMPLATE_FORM  = "single"      # "single" | "dual" | "triple"
EDGE_STYLE     = "binary"    # "cosine" | "binary"

OUT_LINKER_BP  = 15          # outer linkers for "dual"
CORE_WIDTH_BP  = 145         # each core for "dual"
TRANSITION_BP  = 1          # cosine edge width if EDGE_STYLE=="cosine"


# HEATMAP SETTINGS
# --------------------------------------------------- #
# define period grid
PERIOD_MIN         = 175 #10
PERIOD_MAX         = 175 #50
PERIOD_STEP        = 1
PERIODS_BP         = np.arange(PERIOD_MIN, PERIOD_MAX + 1, PERIOD_STEP)
# ── y-axis values for the heat-maps ────────────────────────────────
if TEMPLATE_FORM.lower() in ("dual", "triple"):
    PERIODS_DISPLAY = PERIODS_BP / 2          # core-to-core distance
    Y_AXIS_LABEL    = "Half-period (bp)"
else:                                         # "single"
    PERIODS_DISPLAY = PERIODS_BP
    Y_AXIS_LABEL    = "Period (bp)"


# specify which period to plot in the final line plot
SELECT_PERIOD      = 175
PERIOD_IDX = int(np.where(PERIODS_BP == SELECT_PERIOD)[0])   # ← NEW

SKIP_ZERO_WINDOWS = True     # ← NEW: ignore windows whose mod_qual_bin is all 0

# ───────────────────────── SHARED GRID ────────────────────────── #
bins_m6a = np.arange(-PLOT_WINDOW, PLOT_WINDOW + BIN_WIDTH_M6A, BIN_WIDTH_M6A)
cent_m6a = (bins_m6a[:-1] + bins_m6a[1:]) / 2

# -----------------------COMPUTE r Threshold per condition---------------------------- #
def _compute_occ_thresh(heats, condition_name):
    """Return a single scalar threshold for one condition & print debug."""
    finite_vals = heats[np.isfinite(heats)]
    if finite_vals.size == 0:           # fallback
        print(f"[DEBUG] {condition_name}: no finite r – using 0.0")
        return 0.0

    if OCC_THRESH_METHOD == "fixed":
        thr = OCC_THRESH_FIXED

    elif OCC_THRESH_METHOD == "percentile":
        thr = np.percentile(finite_vals, OCC_THRESH_PERCENTILE)

    elif OCC_THRESH_METHOD == "mean_sd":
        thr = finite_vals.mean() + OCC_THRESH_K_SD * finite_vals.std()

    else:
        raise ValueError(f"unknown OCC_THRESH_METHOD '{OCC_THRESH_METHOD}'")

    print(f"[DEBUG] {condition_name}: occupancy threshold = {thr:.4f} "
          f"({OCC_THRESH_METHOD})")
    return thr


def _make_template(period_bp: int) -> np.ndarray:
    """
    Build a mean-centred template of *exactly* `period_bp` bp.

    TEMPLATE_FORM == "single"
        • Flat peak (PEAK_FRAC) / flat trough (TROUGH_FRAC) / cosine transitions.

    TEMPLATE_FORM == "dual"
        • Pattern 1–0–1–0–1:
              out_linker  |  core  |  mid_linker  |  core  |  out_linker
        • OUT_LINKER_BP and CORE_WIDTH_BP are fixed.
        • mid_linker grows/shrinks with the period.
        • EDGE_STYLE == "cosine" → smooth edges (width = TRANSITION_BP);
          else hard binary.
    """
    n_bins_tot = int(round(period_bp / BIN_WIDTH_M6A))

    # ------------------------------------------------ SINGLE -----------
    if TEMPLATE_FORM.lower() == "single":
        if PEAK_FRAC + TROUGH_FRAC >= 1:
            raise ValueError("PEAK_FRAC + TROUGH_FRAC must be < 1")

        peak_bins   = int(round(n_bins_tot * PEAK_FRAC / 2)) * 2  # split L/R
        trough_bins = int(round(n_bins_tot * TROUGH_FRAC))
        trans_bins  = (n_bins_tot - peak_bins - trough_bins) // 2
        # segments
        peak_L = np.ones(peak_bins // 2)
        peak_R = np.ones_like(peak_L)
        trough = np.zeros(trough_bins)
        x = np.linspace(0, np.pi, trans_bins, endpoint=False)
        fall  = 0.5 * (1 + np.cos(x))        # 1 → 0
        rise  = fall[::-1]                   # 0 → 1
        tpl = np.concatenate([peak_L, fall, trough, rise, peak_R])

    # ------------------------------------------------ DUAL -------------
    elif TEMPLATE_FORM.lower() == "dual":
        out_bins  = int(round(OUT_LINKER_BP / BIN_WIDTH_M6A))
        core_bins = int(round(CORE_WIDTH_BP  / BIN_WIDTH_M6A))
        trans_bins= int(round(TRANSITION_BP  / BIN_WIDTH_M6A)) if EDGE_STYLE=="cosine" else 0

        fixed = 2*out_bins + 2*core_bins + 4*trans_bins
        mid_bins = n_bins_tot - fixed
        if mid_bins < 1:
            raise ValueError(f"period {period_bp} too small for fixed dimensions.")

        # helpers
        def plateau(h, w): return np.full(w, h)
        def trans(a, b):
            if trans_bins==0 or a==b: return np.empty(0)
            x = np.linspace(0, np.pi/2, trans_bins, endpoint=False)
            return np.cos(x)**2 if a>b else 1-np.cos(x)**2

        tpl = np.concatenate([
            plateau(1, out_bins),
            trans(1,0),
            plateau(0, core_bins),
            trans(0,1),
            plateau(1, mid_bins),
            trans(1,0),
            plateau(0, core_bins),
            trans(0,1),
            plateau(1, out_bins)
        ])
    # ---------------------------------------------------------------- TRIPLE --
    elif TEMPLATE_FORM.lower() == "triple":
        out_bins   = int(round(OUT_LINKER_BP  / BIN_WIDTH_M6A))
        core_bins  = int(round(CORE_WIDTH_BP   / BIN_WIDTH_M6A))
        trans_bins = int(round(TRANSITION_BP   / BIN_WIDTH_M6A)) \
                     if EDGE_STYLE == "cosine" else 0

        def plateau(h, w): return np.full(w, h)
        def trans(a, b):
            if trans_bins == 0 or a == b:
                return np.empty(0, dtype=float)
            x = np.linspace(0, np.pi / 2, trans_bins, endpoint=False)
            return np.cos(x) ** 2 if a > b else 1 - np.cos(x) ** 2

        # one nucleosome block: 1 → 0 → 1
        nuc = np.concatenate([
            plateau(1, out_bins),
            trans(1, 0),
            plateau(0, core_bins),
            trans(0, 1),
            plateau(1, out_bins)
        ])

        fixed = 2 * nuc.size                # two nucleosomes, no gap yet
        gap_bins = n_bins_tot - fixed
        if gap_bins < 0:
            raise ValueError(f"period {period_bp} too small for 'triple' form")

        tpl = np.concatenate([nuc, np.full(gap_bins, np.nan), nuc])

    else:
        raise ValueError("TEMPLATE_FORM must be 'single', 'dual' or 'triple'.")

    # ----- centre on the *valid* (non-NaN) bins only --------------------------
    m = np.nanmean(tpl)
    tpl = tpl - m
    tpl[np.isnan(tpl)] = np.nan             # keep NaNs

    # length fence-post
    if tpl.size < n_bins_tot:
        tpl = np.pad(tpl, (0, n_bins_tot - tpl.size), constant_values=np.nan)
    elif tpl.size > n_bins_tot:
        tpl = tpl[:n_bins_tot]

    return tpl




# ─────────────────── prebuild all templates ────────────────────── #
TEMPLATES = {L: _make_template(L) for L in PERIODS_BP}


# ─────────── show five example templates ─────────── #
idxs            = np.linspace(0, len(PERIODS_BP)-1, 5, dtype=int)
example_periods = PERIODS_BP[idxs]

fig_templates = go.Figure()
for L in example_periods:
    tpl = TEMPLATES[L]
    # center x around zero
    x   = (np.arange(tpl.size) - tpl.size//2) * BIN_WIDTH_M6A
    fig_templates.add_trace(go.Scatter(
        x=x, y=tpl, mode="lines", name=f"{L} bp",connectgaps=False,
    ))
fig_templates.update_layout(
    template="plotly_white",
    title=f"Example {TEMPLATE_FORM.title()} Templates",
    xaxis_title="rel_pos within window (bp)",
    yaxis_title="Mean-centered amplitude",
    width = 800
)
fig_templates.show()

# ─────────────────── m6A occupancy track per read ────────────────── #
def _m6a_track(rel_pos, mod_bin):
    idx = np.digitize(rel_pos[np.asarray(mod_bin)==1], bins_m6a) - 1
    idx = idx[(idx>=0) & (idx<len(cent_m6a))]
    #track = np.zeros(len(cent_m6a), float) # background = 0
    track = np.full(len(cent_m6a), -1.0)     # background = -1

    if idx.size:
        track[np.unique(idx)] = 1.0
    # ───── Gaussian smoothing instead of uniform ─────
    if SMOOTH_M6A_BINS > 0:
        # SMOOTH_M6A_BINS now acts as the Gaussian sigma
        track = gaussian_filter1d(track, sigma=SMOOTH_M6A_BINS, mode="nearest")
    return track

# ─────────── normalised sliding correlation (Pearson r) ─────────── #
def _normxcorr1d(track: np.ndarray, tpl: np.ndarray) -> np.ndarray:
    """
    Sliding Pearson-r with optional NaNs in tpl (ignored in r).
    """
    m, n = tpl.size, track.size
    if n < m:
        return np.full(n, np.nan)

    # pre-mask once – tpl NaNs are the same for every window
    mask      = ~np.isnan(tpl)
    tpl_valid = tpl[mask]                   # 1-D view without NaNs
    tpl_zm    = tpl_valid - tpl_valid.mean()
    tpl_norm  = np.linalg.norm(tpl_zm)

    # stride windows, but keep only the masked columns
    shape   = (n - m + 1, m)
    strides = (track.strides[0], track.strides[0])
    windows = as_strided(track, shape, strides)[:, mask]      # (win, k)
    w_cent  = windows - windows.mean(axis=1, keepdims=True)

    numer = (w_cent * tpl_zm).sum(axis=1)
    denom = np.linalg.norm(w_cent, axis=1) * tpl_norm

    if SKIP_ZERO_WINDOWS:
        corr = np.divide(numer, denom,
                         out=np.full_like(numer, np.nan),
                         where=denom > 0)
    else:
        safe_denom = np.where(denom == 0, 1, denom)
        corr = numer / safe_denom
        corr[denom == 0] = 0.0

    out = np.full(n, np.nan)
    out[m // 2 : m // 2 + corr.size] = corr
    return out


# ───────────── single-read periodicity heat-map ───────────── #
def _per_read_heat(track):
    return np.vstack([_normxcorr1d(track, TEMPLATES[L]) for L in PERIODS_BP])

# ───────────── wrapper for threading ───────────── #
def _process_m6a(args):
    track = _m6a_track(*args)
    heat  = _per_read_heat(track)              # (len(PERIODS_BP), len(rel_pos))
    # store which windows we skipped for SELECT_PERIOD
    zero_mask = np.isnan(heat[PERIOD_IDX]) if SKIP_ZERO_WINDOWS \
                else (heat[PERIOD_IDX] == 0)
    return heat, zero_mask


# ───────────────────────── condition-level helper ───────────────────────── #
def _condition_heat(df, *, thresh_val=None,return_tot=False):
    payload = [(np.asarray(r.rel_pos,      dtype=int),
                np.asarray(r.mod_qual_bin, dtype=int))
               for r in df.itertuples(index=False)]

    worker = _process_m6a
    if USE_THREADS and N_WORKERS > 1:
        with ThreadPool(N_WORKERS) as pool:
            results = list(tqdm(pool.imap_unordered(worker, payload),
                                total=len(payload),
                                desc="m6A correlation"))
    else:
        results = [worker(args) for args in tqdm(payload,
                                                 desc="m6A correlation")]

    heats, masks = zip(*results)
    heats = np.array(heats)            # (n_reads, n_periods, n_pos)
    masks = np.array(masks)            # (n_reads, n_pos)

    pct_zero = 100 * masks.sum(axis=0) / masks.shape[0]

    if USE_FRAC_POS:
        if thresh_val is None:                   # compute exactly once
            thresh_val = _compute_occ_thresh(heats, df.iloc[0].condition)


        valid   = ~np.isnan(heats)
        pos_cnt = np.sum((heats > thresh_val) & valid, axis=0)
        tot_cnt = np.sum(valid, axis=0)
        frac_pos = 100 * pos_cnt / np.where(tot_cnt == 0, np.nan, tot_cnt)
        
        if return_tot:
            return frac_pos, pct_zero, thresh_val, tot_cnt[PERIOD_IDX]
        else:
            return frac_pos, pct_zero, thresh_val
    else:
        mean_heat = np.nanmean(heats, axis=0) # consider nanmean
        return mean_heat, pct_zero, None

# ╔════════════════════════════════════════════════════════════════╗
# ║           BUILD & PLOT HEAT-MAPS –  Pearson correlation        ║
# ╚════════════════════════════════════════════════════════════════╝
# ═════════ BUILD & PLOT HEAT-MAPS ═════════
print("[m6A-PERIODICITY] computing …")

# ---- full-condition calls (threshold computed & cached) ----------
hA, pctA, THR_A = _condition_heat(dfA_all)
hB, pctB, THR_B = _condition_heat(dfB_all)
hD              = hB - hA

# Store in a dict for later bootstrap/type slices
COND_THRESH = {COND_A: THR_A, COND_B: THR_B}

# ───────────── helper to make symmetric heat-maps ───────────── #
def _heat(fig_title, Z, fixed_lim=None, y_vals=PERIODS_DISPLAY):
    """
    Draw a heat-map.  z-range priority:
      1) caller-supplied fixed_lim
      2) %-reads mode → [0, 100]
      3) dynamic 5th-/95th-percentile clip (≈ middle 90 %)
    """
    # ───── apply optional box-car smoothing along rel_pos axis ─────
    if SMOOTH_HEAT_BP > 1:
        win = max(1, int(round(SMOOTH_HEAT_BP / BIN_WIDTH_M6A)))
        Z_sm = np.empty_like(Z)
        for i in range(Z.shape[0]):
            row = Z[i]
            mask = ~np.isnan(row)
            csum = np.cumsum(np.where(mask, row, 0.0))
            cnt  = np.cumsum(mask.astype(int))
            # sliding-window subtraction
            csum[win:] -= csum[:-win]
            cnt[win:]  -= cnt[:-win]
            smooth = csum / np.where(cnt == 0, np.nan, cnt)
            smooth[:win-1] = np.nan
            Z_sm[i] = smooth
        Z = Z_sm

    # 1) explicit limits from caller ↓
    if fixed_lim is not None:
        zmin, zmax = fixed_lim
        zmid       = 0

    # 3) dynamic clip to 5th–95th percentiles ↓
    else:
        p5, p95 = np.nanpercentile(Z, [5, 95])
        # keep symmetry for signed maps so colours stay intuitive
        if (p5 < 0) and (p95 > 0):
            v = max(abs(p5), abs(p95))
            zmin, zmax = -v, v
            zmid       = 0
        else:
            zmin, zmax = p5, p95
            zmid       = None

    # colour-bar label
    cbar_title = ("% reads\n(r > 0)" if USE_FRAC_POS and "Δ" not in fig_title
                  else ("Pearson r" if "Δ" not in fig_title else "Δ r or Δ %"))

    fig = go.Figure(go.Heatmap(
        x          = cent_m6a,
        y          = y_vals,
        z          = Z,
        colorscale = "RdBu",
        zmid       = zmid,
        zmin       = zmin,
        zmax       = zmax,
        colorbar   = dict(title=cbar_title)
    ))
    fig.update_layout(
        template    = "plotly_white",
        title       = fig_title,
        xaxis_title = "rel_pos (bp)",
        yaxis_title = Y_AXIS_LABEL,
        xaxis       = dict(range=[-PLOT_WINDOW, PLOT_WINDOW]),
        width       = 800
    )
    return fig



# single-condition maps use the same fixed limits (-1…1)
#corr_lim = (-0.02, 0.08)
fig_hA = _heat(f"{COND_A} – Pearson correlation", hA)#, corr_lim)
fig_hB = _heat(f"{COND_B} – Pearson correlation", hB)#, corr_lim)

# difference map scales to its own max (could be up to ±2)
fig_hD = _heat("Δ Pearson r (B – A)", hD)

fig_hA.show(); fig_hB.show(); fig_hD.show()
print("[m6A-PERIODICITY] done.")
#
# # ─────────────────── Example individual reads ─────────────────── #
# import plotly.graph_objects as go
#
# # pick two example reads (first two, or .sample(...))
# sample_reads = dfA_all.head(2)
#
# for row in sample_reads.itertuples(index=False):
#     # 1) build the occupancy track
#     track = _m6a_track(
#         np.asarray(row.rel_pos,      dtype=int),
#         np.asarray(row.mod_qual_bin, dtype=int)
#     )
#
#     # plot raw occupancy
#     fig_track = go.Figure(go.Scatter(
#         x=cent_m6a, y=track,
#         mode="lines", name="m6A occupancy"
#     ))
#     fig_track.update_layout(
#         template="plotly_white",
#         title=f"m6A occupancy – read {row.read_id}",
#         xaxis_title="rel_pos (bp)",
#         yaxis_title="Occupancy",
#         xaxis=dict(range=[-PLOT_WINDOW, PLOT_WINDOW]),
#         width = 800
#     )
#     fig_track.show()
#
#     # 2) compute and plot per‐read heatmap
#     H = _per_read_heat(track)
#     # use same fixed limits as single‐cond (–1…1)
#     fig_read = _heat(f"Pearson r – read {row.read_id}", H, fixed_lim=(-1,1))
#     fig_read.show()

# ─────────────────── Correlation vs rel_pos – slider plot ─────────────────── #
import plotly.graph_objects as go

def _nan_running_mean(arr, win):
    """
    Windowed mean that ignores NaNs.
    The first (win-1) positions are NaN so the output aligns
    with a centred boxcar of width `win`.
    """
    arr   = np.asarray(arr, float)
    mask  = ~np.isnan(arr)
    csum  = np.cumsum(np.where(mask, arr, 0.0))
    cnt   = np.cumsum(mask.astype(int))

    # standard cumulative-sum window trick
    csum[win:] = csum[win:] - csum[:-win]
    cnt[win:]  = cnt[win:]  - cnt[:-win]

    out = csum / np.where(cnt == 0, np.nan, cnt)
    out[:win-1] = np.nan        # left-pad so length stays the same
    return out

# ─────────── Plot correlation vs rel_pos for one period ────────── #
idx = np.argwhere(PERIODS_BP == SELECT_PERIOD).item()

# raw per-condition series (may contain NaNs from skipped windows)
yA_raw = hA[idx]
yB_raw = hB[idx]

# ---- smooth with NaN-aware running mean --------------------------
win_bins = max(1, int(round(SMOOTH_LINE_BP / BIN_WIDTH_M6A)))
yA  = _nan_running_mean(yA_raw, win_bins)
yB  = _nan_running_mean(yB_raw, win_bins)
yD  = yB - yA                       # Δ after smoothing

# ---- dynamic y-axis limits ---------------------------------------
if USE_FRAC_POS:
    top_label  = "% reads r>0"
    diff_label = "Δ % (B−A)"
    y_top_rng  = [0, 100]
else:
    top_label  = "Pearson r"
    diff_label = "Δ Pearson r"
    ymin, ymax = np.nanmin([yA, yB]), np.nanmax([yA, yB])
    pad        = 0.05 * (ymax - ymin) if ymax > ymin else 0.05
    y_top_rng  = [ymin - pad, ymax + pad]

diff_max = np.nanmax(np.abs(yD))
y_diff_rng = [-diff_max * 1.05, diff_max * 1.05]

# ---- build figure ------------------------------------------------
fig = make_subplots(
    rows=3, cols=1, shared_xaxes=True,
    row_heights=[0.6, 0.25, 0.15],
    vertical_spacing=0.08
)

# row 1: smoothed condition traces
fig.add_trace(go.Scatter(
    x=cent_m6a, y=yA,
    mode="lines", name=COND_A,
    line=dict(color=CLR_A)
), row=1, col=1)

fig.add_trace(go.Scatter(
    x=cent_m6a, y=yB,
    mode="lines", name=COND_B,
    line=dict(color=CLR_B)
), row=1, col=1)

# row 2: smoothed Δ
fig.add_trace(go.Scatter(
    x=cent_m6a, y=yD,
    mode="lines", name=f"{COND_B} − {COND_A}",
    line=dict(color="black", dash="dash")
), row=2, col=1)

# row 3: % skipped (unchanged)
fig.add_trace(go.Scatter(
    x=cent_m6a, y=pctA,
    mode="lines", name=f"% zero {COND_A}",
    line=dict(color=CLR_A, dash="dot")
), row=3, col=1)
fig.add_trace(go.Scatter(
    x=cent_m6a, y=pctB,
    mode="lines", name=f"% zero {COND_B}",
    line=dict(color=CLR_B, dash="dot")
), row=3, col=1)

# ---- axes & layout ----------------------------------------------
fig.update_yaxes(title_text=top_label,  row=1, col=1, range=y_top_rng)
fig.update_yaxes(title_text=diff_label, row=2, col=1, range=y_diff_rng)
fig.update_yaxes(title_text="% skipped\n(r = 0)", row=3, col=1, range=[0, 100])

fig.update_xaxes(title_text="rel_pos (bp)",
                 row=3, col=1,
                 range=[-PLOT_WINDOW, PLOT_WINDOW])

fig.update_layout(
    template="plotly_white",
    title=(f"Period = {SELECT_PERIOD} bp  –  "
           f"{SMOOTH_LINE_BP} bp running mean"),
    height=700, width=800,
    showlegend=True
)
fig.show()



In [None]:
# %%  ───────────────────────────────────────────────────────────────
#  Box-plots of per-type×bootstrap “% Phased” + motif histogram
#  Keys switch to (type, bed_start) when ‘intergenic_control’ present
#  ───────────────────────────────────────────────────────────────────

# ─────── DISABLE ALL PROGRESS BARS ───────
import tqdm
tqdm = lambda iterable, **kwargs: iterable
# ───────── USER-CONFIG (add / update) ─────────
BIN_STEP_BP         = 10
BIN_WIDTH_BP        = 200          # full width of each analysis window
FIRST_BIN_CENTER_BP = 0
N_BOOT_GROUPS       = 2
BOOT_GROUP_SEED     = 43

SUB_BIN_SIZE_BP     = 20          # NEW ⟶ width of sub-bins inside each window
USE_WINDOW_ENTROPY  = True       # NEW ⟶ True ⇒ use Shannon entropy metric
# ──────────────────────────────────────────────

# ───────── ENTROPY BASELINE CONFIG ─────────
BASELINE_PCTL = 0      # percentile to subtract (0–100)
CLIP_BELOW_0 = True    # clip negatives to 0 before entropy
# -------------------------------------------


# ───────── PLOT-WINDOW (display only) ─────────
PLOT_HALF_WINDOW_BP = 1000          # ← set any positive integer
PLOT_MOTIF_HIST     = False      # toggle the motif histogram on/off

# ─── decide key granularity ───
has_ic = ((dfA_all["type"] == "intergenic_control").any() or
          (dfB_all["type"] == "intergenic_control").any())

KEY_COLS = ["type", "bed_start"] if has_ic else ["type"]
print(f"[DEBUG] grouping keys =", KEY_COLS)

# ───────────────── helper: assign bootstrap groups ────────────────
def _subbin_stat(vec, mask, *, tlabel, cond):
    """
    Metric for a sub-bin minus its condition×type baseline.
    • vec[mask]  → raw metric values (may contain NaNs)
    • tlabel, cond  → look up baseline
    """
    if mask.sum() == 0:
        raw = 0.0
    else:
        raw = np.nanmean(vec[mask]) if USE_FRAC_POS else np.nanmean(vec[mask])
                                                # same either way
    # subtract baseline
    base = (BASE_A if cond == COND_A else BASE_B)[tlabel]
    adj  = raw - base
    if CLIP_BELOW_0:
        adj = max(adj, 0.0)
    return adj



def _assign_bootstrap_groups(df, n_groups, seed=42):
    rng = np.random.RandomState(seed)
    df  = df.copy()
    for key_vals, idx in df.groupby(KEY_COLS).groups.items():
        shuffled = rng.permutation(idx)
        df.loc[shuffled, "__boot"] = np.arange(len(shuffled)) % n_groups
    df["__boot"] = df["__boot"].astype(int)
    return df

_LN = np.log                                     # natural log (base e)

def _shannon_entropy01(vals):
    """
    Normalised Shannon entropy, 0 ≤ H ≤ 1.
    `vals` can be any non-negative numbers (need not sum to 1).
    """
    v   = np.asarray(vals, float)
    v[v < 0] = 0
    tot = v.sum()
    if tot == 0 or np.isnan(tot):
        return np.nan
    p        = v / tot
    nz       = p > 0
    h_raw    = -(p[nz] * _LN(p[nz])).sum()       # 0‒ln(n)
    h_max    = _LN(len(vals))                    # ln(n)
    return h_raw / h_max if h_max > 0 else np.nan   # 0‒1

# NEW – convert entropy → order once, everywhere
def _order_score(vals):
    h = _shannon_entropy01(vals)
    return 1 - h if np.isfinite(h) else np.nan

# ───────────────── SELECT_PERIOD series per (key, boot) ───────────
def _per_key_boot_series(df):
    out = {}
    grp_cols = KEY_COLS + ["__boot"]
    for key_tuple, g in df.groupby(grp_cols, sort=False):
        if g.empty:
            continue
        thr = COND_THRESH[g.iloc[0].condition]
        heat, _ , _ = _condition_heat(g, thresh_val=thr)
        out[key_tuple] = heat[PERIOD_IDX]     # 1-D trace
    return out

dfA_boot = _assign_bootstrap_groups(dfA_all, N_BOOT_GROUPS, BOOT_GROUP_SEED)
dfB_boot = _assign_bootstrap_groups(dfB_all, N_BOOT_GROUPS, BOOT_GROUP_SEED)

series_A = _per_key_boot_series(dfA_boot)
series_B = _per_key_boot_series(dfB_boot)

keys_all = sorted(series_A.keys() | series_B.keys())
print(f"[DEBUG] {len(keys_all)} group×boot combos – first five:", keys_all[:5])

def _build_baseline(series_dict, condition_name):
    """
    Return {type_label: baseline_value} where baseline is the
    P-th percentile (configurable) of the selected metric across
    *all* reads / positions / boots for that (condition, type).
    """
    pools = {}
    for (tlabel, *_), vec in series_dict.items():   # key[0] is type
        pools.setdefault(tlabel, []).append(vec)

    baselines = {}
    for tlabel, vecs in pools.items():
        all_vals = np.concatenate([v[~np.isnan(v)] for v in vecs])
        if all_vals.size == 0:
            baselines[tlabel] = 0.0
        else:
            # if BASELINE_PCTL != 0 then  baselines[tlabel] = np.nanpercentile(all_vals, BASELINE_PCTL)
            if BASELINE_PCTL != 0:
                baselines[tlabel] = np.nanpercentile(all_vals, BASELINE_PCTL)
            else:
                # subtract 0
                baselines[tlabel] = 0.0


        print(f"[DEBUG] {condition_name} | {tlabel}: "
              f"{BASELINE_PCTL}th-pct baseline = {baselines[tlabel]:.4f}")
    return baselines

BASE_A = _build_baseline(series_A, COND_A)   # dict {type: value}
BASE_B = _build_baseline(series_B, COND_B)


# ─────────────────── build window grid (patched) ────────────────
pos_centres = np.arange(FIRST_BIN_CENTER_BP,
                        PLOT_HALF_WINDOW_BP + 1,
                        BIN_STEP_BP)

# If the first positive bin is zero, skip that same zero on the
# negative side to avoid duplication.
if FIRST_BIN_CENTER_BP == 0:
    neg_centres = -pos_centres[1:][::-1]   # skip the 0
else:
    neg_centres = -pos_centres[::-1]

centres = np.concatenate([neg_centres, pos_centres])
edges   = [(c - BIN_WIDTH_BP/2, c + BIN_WIDTH_BP/2) for c in centres]
labels  = [f"{int(lo)} to {int(hi)}" for lo, hi in edges]

# ─────────────────── % Phased per window (unchanged) ──────────────
def _window_metric(series_dict, cond):
    """
    Metric per window for one condition (A or B), after
    type-specific baseline subtraction.
    """
    metric = [[] for _ in edges]

    for key_tuple, vec in series_dict.items():
        tlabel = key_tuple[0]                 # first element is type

        for idx, (lo, hi) in enumerate(edges):
            win_mask = (cent_m6a >= lo) & (cent_m6a < hi)
            if not win_mask.any():
                continue

            if not USE_WINDOW_ENTROPY:
                win_val = np.nanmax(vec[win_mask])
            else:
                sub_edges = np.arange(lo, hi, SUB_BIN_SIZE_BP)
                sub_means = []
                for s_lo in sub_edges:
                    s_hi = min(s_lo + SUB_BIN_SIZE_BP, hi)
                    m = win_mask & (cent_m6a >= s_lo) & (cent_m6a < s_hi)
                    sub_means.append(
                        _subbin_stat(vec, m, tlabel=tlabel, cond=cond)
                    )
                win_val = _order_score(sub_means)       # ← was  _shannon_entropy01(...)

            metric[idx].append(win_val)

    return metric



phased_A = _window_metric(series_A, COND_A)
phased_B = _window_metric(series_B, COND_B)


# pretty-print string used in titles / y-labels
BASE_TXT = (f" – baseline {BASELINE_PCTL}th pct"
            if BASELINE_PCTL else "")
# ─────────────────── motif-density histogram (unchanged) ──────────
unique_pairs = set()
for df in (dfA_all, dfB_all):
    for t, starts in zip(df["type"], df["motif_rel_start"]):
        for pos in starts:
            unique_pairs.add((t, pos))

hist_counts = [0] * len(edges)
for _, pos in unique_pairs:
    for i, (lo, hi) in enumerate(edges):
        if lo <= pos < hi:
            hist_counts[i] += 1
            break

# ─────────────────── debug first key (unchanged) ──────────────────
first_key = keys_all[0]
lo0, hi0  = edges[0]
m0        = (cent_m6a >= lo0) & (cent_m6a < hi0)
vals0_A   = series_A[first_key][m0]
rng0      = np.nanmax(vals0_A) - np.nanmin(vals0_A)
pct0      = 100*(rng0/100) if USE_FRAC_POS else 100*(rng0/2)
print(f"[DEBUG] {first_key}, window '{labels[0]}': range={rng0:.4f}, % Phased={pct0:.2f}")

# ─── helper: 3-bin running mean that ignores NaNs ─────────────────
import numpy as np
def _nan_smooth(arr, k=3):
    """Return array smoothed with a centred k-bin mean, NaN-aware."""
    a = np.asarray(arr, float)
    out = np.full_like(a, np.nan)

    # build a (len, k) sliding window view
    idx = np.arange(len(a))
    pads = k // 2
    a_p = np.pad(a, (pads, pads), constant_values=np.nan)
    for i in idx:
        win = a_p[i : i + k]
        if np.all(np.isnan(win)):
            continue
        out[i] = np.nanmean(win)
    return out

# ─────────────────── figure (unchanged) ───────────────────────────
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# decide how many rows
nrows = 2 if PLOT_MOTIF_HIST else 1
row_heights = [0.7, 0.3] if PLOT_MOTIF_HIST else [1.0]

fig = make_subplots(
    rows=nrows, cols=1, shared_xaxes=True,
    row_heights=row_heights,
    vertical_spacing=0.07
)

# row 1 – median + IQR lines (as per the previous update)
for idx, (cond, clr, ranges) in enumerate(
        zip([COND_A, COND_B], [CLR_A, CLR_B], [phased_A, phased_B])):

    q1s, meds, q3s = [], [], []
    for win_vals in ranges:
        if win_vals:
            q1, med, q3 = np.percentile(win_vals, [25, 50, 75])
        else:
            q1 = med = q3 = np.nan
        q1s.append(q1); meds.append(med); q3s.append(q3)

    # --- 3-bin smoothing ----------------------------------------
    q1s  = _nan_smooth(q1s)
    meds = _nan_smooth(meds)
    q3s  = _nan_smooth(q3s)

    x_pos = [c - DOT_OFFSET if idx == 0 else c + DOT_OFFSET for c in centres]

    # median
    fig.add_trace(go.Scatter(
        x=x_pos, y=meds, mode="lines",
        name=f"{cond} median",
        line=dict(color=clr, width=4)
    ), row=1, col=1)

    # Q1 & Q3
    for bound in (q1s, q3s):
        fig.add_trace(go.Scatter(
            x=x_pos, y=bound, mode="lines",
            showlegend=False,
            line=dict(color=clr, width=0.5, dash="dash")
        ), row=1, col=1)

# row 2 – motif histogram (optional)
if PLOT_MOTIF_HIST:
    fig.add_trace(go.Bar(
        x=centres,
        y=hist_counts,
        width=BIN_WIDTH_BP,
        marker_color="grey",
        name="unique motif starts / type"
    ), row=2, col=1)

# ─── axes & layout ───
# ───────── USER-CONFIG (add this just below the other config vars) ─────────
X_TICK_BP = 100          # spacing (bp) between major x-ticks
# ───────────────────────────────────────────────────────────────────────────


# ─── 1. Windowed %-phased / IQR figure ────────────────────────────
for r in range(1, nrows + 1):
    fig.update_xaxes(
        row=r, col=1,
        type="linear",
        tickmode="linear",
        dtick=X_TICK_BP,                 # ← use configured spacing
        range=[-PLOT_HALF_WINDOW_BP, PLOT_HALF_WINDOW_BP],
        showgrid=False, zeroline=False
    )

BASE_TXT  = "" if BASELINE_PCTL == 0 else f" – baseline {BASELINE_PCTL}th pct"
ylab      = "Order score (0–1)" if USE_WINDOW_ENTROPY else "Max Pearson r"

fig.update_yaxes(
    title = ylab, row = 1, col = 1,
    showgrid = False, zeroline = False,
    range = [0, 1] if USE_WINDOW_ENTROPY else None
)
fig.update_layout(
    template = "plotly_white",
    width    = 1000, height = 650,
    title    = (f"Windowed {BIN_WIDTH_BP} bp analysis{BASE_TXT} "
                f"(centre ±{FIRST_BIN_CENTER_BP} bp; step {BIN_STEP_BP} bp; "
                f"{N_BOOT_GROUPS} bootstrap groups)")
)

fig.show()
# %%  ───────────────────────────────────────────────────────────────

# ──────────────────────────────────────────────────────────────────
#  NEW 3-in-1 COVERAGE FIGURE
#      • total overlap
#      • overlap of reads with r > thresh
#      • overlap of reads with NaN r
# ──────────────────────────────────────────────────────────────────
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ---------- helper: mean ± SEM -----------------------------------
def _mean_sem(vecs):
    M    = np.vstack(vecs)               # (n_groups, n_windows)
    mean = np.nanmean(M, axis=0)
    n    = np.sum(~np.isnan(M), axis=0)
    sem  = np.nanstd(M, axis=0, ddof=1) / np.sqrt(n)
    return mean, sem

def _hex_to_rgba(hexclr, alpha=0.2):
    h = hexclr.lstrip('#')
    r, g, b = (int(h[i:i+2], 16) for i in (0, 2, 4))
    return f"rgba({r},{g},{b},{alpha})"

# ---------- per-(key,boot) coverage arrays -----------------------
def _per_key_boot_cov3(df):
    """
    Return three dicts keyed by (type, bed_start, __boot):
        total_cov , thresh_cov , nan_cov
    Each value is a length-len(edges) int array.
    """
    n_win = len(edges)
    out_tot, out_thr, out_nan = {}, {}, {}

    for key_tuple, g in df.groupby(KEY_COLS + ["__boot"], sort=False):
        tot  = np.zeros(n_win, int)
        thr  = np.zeros(n_win, int)
        nan  = np.zeros(n_win, int)

        threshold = COND_THRESH[g.iloc[0].condition]

        for row in g.itertuples(index=False):
            rel = np.asarray(row.rel_pos, dtype=int)
            mod = np.asarray(row.mod_qual_bin, dtype=int)

            # pre-compute r once for this read
            track = _m6a_track(rel, mod)
            r_vec = _normxcorr1d(track, TEMPLATES[SELECT_PERIOD])

            for w, (lo, hi) in enumerate(edges):
                overlap = ((rel >= lo) & (rel < hi)).any()
                if not overlap:
                    continue

                # --- raw coverage ---
                tot[w] += 1

                # bins of r that correspond to this genomic window
                msk   = (cent_m6a >= lo) & (cent_m6a < hi)
                r_win = r_vec[msk]

                # --- classification within the overlapped reads ---
                if np.all(np.isnan(r_win)):
                    nan[w] += 1
                elif (threshold is not None) and (np.nanmax(r_win) > threshold):
                    thr[w] += 1
                # else: falls into “overlap but r ≤ threshold”

        out_tot[key_tuple]  = tot
        out_thr[key_tuple]  = thr
        out_nan[key_tuple]  = nan

    return out_tot, out_thr, out_nan


covA_tot, covA_thr, covA_nan = _per_key_boot_cov3(dfA_boot)
covB_tot, covB_thr, covB_nan = _per_key_boot_cov3(dfB_boot)

# ─── if we're plotting mean r (not % reads > thresh), use tot − nan ───
if not USE_FRAC_POS:
    covA_thr = { k: covA_tot[k] - covA_nan[k] for k in covA_tot }
    covB_thr = { k: covB_tot[k] - covB_nan[k] for k in covB_tot }

# ---------- aggregate to condition-level mean ± SEM --------------
meanA_tot, semA_tot = _mean_sem(covA_tot.values())
meanB_tot, semB_tot = _mean_sem(covB_tot.values())

meanA_thr, semA_thr = _mean_sem(covA_thr.values())
meanB_thr, semB_thr = _mean_sem(covB_thr.values())

meanA_nan, semA_nan = _mean_sem(covA_nan.values())
meanB_nan, semB_nan = _mean_sem(covB_nan.values())



# ---------- build 3-row figure -----------------------------------
fig_cov = make_subplots(
    rows=3, cols=1, shared_xaxes=True,
    row_heights=[0.4, 0.3, 0.3],
    vertical_spacing=0.05
)

def _add(row, mean, sem, name, clr, dash=None):
    # main line
    fig_cov.add_trace(
        go.Scatter(
            x=centres, y=mean, mode="lines",
            name=name, line=dict(color=clr, dash=dash, width=2),
            hovertemplate="window=%{x}<br>mean=%{y:.1f} ± %{customdata:.1f}",
            customdata=sem
        ),
        row=row, col=1
    )
    # ribbon
    fig_cov.add_trace(
        go.Scatter(
            x=np.concatenate([centres, centres[::-1]]),
            y=np.concatenate([mean - sem, (mean + sem)[::-1]]),
            fill="toself", fillcolor=_hex_to_rgba(clr, 0.2),
            line=dict(color="rgba(0,0,0,0)"), hoverinfo="skip",
            showlegend=False
        ),
        row=row, col=1
    )

# row 1 – raw coverage
_add(1, meanA_tot,  semA_tot,  f"{COND_A} total",  CLR_A)
_add(1, meanB_tot,  semB_tot,  f"{COND_B} total",  CLR_B, dash="dash")

# row 2 – r > threshold
_add(2, meanA_thr,  semA_thr,  f"{COND_A} r>thr", CLR_A)
_add(2, meanB_thr,  semB_thr,  f"{COND_B} r>thr", CLR_B, dash="dash")

# row 3 – NaN coverage
_add(3, meanA_nan, semA_nan, f"{COND_A} NaN", CLR_A)
_add(3, meanB_nan, semB_nan, f"{COND_B} NaN", CLR_B, dash="dash")

# ---------- axes & layout ----------------------------------------
# ─── 2. 3-row coverage figure ────────────────────────────────────
for r in (1, 2, 3):
    fig_cov.update_xaxes(
        row=r, col=1,
        type="linear",
        tickmode="linear",
        dtick=X_TICK_BP,                 # ← use configured spacing
        range=[-PLOT_HALF_WINDOW_BP, PLOT_HALF_WINDOW_BP],
        showgrid=False, zeroline=False
    )

fig_cov.update_yaxes(title="raw coverage",             row=1, col=1, rangemode="tozero")
fig_cov.update_yaxes(title="coverage (r > thr)",       row=2, col=1, rangemode="tozero")
fig_cov.update_yaxes(title="coverage (r = NaN)",       row=3, col=1, rangemode="tozero")



fig_cov.update_layout(
    template="plotly_white",
    width=1000, height=850,
    title=(f"Coverage metrics per {BIN_WIDTH_BP} bp window "
           f"({N_BOOT_GROUPS} bootstrap group"
           f"{'s' if N_BOOT_GROUPS>1 else ''}; keys = {', '.join(KEY_COLS)})"),
    showlegend=True
)

fig_cov.show()


# PE TYPE PLOTS
# %%────────────────────────────────────────────────────────────────────────────
#  PER-TYPE WINDOWED %-PHASED LINES  +  MOTIF MARKERS
#  (Figure 2 – leaves the combined plot & coverage plots untouched)
# -----------------------------------------------------------------------------#
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ─── gather the list of types present in either condition ────────────────────
types_all = sorted(
    set(dfA_all["type"].unique()).union(dfB_all["type"].unique())
)

# ─── helper: phased-metric arrays restricted to one type ──────────────────────
def _window_metric_for_type(series_dict, type_label, cond):
    """Same as above, restricted to one type."""
    out = [[] for _ in edges]
    for key, vec in series_dict.items():
        if key[0] != type_label:
            continue
        for idx, (lo, hi) in enumerate(edges):
            win_mask = (cent_m6a >= lo) & (cent_m6a < hi)
            if not win_mask.any():
                continue

            if not USE_WINDOW_ENTROPY:
                val = np.nanmax(vec[win_mask])
            else:
                sub_edges = np.arange(lo, hi, SUB_BIN_SIZE_BP)
                sub_means = []
                for s_lo in sub_edges:
                    s_hi = min(s_lo + SUB_BIN_SIZE_BP, hi)
                    m = win_mask & (cent_m6a >= s_lo) & (cent_m6a < s_hi)
                    sub_means.append(
                        _subbin_stat(vec, m, tlabel=type_label, cond=cond)
                    )
                val = _shannon_entropy01(sub_means)
            out[idx].append(val)
    return out



# ─── motif positions per type (set for quick lookup) ──────────────────────────
motifs_by_type = {}
for t, pos in unique_pairs:         # `unique_pairs` built earlier
    motifs_by_type.setdefault(t, set()).add(pos)

# ─── build per-type figure ───────────────────────────────────────
n_rows = len(types_all)
fig_t  = make_subplots(
    rows=n_rows, cols=1,
    shared_xaxes=True,
    specs=[[{"secondary_y": True}] for _ in range(n_rows)],  # ← NEW
    vertical_spacing=0.04,
    row_heights=[1 / n_rows] * n_rows
)

def _window_baseline_adj(series_dict, type_label, cond):
    """Return ⟨metric − baseline⟩ per window for one type & condition."""
    means = np.full(len(edges), np.nan)
    baseline = (BASE_A if cond == COND_A else BASE_B)[type_label]

    # pool all vecs belonging to this type
    vecs = [v for k, v in series_dict.items() if k[0] == type_label]
    if not vecs:
        return means

    for idx, (lo, hi) in enumerate(edges):
        win_mask = (cent_m6a >= lo) & (cent_m6a < hi)
        vals = np.concatenate([v[win_mask] for v in vecs])
        vals = vals[~np.isnan(vals)]
        if vals.size:
            means[idx] = np.nanmean(vals) - baseline
            if CLIP_BELOW_0:
                means[idx] = max(means[idx], 0.0)
    return means


for row_idx, tlabel in enumerate(types_all, start=1):
    # centred at each subplot’s vertical mid-point
    y_pos = 1 - (row_idx - 0.5) / n_rows
    # ---- compute median + IQR for each condition for this type --------------
    phased_tA = _window_metric_for_type(series_A, tlabel, COND_A)
    phased_tB = _window_metric_for_type(series_B, tlabel, COND_B)


    for cond_idx, (cond, clr, ranges) in enumerate(
            zip([COND_A, COND_B], [CLR_A, CLR_B], [phased_tA, phased_tB])):

        q1s, meds, q3s = [], [], []
        for win_vals in ranges:
            if win_vals:
                q1, med, q3 = np.percentile(win_vals, [25, 50, 75])
            else:
                q1 = med = q3 = np.nan
            q1s.append(q1); meds.append(med); q3s.append(q3)

        q1s  = _nan_smooth(q1s)
        meds = _nan_smooth(meds)
        q3s  = _nan_smooth(q3s)

        x_pos = [c - DOT_OFFSET if cond_idx == 0 else c + DOT_OFFSET
                 for c in centres]

        # median line
        fig_t.add_trace(go.Scatter(
            x=x_pos, y=meds, mode="lines",
            name=f"{tlabel} – {cond} median",
            legendgroup=f"{cond}", showlegend=(row_idx == 1),
            line=dict(color=clr, width=4)
        ), row=row_idx, col=1)

        # IQR bounds
        for bound in (q1s, q3s):
            fig_t.add_trace(go.Scatter(
                x=x_pos, y=bound, mode="lines",
                showlegend=False, legendgroup=f"{cond}",
                line=dict(color=clr, width=0.5, dash="dash")
            ), row=row_idx, col=1)

        # ➋ baseline-subtracted binned metric on secondary axis
        pct_A = _window_baseline_adj(series_A, tlabel, COND_A)
        pct_B = _window_baseline_adj(series_B, tlabel, COND_B)

        fig_t.add_trace(go.Scatter(
            x=centres, y=pct_A, mode="lines",
            name=f"{tlabel} – {COND_A} adj", legendgroup=f"{COND_A}",
            line=dict(color=CLR_A, width=2, dash="dot")
        ), row=row_idx, col=1, secondary_y=True)

        fig_t.add_trace(go.Scatter(
            x=centres, y=pct_B, mode="lines",
            name=f"{tlabel} – {COND_B} adj", legendgroup=f"{COND_B}",
            line=dict(color=CLR_B, width=2, dash="dot")
        ), row=row_idx, col=1, secondary_y=True)

    # replace the previous annotation block inside the for-loop
    y_pos = 1 - (row_idx - 0.5) / n_rows          # vertical centre of row
    fig_t.add_annotation(
        text        = f"<b>{tlabel}</b>",
        xref="paper", x = 0.01,                    # left margin
        yref="paper", y = y_pos,
        showarrow   = False,
        font        = dict(size = 12)
    )



    # ---- vertical dashed lines for motif starts -----------------------------
    for mpos in motifs_by_type.get(tlabel, []):
        fig_t.add_vline(
            x=mpos, line_dash="dash", line_color="grey",
            row=row_idx, col=1, opacity=0.5
        )

    # figure 1 (combined)
    ylab = "Normalised Shannon entropy (0–1)" if USE_WINDOW_ENTROPY \
           else "Max Pearson r (mean per rex site)"
    y_rng = [0, 1] if USE_WINDOW_ENTROPY else None

    fig_t.update_layout(
    title=(f"Windowed {BIN_WIDTH_BP} bp analysis – per type{BASE_TXT} "
           f"(centre ±{FIRST_BIN_CENTER_BP} bp; step {BIN_STEP_BP} bp)"))

order_label = "Order score (0–1)" if USE_WINDOW_ENTROPY else "Max Pearson r"
for r in range(1, n_rows + 1):
    fig_t.update_yaxes(
        title_text = order_label if r == 1 else "",
        row        = r, col = 1,
        showgrid   = False, zeroline = False,
        range      = [0, 1] if USE_WINDOW_ENTROPY else None
    )
    # secondary (baseline-adjusted) axis on the right
    fig_t.update_yaxes(
        title_text = "Adj. metric",
        row        = r, col = 1, secondary_y = True,
        showgrid   = False, zeroline = False
    )


# ─── shared x-axis settings ---------------------------------------------------
fig_t.update_xaxes(
    type="linear", tickmode="linear", dtick=X_TICK_BP,
    range=[-PLOT_HALF_WINDOW_BP, PLOT_HALF_WINDOW_BP],
    showgrid=False, zeroline=False
)

fig_t.update_layout(
    template="plotly_white",
    width=1000, height=300 + 175 * n_rows,
    title=(f"Windowed {BIN_WIDTH_BP} bp analysis – per type "
           f"(centre ±{FIRST_BIN_CENTER_BP} bp; step {BIN_STEP_BP} bp)"),
    legend_tracegroupgap=60
)

fig_t.show()
# %%────────────────────────────────────────────────────────────────────────────


In [None]:
# %%  ───────────────────────────────────────────────────────────────
#  DEBUG: key with the most reads-considered  +  chr_type sanity     #
# ────────────────────────────────────────────────────────────────────
import numpy as np, itertools, collections
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ────────────────────────────────────────────────────────────────────
# 1)   find key with max mean reads-considered                       |
# ────────────────────────────────────────────────────────────────────
def _heaviest_key(count_dicts):
    best_key, best_mean = None, -np.inf
    for d in count_dicts:
        for k, vec in d.items():
            m = np.nanmean(vec)
            if m > best_mean:
                best_key, best_mean = k, m
    return best_key, best_mean

key_max, mean_max = _heaviest_key([counts_A_cons, counts_B_cons])
type_key = "|".join(map(str, key_max[:-1]))
boot_idx = key_max[-1]

print(f"[DEBUG] heaviest key → '{type_key}', boot={boot_idx}  "
      f"(mean reads ≈ {mean_max:.1f})")

# ────────────────────────────────────────────────────────────────────
# 2)   helper → % r>thr  &  tot reads for SELECT_PERIOD              |
# ────────────────────────────────────────────────────────────────────
def _frac_tot(df_boot, cond):
    mask = (df_boot["__boot"] == boot_idx) & (df_boot["type"] == key_max[0])
    if len(key_max) == 3:                           # includes bed_start
        mask &= df_boot["bed_start"] == key_max[1]
    df_sel = df_boot[mask]
    if df_sel.empty:                                # not present in cond
        return None, None
    thr = COND_THRESH[cond]
    frac, _, _, tot = _condition_heat(
        df_sel, thresh_val=thr, return_tot=True
    )
    return frac[PERIOD_IDX], tot

fracA, totA = _frac_tot(dfA_boot, COND_A)
fracB, totB = _frac_tot(dfB_boot, COND_B)

# ────────────────────────────────────────────────────────────────────
# 3)   build two-row plot                                             |
# ────────────────────────────────────────────────────────────────────
fig_key = make_subplots(
    rows=2, cols=1, shared_xaxes=True,
    row_heights=[0.55, 0.45], vertical_spacing=0.07
)

if fracA is not None:
    fig_key.add_trace(go.Scatter(
        x=cent_m6a, y=fracA, mode="lines",
        name=f"{COND_A} – % r>thr", line=dict(color=CLR_A)
    ), row=1, col=1)
    fig_key.add_trace(go.Scatter(
        x=cent_m6a, y=totA, mode="lines",
        name=f"{COND_A} – tot", line=dict(color=CLR_A, dash="dot")
    ), row=2, col=1)

if fracB is not None:
    fig_key.add_trace(go.Scatter(
        x=cent_m6a, y=fracB, mode="lines",
        name=f"{COND_B} – % r>thr", line=dict(color=CLR_B)
    ), row=1, col=1)
    fig_key.add_trace(go.Scatter(
        x=cent_m6a, y=totB, mode="lines",
        name=f"{COND_B} – tot", line=dict(color=CLR_B, dash="dot")
    ), row=2, col=1)

fig_key.update_yaxes(title="% reads r>thr",       row=1, col=1, range=[0, 100])
fig_key.update_yaxes(title="# reads considered", row=2, col=1, rangemode="tozero")
fig_key.update_xaxes(title="rel_pos (bp)",
                     range=[-PLOT_WINDOW, PLOT_WINDOW],
                     row=2, col=1)

fig_key.update_layout(
    template="plotly_white",
    width=900, height=600,
    title=(f"Key with most reads: '{type_key}', boot={boot_idx}  "
           f"(period {SELECT_PERIOD} bp; thresh={OCC_THRESH_METHOD})")
)

fig_key.show()
# %%  ───────────────────────────────────────────────────────────────


In [None]:
# ╔═══════════════════════════════════════════════════════════════╗
# ║  SPLIT-AND-FLIP  +  POSITIVE-SIDE ANALYTICS (SELF-CONTAINED)  ║
# ╚═══════════════════════════════════════════════════════════════╝
import numpy as np, pandas as pd
from copy import deepcopy
from scipy.ndimage import uniform_filter1d
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ───────────────────────── USER CONFIG ───────────────────────── #
# – Split-read-specific options live here so you can tweak them without
#   touching the first (full-read) cell.
SPLIT_POSITIVE_ONLY       = True        # keep only rel_pos ≥ 0
SPLIT_CONS_START_POS      = 115         # first seed for consensus (bp)
SPLIT_NRL                 = 174         # nucleosome repeat length (bp)
SPLIT_CONSENSUS_MODE      = "per_condition_mean"   # "static" | "per_condition_mean"
SPLIT_CONS_ADAPT_WINDOW   = 75          # half-window for per-cond averaging
# ----------------------------------------------------------------

# (the rest of the global constants – BIN_WIDTH_OCC, PLOT_WINDOW, etc. –
#  are reused from the first cell and assumed to be already defined)

# ─── sanity check: keep an untouched copy of the original DataFrame ───
orig_filtered = filtered_reads_df.copy(deep=True)
dbg(f"[SANITY]  start  filtered_reads_df.shape = {filtered_reads_df.shape}")

# ───────────────────── split + optional flip ───────────────────── #
# (unchanged helper definitions omitted for brevity)
# … _split_mask, _filter_motifs, _flip, _flip_np, _flip_coords, _relims …

# --- split the reads ---------------------------------------------------------
split_reads_df = split_and_flip(orig_filtered)
dbg(f"[SPLIT]  {len(orig_filtered)} → {len(split_reads_df)} rows")

# --- enforce positive-side only ----------------------------------------------
if SPLIT_POSITIVE_ONLY:
    split_reads_df = split_reads_df.loc[split_reads_df["rel_read_end"] >= 0]
    dbg(f"[POS-SIDE] kept {len(split_reads_df)} rows with rel_read_end ≥ 0")

# ───────────────────── bin grids (positive side) ───────────────────── #
BIN_MIN = 0 if SPLIT_POSITIVE_ONLY else -PLOT_WINDOW
bins_occ = np.arange(BIN_MIN, PLOT_WINDOW + BIN_WIDTH_OCC, BIN_WIDTH_OCC)
cent_occ = (bins_occ[:-1] + bins_occ[1:]) / 2

bins_ctr = np.arange(BIN_MIN, PLOT_WINDOW + BIN_WIDTH_CTR, BIN_WIDTH_CTR)
cent_ctr = (bins_ctr[:-1] + bins_ctr[1:]) / 2

# ───────────────────── masking by condition / motifs ────────────────── #
maskA = (split_reads_df["condition"] == COND_A)
maskB = (split_reads_df["condition"] == COND_B)
dfA_all = split_reads_df[maskA & split_reads_df.apply(motif_pass, axis=1)]
dfB_all = split_reads_df[maskB & split_reads_df.apply(motif_pass, axis=1)]

# optional sorting for scatter
def _subset(df):
    if READ_SORT_KEY and READ_SORT_KEY in df.columns:
        df = df.sort_values(READ_SORT_KEY)
    return df.head(MAX_READS).reset_index(drop=True)

dfA, dfB = _subset(dfA_all), _subset(dfB_all)
dbg(f"[SPLIT-PLOT] {COND_A}: {len(dfA_all)} reads ({len(dfA)} plotted)")
dbg(f"[SPLIT-PLOT] {COND_B}: {len(dfB_all)} reads ({len(dfB)} plotted)")

# ───────────────────── coverage + centre fractions ─────────────────── #
fA_occ, fB_occ = (frac_core(dfA_all, bins_occ, cent_occ),
                  frac_core(dfB_all, bins_occ, cent_occ))
d_occ   = fB_occ - fA_occ
col_occ = [BAR_POS if x > 0 else BAR_NEG for x in d_occ]

rA_ctr, rB_ctr = (frac_cent(dfA_all, bins_ctr, cent_ctr),
                  frac_cent(dfB_all, bins_ctr, cent_ctr))
fA_ctr = uniform_filter1d(rA_ctr, SMOOTH_BINS_CTR, mode="nearest")
fB_ctr = uniform_filter1d(rB_ctr, SMOOTH_BINS_CTR, mode="nearest")
d_ctr  = fB_ctr - fA_ctr
col_ctr = [BAR_POS if x > 0 else BAR_NEG for x in d_ctr]

# ───────────────────── periodicity heat-maps ──────────────────────── #
def periodicity_heat(series):
    lag_bins = np.round(PERIODS_BP / BIN_WIDTH_CTR).astype(int)
    win_bins = PERIOD_WINDOW_BP // BIN_WIDTH_CTR
    H = np.full((len(PERIODS_BP), len(series)), np.nan)
    for i in range(len(series)):
        seg = series[max(0,i-win_bins):min(len(series),i+win_bins+1)]
        if np.isnan(seg).all():     # skip all-NaN windows
            continue
        seg = np.nan_to_num(seg) - np.nanmean(seg)   # demean
        for j,L in enumerate(lag_bins):
            if L < len(seg):
                H[j,i] = np.dot(seg[:-L], seg[L:])
    return H

hA_raw, hB_raw = periodicity_heat(fA_ctr), periodicity_heat(fB_ctr)
row_scale = np.nanmax(np.abs(np.c_[hA_raw, hB_raw]), axis=1)
row_scale[row_scale == 0] = 1      # avoid division by zero → keeps row at 0
hA, hB = hA_raw / row_scale[:,None], hB_raw / row_scale[:,None]
hD = hB - hA

fig_hA = make_heatmap(hA, f"Periodicity – {COND_A} (split)")
fig_hB = make_heatmap(hB, f"Periodicity – {COND_B} (split)")
fig_hD = make_heatmap(hD,  "Δ Periodicity (B–A) · split")
fig_hA.show(); fig_hB.show(); fig_hD.show()

# ───────────────────── multi-panel scatter + curves ────────────────── #
fig = make_subplots(
    rows=6, cols=1, shared_xaxes=True, vertical_spacing=0.02,
    row_heights=[0.22,0.07,0.07,0.07,0.07,0.22]
)
for tr in make_scatter(dfA, CLR_A): fig.add_trace(tr, row=1, col=1)
fig.add_trace(go.Scatter(x=cent_occ, y=fA_occ, name=f"{COND_A} core",
                         line=dict(color=CLR_A,width=2)), row=2,col=1)
fig.add_trace(go.Scatter(x=cent_occ, y=fB_occ, name=f"{COND_B} core",
                         line=dict(color=CLR_B,width=2)), row=2,col=1)
fig.add_trace(go.Bar(x=cent_occ, y=d_occ, marker_color=col_occ,
                     showlegend=False), row=3,col=1)
fig.add_trace(go.Scatter(x=cent_ctr, y=fA_ctr, name=f"{COND_A} centre",
                         line=dict(color=CLR_A,width=2)), row=4,col=1)
fig.add_trace(go.Scatter(x=cent_ctr, y=fB_ctr, name=f"{COND_B} centre",
                         line=dict(color=CLR_B,width=2)), row=4,col=1)
fig.add_trace(go.Bar(x=cent_ctr, y=d_ctr, marker_color=col_ctr,
                     showlegend=False), row=5,col=1)
for tr in make_scatter(dfB, CLR_B): fig.add_trace(tr, row=6, col=1)

fig.update_xaxes(range=[BIN_MIN, PLOT_WINDOW],
                 title="Genomic rel_pos (bp)", row=6,col=1)
fig.update_yaxes(title="Read #", row=1,col=1); fig.update_yaxes(title="Read #", row=6,col=1)
fig.update_yaxes(title="% reads\ncore",   tickformat=".0%", row=2,col=1)
fig.update_yaxes(title="Δ core",          tickformat=".0%", row=3,col=1)
fig.update_yaxes(title="% reads\ncentre", tickformat=".0%", row=4,col=1)
fig.update_yaxes(title="Δ centre",        tickformat=".0%", row=5,col=1)

fig.update_layout(template="plotly_white", height=950, width=1000,
    title=(f"Split fragments – A:{COND_A} | B:{COND_B} "
           f"(filter={FILTER_MODE}, thresh={THRESH_DIST}, motifN={MOTIF_FILTER_N})"))
fig.show()

# ───────────────────── consensus offsets (positive side) ──────────── #
# build static consensus seeds (positive only)
consensus = np.arange(SPLIT_CONS_START_POS, PLOT_WINDOW+1, SPLIT_NRL)
consensus = list(consensus)  # keep original order

def calc_condition_consensus(df, static_pos, half_win=SPLIT_CONS_ADAPT_WINDOW):
    out = {}
    for pos in static_pos:
        hits = [c for r in df.itertuples(index=False)
                   for c in r.nuc_centers if abs(c-pos) <= half_win]
        out[pos] = np.mean(hits) if hits else pos
    return out

if SPLIT_CONSENSUS_MODE == "per_condition_mean":
    CONS_REF = {COND_A: calc_condition_consensus(dfA_all, consensus),
                COND_B: calc_condition_consensus(dfB_all, consensus)}
else:   # "static"
    CONS_REF = {COND_A:{p:p for p in consensus},
                COND_B:{p:p for p in consensus}}

def collect_offsets(df, cond_key):
    ref_map = CONS_REF[cond_key]
    offs = {p: [] for p in consensus}
    for r in df.itertuples(index=False):
        for cen in r.nuc_centers:
            idx = np.argmin(np.abs(np.asarray(consensus) - cen))
            ref = ref_map[consensus[idx]]
            diff = abs(cen - ref)
            if diff <= CONS_WINDOW_BP:
                offs[consensus[idx]].append(diff)
    return offs

offsets = {COND_A: collect_offsets(dfA_all, COND_A),
           COND_B: collect_offsets(dfB_all, COND_B)}

cat_lbls = [str(c) for c in consensus]

# debug print of consensus positions for first 5 consensus positions
dbg(f"[SANITY] consensus positions:")
dbg(f"[SANITY] {COND_A}: {[round(CONS_REF[COND_A][p],1) for p in consensus]}")


# ╔════════════════════════════════════════════════════════════════╗
# ║  Consensus offsets – dot + IQR whisker  +  coverage bar row    ║
# ╚════════════════════════════════════════════════════════════════╝
# ────────────────────────── IQR MEDIAN DOTS + OFFSET + ANNOTATION ──────────────────────────
import numpy as np
import plotly.graph_objects as go

# small fraction of the NRL for horizontal jitter
DOT_OFFSET = NRL * 0.1

fig_iqr = go.Figure()

for idx, (cond, clr) in enumerate(zip([COND_A, COND_B], [CLR_A, CLR_B])):
    # compute quartiles for each consensus position
    meds, q1s, q3s = [], [], []
    for cpos in consensus:
        vals = offsets[cond][cpos]
        if vals:
            q1, med, q3 = np.percentile(vals, [25, 50, 75])
        else:
            q1 = med = q3 = np.nan
        q1s.append(q1); meds.append(med); q3s.append(q3)

    # error bars for whiskers
    err_up   = np.asarray(q3s) - np.asarray(meds)
    err_down = np.asarray(meds) - np.asarray(q1s)

    # numeric x positions with slight offset left/right
    x_pos = [c - DOT_OFFSET if idx == 0 else c + DOT_OFFSET for c in consensus]

    fig_iqr.add_trace(go.Scatter(
        x            = x_pos,
        y            = meds,
        mode         = "markers",
        name         = cond,
        marker       = dict(size=10, color=clr),
        error_y      = dict(
                          type="data",
                          symmetric=False,
                          array=err_up,
                          arrayminus=err_down,
                          thickness=1,
                          color=clr
                       ),
        hovertemplate = (
            "cons=%{customdata[2]}<br>" +
            "Q1=%{customdata[0]:.1f}  median=%{y:.1f}  Q3=%{customdata[1]:.1f}"
        ),
        customdata = np.column_stack([q1s, q3s, consensus])
    ))

# vertical line at x=0
fig_iqr.add_shape(
    type="line",
    x0=0, x1=0,
    y0=0, y1=1,
    xref="x", yref="paper",
    line=dict(color="black", dash="dash")
)
fig_iqr.add_annotation(
    x=0, y=1.05,
    xref="x", yref="paper",
    text="best consensus match motif",
    showarrow=False,
    font=dict(color="black")
)

# relabel x-axis as categorical
fig_iqr.update_xaxes(
    type="linear",
    tickmode="array",
    tickvals=consensus,
    ticktext=[str(c) for c in consensus],
    title="Consensus position (bp)"
)
fig_iqr.update_layout(
    template="plotly_white",
    width=1000,
    height=500,
    # drop gridlines
    xaxis=dict(showgrid=False, zeroline=False),
    yaxis=dict(showgrid=False, zeroline=False),
    title=(
        f"Offsets to consensus (neg_start={neg_start}, "
        f"pos_start={pos_start}, NRL={NRL}, window=±{CONS_WINDOW_BP} bp)"
    ),
    yaxis_title="Centre – consensus (bp)"
)
fig_iqr.show()


# ───────────────────── EXPORTS: PNG, SVG, CSVs ───────────────────── #
from pathlib import Path
import pandas as pd

# ---------- config ----------
OUT_DIR   = Path("/Data1/git/meyer-nanopore/scripts/analysis/images_20250604")
OUT_DIR.mkdir(parents=True, exist_ok=True)

RUN_TAG   = "split"
BASE_TAG  = f"iqr_{COND_A}_vs_{COND_B}_{RUN_TAG}"


# ---------- tidy CSV: consensus-offsets (one row per offset) ----------
rows_offsets = [
    {"condition": cond, "consensus_bp": cpos, "offset_bp": off}
    for cond in [COND_A, COND_B]
    for cpos, offs in offsets[cond].items()
    for off in offs
]
pd.DataFrame(rows_offsets).to_csv(
    OUT_DIR / f"offsets_{COND_A}_vs_{COND_B}_{RUN_TAG}.csv", index=False
)

# ---------- tidy CSV: median + IQR ----------
rows_quart = []
for cond, clr in zip([COND_A, COND_B], [CLR_A, CLR_B]):
    meds, q1s, q3s = [], [], []
    for cpos in consensus:
        vals = offsets[cond][cpos]
        if vals:
            q1, med, q3 = np.percentile(vals, [25, 50, 75])
        else:
            q1 = med = q3 = np.nan
        rows_quart.append({
            "condition":  cond,
            "consensus_bp": cpos,
            "q1":       q1,
            "median":   med,
            "q3":       q3
        })
pd.DataFrame(rows_quart).to_csv(
    OUT_DIR / f"iqr_stats_{COND_A}_vs_{COND_B}_{RUN_TAG}.csv", index=False
)

print(f"[EXPORT] saved PNG, SVG, offsets CSV, and IQR CSV → {OUT_DIR}")
# ─────────────────────────────────────────────────────────────────── #


In [None]:
# ───────────────── Motif‑accessibility (0‑4 empty motifs) ─────────────────
import numpy as np, pandas as pd, plotly.graph_objects as go

# ---------------- USER KNOBS ----------------
CONDITIONS   = analysis_cond
THRESH_DIST  = 70          # bp
NEAREST_N    = 4
BUFFER_BP    = 20
SHOW_DEBUG   = True
# -------------------------------------------

def dbg(m):
    if SHOW_DEBUG:
        print(m)

# ── helper: count empty motifs (return None if read fails span check) ──
def empty_cnt(row):
    motifs = list(row.motif_rel_start)
    if not motifs or not row.nuc_centers:
        return None
    nearest = sorted(motifs, key=lambda x: abs(x))[:NEAREST_N]
    span_ok = (row.rel_read_start <= min(nearest)-BUFFER_BP) and \
              (row.rel_read_end   >= max(nearest)+BUFFER_BP)
    if not span_ok:
        return None
    cnt = sum(
        np.min(np.abs(np.asarray(row.nuc_centers) - m)) > THRESH_DIST
        for m in nearest
    )
    return cnt

# ── collect fractions: cond × k (0…4) → list[fractions per motif type] ──
data = {c:{k:[] for k in range(NEAREST_N+1)} for c in CONDITIONS}

for cond in CONDITIONS:
    df_c = filtered_reads_df[filtered_reads_df["condition"] == cond]
    for typ in df_c["type"].unique():
        df_t = df_c[df_c["type"] == typ]
        eligible = 0
        counts   = np.zeros(NEAREST_N+1, int)
        for row in df_t.itertuples(index=False):
            k = empty_cnt(row)
            if k is None:
                continue
            k = min(k, NEAREST_N)
            counts[k] += 1
            eligible  += 1
        if eligible == 0:
            continue
        for k in range(NEAREST_N+1):
            if counts[k]:
                data[cond][k].append(counts[k]/eligible)
        dbg(f"{cond} | {typ}: eligible={eligible}  " +
            ", ".join(f"k{k}={counts[k]}" for k in range(NEAREST_N+1)))

# ── make grouped boxplots ──
palette = ["#1b9e77","#d95f02","#7570b3","#e7298a","#66a61e"]  # 0→4
fig = go.Figure()

for k in range(NEAREST_N, -1, -1):       # plot tall boxes last→first
    for cond in CONDITIONS:
        vals = data[cond][k]
        if not vals:
            continue
        fig.add_trace(go.Box(
            x=[cond]*len(vals),
            y=vals,                       # fractions (0–1)
            name=f"{k} motifs accessible",
            legendgroup=f"k{k}",
            marker_color=palette[k],
            boxpoints="all", jitter=0.3, pointpos=0,
            hovertemplate="%{y:.1%}"
        ))

fig.update_layout(
    template="plotly_white",
    boxmode="group",
    title=(f"% reads with 0–{NEAREST_N} motifs accessible "
           f"(nuc centre > {THRESH_DIST} bp; reads span motifs±{BUFFER_BP} bp)"),
    xaxis_title="Condition",
    yaxis_title="% reads with k motifs accessible",
    height=550, width=850,
    showlegend=True
)
fig.update_yaxes(tickformat=".0%")
fig.show()


In [None]:
import numpy as np
from scipy.stats import gamma
from scipy.ndimage import uniform_filter1d
from scipy.signal import find_peaks
import plotly.graph_objects as go

# ── PARAMETERS ───────────────────────────────────────────────────────────────
DEBUG                  = True
bin_size               = CENTER_HIST_BIN
bins                   = np.arange(-REL_POS_RANGE, REL_POS_RANGE + bin_size, bin_size)
bin_centres            = (bins[:-1] + bins[1:]) / 2
SMOOTH_BINS            = 5
NEAREST_CENTER_THRESH  = 70
EB_ALPHA0              = 1.0
EB_BETA0               = 1.0
MIN_PEAK_DIST_BP       = 120   # minimum distance between consensus peaks (bp)
MAX_PEAK_GAP_BP        = 200   # maximum allowed gap before filling (bp)

conds                  = sorted(filtered_reads_df["condition"].unique())
# pick which conditions to use for consensus:
# e.g. conds               → all conditions
#      [analysis_cond[0]]  → just the first
CONSENSUS_CONDS        = conds

# ── DEBUG UTILITY ─────────────────────────────────────────────────────────────
def dbg(msg):
    if DEBUG:
        print(f"[DEBUG] {msg}")

# ── 1) COMPUTE POSTERIORS ─────────────────────────────────────────────────────
def compute_posteriors(df):
    dbg(f"compute_posteriors: {len(df)} reads")
    counts = {c: np.zeros(len(bin_centres), dtype=int) for c in conds}
    denom  = {c: np.zeros(len(bin_centres), dtype=int) for c in conds}

    for c in conds:
        grp = df[df["condition"]==c]
        dbg(f"  '{c}': {len(grp)} reads")
        for row in grp.itertuples():
            rel = np.asarray(row.rel_pos)
            if rel.size==0: continue
            covered = (bins[:-1]>=rel.min()) & (bins[1:]<=rel.max())
            denom[c][covered] += 1
            centers = np.asarray(row.nuc_centers)
            if centers.size==0: continue
            idxs = np.floor((centers - bins[0])/bin_size).astype(int)
            idxs = idxs[(idxs>=0)&(idxs<len(bin_centres))]
            for i in np.unique(idxs):
                counts[c][i] += 1

    post_mean, post_lower, post_upper = {}, {}, {}
    for c in conds:
        k = counts[c]; n = denom[c]
        dbg(f"  posterior '{c}': k.sum()={k.sum()}, n.sum()={n.sum()}")
        a_post = EB_ALPHA0 + k
        b_post = EB_BETA0 + n
        m = a_post / b_post
        l = gamma.ppf(0.025, a=a_post, scale=1/b_post)
        u = gamma.ppf(0.975, a=a_post, scale=1/b_post)
        post_mean [c] = uniform_filter1d(m, size=SMOOTH_BINS, mode="nearest")
        post_lower[c] = uniform_filter1d(l, size=SMOOTH_BINS, mode="nearest")
        post_upper[c] = uniform_filter1d(u, size=SMOOTH_BINS, mode="nearest")
    return post_mean, post_lower, post_upper

pm_all, pl_all, pu_all = compute_posteriors(filtered_reads_df)

# ── 2) FIND CONSENSUS FROM PEAKS IN POSTERIOR MEAN ────────────────────────────
# aggregate the posterior means for the chosen conditions
agg = np.zeros_like(bin_centres)
for c in CONSENSUS_CONDS:
    agg += pm_all[c]
agg /= len(CONSENSUS_CONDS)
dbg(f"Aggregated posterior mean over {CONSENSUS_CONDS}")

# convert bp thresholds into bin‐units
min_dist_bins = int(np.ceil(MIN_PEAK_DIST_BP / bin_size))
max_gap_bp    = MAX_PEAK_GAP_BP

# find initial peaks
peak_idxs    = find_peaks(agg, distance=min_dist_bins)[0]
consensus    = bin_centres[peak_idxs]
dbg(f"Initial consensus peaks: {consensus}")

# fill gaps > MAX_PEAK_GAP_BP
while True:
    diffs = np.diff(consensus)
    gaps  = np.where(diffs > max_gap_bp)[0]
    if gaps.size==0:
        break
    for i in reversed(gaps):
        lo, hi = consensus[i], consensus[i+1]
        mask    = (bin_centres>lo)&(bin_centres<hi)
        if mask.any():
            sub_vals = agg[mask]
            sub_pos  = bin_centres[mask]
            newp     = sub_pos[np.argmax(sub_vals)]
            dbg(f"  Filling gap ({lo},{hi}) with peak at {newp}")
        else:
            newp = (lo+hi)/2
            dbg(f"  Filling gap ({lo},{hi}) with midpoint {newp}")
        consensus = np.insert(consensus, i+1, newp)
dbg(f"Final consensus centers: {consensus}")

# ── 3) FORM SUBSETS FOR BOXPLOTS ────────────────────────────────────────────────
def has_valid(c): return np.asarray(c).size>0
valid_df = filtered_reads_df[filtered_reads_df["nuc_centers"].apply(has_valid)]

nearest = np.array([np.min(np.abs(c)) for c in valid_df["nuc_centers"]])
subsets = [
    ("All reads",      valid_df),
    (f"Nearest < {NEAREST_CENTER_THRESH}", valid_df[nearest<NEAREST_CENTER_THRESH]),
    (f"Nearest ≥ {NEAREST_CENTER_THRESH}", valid_df[nearest>=NEAREST_CENTER_THRESH]),
]
conds = sorted(valid_df["condition"].unique())

# ── 4) PLOT TWO BOXPLOTS PER SUBSET ────────────────────────────────────────────
for title, df in subsets:
    dbg(f"\nSubset '{title}': {len(df)} reads")

    # a) Inter-nuc distances vs consensus
    data_mid = {c:[] for c in conds}
    data_gap = {c:[] for c in conds}
    for c in conds:
        grp = df[df["condition"]==c]
        for centers in grp["nuc_centers"]:
            arr = np.sort(np.asarray(centers))
            if arr.size<2: continue
            diffs = np.diff(arr)
            mids  = (arr[:-1]+arr[1:])/2
            idxs  = np.argmin(np.abs(mids[:,None]-consensus[None,:]), axis=1)
            for i,d in zip(idxs,diffs):
                data_mid[c].append(consensus[i])
                data_gap[c].append(d)
        dbg(f"  {c}: {len(data_gap[c])} gaps")

    fig1 = go.Figure()
    for i,c in enumerate(conds):
        fig1.add_trace(go.Box(
            x=data_mid[c], y=data_gap[c], name=c,
            marker_color=DEFAULT_READ_CLRS[i%len(DEFAULT_READ_CLRS)],
            boxpoints="outliers"
        ))
    fig1.update_layout(
        template="plotly_white",
        title=f"{title}: Inter-nuc distances (consensus from {CONSENSUS_CONDS})",
        xaxis_title="Consensus center (bp)",
        yaxis_title="Gap size (bp)"
    )
    fig1.show()

    # b) Position-precision (|offset| vs consensus)
    data_off = {c: [] for c in conds}
    for c in conds:
        grp = df[df["condition"] == c]
        for centers in grp["nuc_centers"]:
            arr = np.asarray(centers)
            if arr.size == 0:
                continue
            idxs = np.argmin(np.abs(arr[:, None] - consensus[None, :]), axis=1)
            for i, val in zip(idxs, arr):
                offset = abs(val - consensus[i])
                data_off[c].append((consensus[i], offset))
        dbg(f"  {c}: {len(data_off[c])} absolute offsets")

    fig2 = go.Figure()
    for i, c in enumerate(conds):
        # unpack x and y from the stored tuples
        xvals = [center for center, _ in data_off[c]]
        yvals = [offset for _, offset in data_off[c]]
        fig2.add_trace(go.Box(
            x=xvals,
            y=yvals,
            name=c,
            marker_color=DEFAULT_READ_CLRS[i % len(DEFAULT_READ_CLRS)],
            boxpoints="outliers"
        ))
    fig2.update_layout(
        template="plotly_white",
        title=f"{title}: Precision vs. consensus (|spread|)",
        xaxis_title="Consensus center (bp)",
        yaxis_title="|Offset from consensus| (bp)"
    )
    fig2.show()


In [None]:
###############################################################################
#  Pairwise Δ‑distance box‑plots + separate Median‑IQR line plots
#  • Conditions are colored:
#      analysis_cond[0] → blue (#1F78B4)
#      analysis_cond[1] → red  (#E31A1C)
#      analysis_cond[2] → green (#33A02C)
#  • Box‑plots remain unchanged visually (no fill, transparent bg)
#  • Median‑IQR figure is a separate plot below each box‑plot
###############################################################################
import pandas as pd
import itertools
import plotly.express as px
import plotly.graph_objects as go
from tqdm.auto import tqdm  # progress bar (optional)

# ─────────────────── CONFIG ─────────────────── #
MAX_FLANK      = 3      # nucleosomes on each side
DEBUG_PROGRESS = True   # show tqdm bar
FIG_TEMPLATE   = "plotly_white"
Y_MIN          = 100    # fixed lower bound for y‑axis

# ── Color mapping for conditions ────────────────────────────────────────────
# Assumes `analysis_cond` is a list like: [ "condA", "condB", "condC", ... ]
# and that you want exactly 3 conditions colored as defined below.
COLOR_HEX = {
    analysis_cond[0]: "#1F78B4",  # blue
    analysis_cond[1]: "#E31A1C",  # red
    analysis_cond[2]: "#33A02C"   # green
}
# ────────────────────────────────────────────────────────────────────────────── #

# ---------- 1) helpers ------------------------------------------------------- #
def _adjacent_distances(centers, max_flank):
    centers = sorted(centers)
    negs = [c for c in centers if c < 0]
    poss = [c for c in centers if c > 0]
    # require ≥max_flank on one side AND ≥1 on the other
    if not ((len(negs) >= max_flank and len(poss) >= 1) or
            (len(poss) >= max_flank and len(negs) >= 1)):
        return None
    n1, p1 = negs[-1], poss[0]
    d = {"n1": n1, "p1": p1}
    # negative side adjacent pairs
    for k in range(max_flank, 1, -1):
        if len(negs) >= k:
            d[f"n-{k} to n-{k-1}"] = abs(negs[-k] - negs[-k + 1])
    # dyad gap
    d["n-1 to n+1"] = abs(n1 - p1)
    # positive side adjacent pairs
    for i in range(1, max_flank):
        if len(poss) >= i + 1:
            d[f"n+{i} to n+{i+1}"] = abs(poss[i] - poss[i - 1])
    return d

def build_per_read_df(reads_df, max_flank):
    records = []
    for row in tqdm(
        reads_df.itertuples(index=False),
        total=len(reads_df),
        disable=not DEBUG_PROGRESS,
        desc="scanning reads"
    ):
        d = _adjacent_distances(row.nuc_centers, max_flank)
        if d is not None:
            records.append({
                "read_id":   row.read_id,
                "condition": row.condition,
                **d
            })
    return pd.DataFrame(records)

def make_category_order(max_flank):
    neg = [f"n-{k} to n-{k-1}" for k in range(max_flank, 1, -1)]
    pos = [f"n+{i} to n+{i+1}" for i in range(1, max_flank)]
    return neg + ["n-1 to n+1"] + pos

def median_iqr_stats(df, cond, categories):
    rows = []
    for cat in categories:
        if cat not in df.columns:
            continue
        vals = df[cat].dropna()
        if vals.empty:
            continue
        rows.append({
            "condition": cond,
            "category":  cat,
            "q1":        vals.quantile(0.25),
            "median":    vals.quantile(0.5),
            "q3":        vals.quantile(0.75)
        })
    return pd.DataFrame(rows)

def hex_to_rgb(hex_str):
    """Convert '#RRGGBB' → (r, g, b) tuple of ints."""
    h = hex_str.lstrip("#")
    return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))

def add_iqr_band(fig, x, q1, q3, rgb):
    fig.add_trace(go.Scatter(
        x=x + x[::-1],
        y=q3 + q1[::-1],
        fill="toself",
        fillcolor=f"rgba{rgb + (0.15,)}",
        line=dict(color="rgba(0,0,0,0)"),
        showlegend=False,
        hoverinfo="skip"
    ))

def add_median_line(fig, x, med, rgb, name):
    fig.add_trace(go.Scatter(
        x=x,
        y=med,
        mode="lines+markers",
        line=dict(color=f"rgb{rgb}"),
        name=name
    ))

# ---------- 2) build per‑read DataFrame ------------------------------------- #
df_reads = build_per_read_df(filtered_reads_df, MAX_FLANK)
if df_reads.empty:
    raise ValueError(f"No reads span ±{MAX_FLANK} nucleosomes!")

cat_order = make_category_order(MAX_FLANK)

# compute global Q3 for y‑axis upper bound
all_vals = df_reads[[c for c in cat_order if c in df_reads.columns]]
global_q3 = all_vals.quantile(0.75).max()

# ---------- 3) pairwise plotting ------------------------------------------------ #
all_conditions = sorted(df_reads["condition"].unique())
pairs = list(itertools.combinations(all_conditions, 2))

for c1, c2 in pairs:
    sub = df_reads[df_reads["condition"].isin([c1, c2])]

    # -- 3a) Box‑plot figure (no changes here, just add color map) -------- #
    long_df = sub.melt(
        id_vars=["read_id", "condition"],
        value_vars=[c for c in cat_order if c in sub.columns],
        var_name="category",
        value_name="distance"
    )
    fig_box = px.box(
        long_df,
        x="category",
        y="distance",
        color="condition",
        category_orders={"category": cat_order},
        color_discrete_map=COLOR_HEX,  # enforce specific colors
        template=FIG_TEMPLATE,
        points=False,
        title=f"Box‑plots: {c1} vs {c2}"
    )
    # remove box fill and set transparent background
    fig_box.update_traces(fillcolor="rgba(0,0,0,0)")
    fig_box.update_layout(
        plot_bgcolor="rgba(0,0,0,0)",
        paper_bgcolor="rgba(0,0,0,0)",
        xaxis_title="Adjacent‑pair category",
        yaxis_title="Distance (bp)",
        width=900
    )
    #fig_box.update_yaxes(range=[Y_MIN, global_q3])
    fig_box.show()

    # -- 3b) Median‑line + IQR band figure (separate) ------------------ #
    fig_med = go.Figure()
    for cond in [c1, c2]:
        dfc = sub[sub["condition"] == cond]
        stats = median_iqr_stats(dfc, cond, cat_order)
        if stats.empty:
            continue
        x   = stats["category"].tolist()
        med = stats["median"].tolist()
        q1  = stats["q1"].tolist()
        q3  = stats["q3"].tolist()
        rgb = hex_to_rgb(COLOR_HEX[cond])

        add_iqr_band(fig_med, x, q1, q3, rgb)
        add_median_line(fig_med, x, med, rgb, f"{cond} median")

    fig_med.update_layout(
        title=f"Median ± IQR: {c1} vs {c2}",
        xaxis_title="Adjacent‑pair category",
        yaxis_title="Distance (bp)",
        template=FIG_TEMPLATE,
        plot_bgcolor="rgba(0,0,0,0)",
        paper_bgcolor="rgba(0,0,0,0)",
        width=900
    )
    #fig_med.update_yaxes(range=[Y_MIN, global_q4])
    fig_med.update_xaxes(categoryorder="array", categoryarray=cat_order)
    fig_med.show()


In [None]:
###############################################################################
#  INTER‑NUCLEOSOME‑DISTANCE HISTOGRAMS  +  COMBINED KDE
#    • One histogram (% of all inter‑nucleosome gaps) per condition
#    • Combined KDE overlay (all conditions) on a separate figure
#
#  Configurable:
#      DIST_LOWER_BOUND   – exclude gaps < this (bp)          (default 147)
#      DIST_UPPER_BOUND   – exclude gaps > this (bp)          (default 400)
#      DIST_BIN_WIDTH     – histogram bin width (bp)          (default 1)
#      KDE_BW             – gaussian_kde bandwidth            (default 0.005)
#      REL_POS_WINDOWS    – list[(start,end)] to *keep*       (default whole range)
###############################################################################

# ─────────────────── 8A)  CONFIG ─────────────────── #
DIST_LOWER_BOUND   = 100          # bp
DIST_UPPER_BOUND   = 400          # bp
DIST_BIN_WIDTH     = 1            # bp
KDE_BW             = 0.01        # gaussian_kde bw_method

# Windows in rel_pos coordinates that *keep* nucleosome centres.
# Example: [(-1000, -200), (200, 1000)]
# Leave empty → keep all centres in [-REL_POS_RANGE, +REL_POS_RANGE].
REL_POS_WINDOWS    = [(-1000, -200), (200, 1000)]

# ─────────────────── 8B)  HELPERS ─────────────────── #
def _centres_in_windows(centres):
    """Filter centres so they fall *inside* any window in REL_POS_WINDOWS."""
    if not REL_POS_WINDOWS:                    # keep all
        return centres
    keep = []
    for c in centres:
        for lo, hi in REL_POS_WINDOWS:
            if lo <= c <= hi:
                keep.append(c)
                break
    return sorted(keep)

def _adjacent_diffs(sorted_centres):
    """Return inter‑nuc distances between *adjacent* centres."""
    if len(sorted_centres) < 2:
        return []
    arr = np.diff(sorted_centres)
    # keep distances in [DIST_LOWER_BOUND, DIST_UPPER_BOUND]
    mask = (arr >= DIST_LOWER_BOUND) & (arr <= DIST_UPPER_BOUND)
    return arr[mask].tolist()

def _collect_inter_nuc_dists():
    """
    Build dict[cond] → list[distances] with all qualifying gaps.
    Distances are measured *within the same read* and *within the same window*.
    """
    dbg("collecting inter‑nucleosome distances …")
    by_cond = {c: [] for c in filtered_reads_df["condition"].unique()}

    for row in filtered_reads_df.itertuples():
        cond = row.condition
        centres = _centres_in_windows(row.nuc_centers)
        # Split centres by window so gaps across windows are ignored
        if REL_POS_WINDOWS:
            for lo, hi in REL_POS_WINDOWS:
                win_centres = [c for c in centres if lo <= c <= hi]
                by_cond[cond].extend(_adjacent_diffs(win_centres))
        else:
            by_cond[cond].extend(_adjacent_diffs(centres))

    for cond, lst in by_cond.items():
        dbg(f"[{cond}] kept {len(lst)} gaps", always=False)
    return by_cond

# ─────────────────── 8C)  PLOTTER ─────────────────── #
def plot_inter_nuc_dist_hist_and_kde(
    lower=DIST_LOWER_BOUND, upper=DIST_UPPER_BOUND,
    bin_width=DIST_BIN_WIDTH, kde_bw=KDE_BW):
    """
    One bar‑histogram per condition (rows), y = % of all gaps in that condition.
    Combined KDE overlay (all conds) beneath.
    """
    dbg("plotting inter‑nucleosome distance histograms …")
    dists_by_cond = _collect_inter_nuc_dists()
    conds = sorted(dists_by_cond)

    # ---------- build histogram data ---------- #
    bins = np.arange(lower, upper + bin_width, bin_width)
    bin_centres = (bins[:-1] + bins[1:]) / 2
    n_bins = len(bin_centres)

    pct_by_cond = {}
    ymax = 0.0
    for cond in conds:
        dist_arr = np.asarray(dists_by_cond[cond])
        hist, _ = np.histogram(dist_arr, bins=bins)
        total = hist.sum()
        pct = hist / total if total else np.zeros_like(hist, dtype=float)
        pct_by_cond[cond] = pct
        ymax = max(ymax, pct.max())

    # ---------- figure 1 – histograms ---------- #
    rows = len(conds)
    hist_fig = make_subplots(rows=rows, cols=1, shared_xaxes=True,
                             vertical_spacing=0.03,
                             specs=[[{}] for _ in conds])

    for r, cond in enumerate(conds, start=1):
        colour = DEFAULT_READ_CLRS[(r-1) % len(DEFAULT_READ_CLRS)]
        hist_fig.add_trace(
            go.Bar(
                x=bin_centres,
                y=pct_by_cond[cond],
                width=bin_width,
                marker=dict(color="rgba(0,0,0,0)",
                            line=dict(color=colour)),
                name=cond,
                showlegend=False
            ),
            row=r, col=1
        )
        hist_fig.update_yaxes(
            range=[0, ymax * 1.05],
            tickformat=".1%",
            title_text=cond,
            row=r, col=1
        )

    hist_fig.update_xaxes(range=[lower, upper], dtick=25,
                          title_text="Inter‑nucleosome distance (bp)")
    hist_fig.update_layout(
        template=FIG_TEMPLATE,
        width=900,
        height=250 * rows,
        title="% of inter‑nucleosome gaps per condition"
    )
    hist_fig.show()

    # ---------- figure 2 – combined KDE ---------- #
    dbg("plotting KDE overlay …")
    x_grid = np.linspace(lower, upper, KDE_POINTS)
    kde_fig = go.Figure()
    for idx, cond in enumerate(conds):
        if not dists_by_cond[cond]:
            continue
        colour = DEFAULT_READ_CLRS[idx % len(DEFAULT_READ_CLRS)]
        kde = gaussian_kde(dists_by_cond[cond], bw_method=kde_bw)
        kde_fig.add_trace(
            go.Scatter(
                x=x_grid,
                y=kde.evaluate(x_grid),
                mode="lines",
                line=dict(width=2, color=colour),
                name=cond
            )
        )

    kde_fig.update_layout(
        template=FIG_TEMPLATE,
        width=900,
        height=350,
        title="Inter‑nucleosome distance density (KDE)",
        xaxis_title="Inter‑nucleosome distance (bp)",
        yaxis_title="Density",
        legend=dict(orientation="h", yanchor="bottom", y=1.02,
                    xanchor="right", x=1)
    )
    kde_fig.update_xaxes(range=[lower, upper], dtick=25)
    kde_fig.show()

    dbg("…done!", always=True)

# ─────────────────── 8D)  CALL PLOTTER ─────────────────── #
# Uncomment this line (or call elsewhere) to generate the figures:
plot_inter_nuc_dist_hist_and_kde()


In [None]:
###############################################################################
#  CONSENSUS NUCLEOSOME PEAKS  ➜  Δ‑position box‑plot (relative to 0th condition)
#  + DEBUG KDE PLOTS
#
#  • For *one* type (the first in sorted order), produce a debug KDE plot
#    for each condition with vertical lines at the chosen summits.
#
#  • Then run the usual “find summits per type–condition” logic, build
#    peaks_long_df, compute Δ_pos relative to the “0th” (baseline) condition
#    for each (type, nuc_index), and draw the box‑plot.
###############################################################################
import numpy as np
import pandas as pd
import itertools
from scipy.signal import find_peaks
from scipy.stats  import gaussian_kde
import plotly.graph_objects as go
import plotly.express as px
from IPython.display import display

# ------------ parameters -----------------------------------------------------
kde_bw      = 0.005                      # same bandwidth used before
x_grid      = np.linspace(-REL_POS_RANGE, REL_POS_RANGE, KDE_POINTS)
peak_sep    = CORE_SIZE                  # 147 bp minimum distance between summits
peak_height = 0                          # accept all peaks, filter by height later

# ------------ helper: peak selection ----------------------------------------
def _select_peak_centers(x: np.ndarray, y: np.ndarray):
    """
    Identify all peaks in the KDE curve `y(x)` using `find_peaks`, then for each peak:
      1. Let h = y[p] be the peak height at index p.
      2. Define half_h = h/2.
      3. Scan left from p until y dips below half_h (or we reach array start) → index l.
      4. Scan right from p until y dips below half_h (or we reach array end) → index r.
      5. The “center” of that peak is (x[l] + x[r]) / 2.
    Finally, keep only those peak‐centers that remain ≥ CORE_SIZE apart,
    selecting the tallest peaks first.

    Returns
    -------
    List of floating‐point center positions (in the same coordinate system as x).
    """
    # 1) find all local maxima (above zero)
    peaks, props = find_peaks(y, height=0.0)
    if peaks.size == 0:
        return []

    # 2) collect tuples (peak_index, peak_height)
    peak_info = sorted(
        [(p, props["peak_heights"][i]) for i, p in enumerate(peaks)],
        key=lambda x: -x[1]   # sort by height descending
    )

    selected_centers = []
    for p, h in peak_info:
        half_h = h * 0.5

        # 3) scan left from p until y < half_h (or we hit index 0)
        l = p
        while l > 0 and y[l] >= half_h:
            l -= 1
        # if we stepped one below half_h, step back up by one
        if y[l] < half_h and l < p:
            l += 1

        # 4) scan right from p until y < half_h (or we hit last index)
        r = p
        while r + 1 < len(y) and y[r] >= half_h:
            r += 1
        if y[r] < half_h and r > p:
            r -= 1

        # 5) compute the midpoint between x[l] and x[r]
        center_x = 0.5 * (x[l] + x[r])

        # 6) reject if within CORE_SIZE of any already‐accepted center
        if any(abs(center_x - c) < CORE_SIZE for c in selected_centers):
            continue

        selected_centers.append(center_x)

    # sort so that final list is left→right
    selected_centers.sort()
    return selected_centers

# ------------ pick the first type and plot debug KDE per condition ----------
all_types = sorted(filtered_reads_df["type"].unique())
if not all_types:
    raise ValueError("No types found in filtered_reads_df")
first_type = all_types[0]

for cond in sorted(filtered_reads_df["condition"].unique()):
    # collect all centres for this type–condition pair
    centres = list(itertools.chain.from_iterable(
        filtered_reads_df.loc[
            (filtered_reads_df["type"] == first_type) &
            (filtered_reads_df["condition"] == cond),
            "nuc_centers"
        ].tolist()
    ))
    if not centres:
        continue

    kde = gaussian_kde(centres, bw_method=kde_bw)
    density = kde.evaluate(x_grid)
    summits = _select_peak_centers(x_grid, density)

    # build debug plot
    fig_dbg = go.Figure()
    fig_dbg.add_trace(
        go.Scatter(
            x=x_grid,
            y=density,
            mode="lines",
            line=dict(width=2, color="blue"),
            name=f"KDE ({first_type}, {cond})"
        )
    )
    for s in summits:
        fig_dbg.add_vline(
            x=s,
            line=dict(color="red", width=1, dash="dash"),
            annotation_text=f"{int(np.round(s))}",
            annotation_position="top right"
        )
    fig_dbg.update_layout(
        width=900,
        height=300,
        title=f"Debug KDE – type={first_type}, condition={cond}",
        xaxis_title="Relative Position (bp)",
        yaxis_title="Density"
    )
    fig_dbg.show()

# ------------ collect consensus peaks across all types × conditions ----------
rows = []
for (typ, cond), grp in filtered_reads_df.groupby(["type", "condition"]):
    centres = list(itertools.chain.from_iterable(grp["nuc_centers"].tolist()))
    if not centres:
        continue

    kde = gaussian_kde(centres, bw_method=kde_bw)
    density = kde.evaluate(x_grid)
    peaks   = _select_peak_centers(x_grid, density)
    if not peaks:
        continue

    # assign nuc_index: n-1, n-2, … for negative; n+1, n+2, … for positive
    neg_peaks = [p for p in peaks if p < 0]
    pos_peaks = [p for p in peaks if p > 0]

    # sort negative in increasing order (more negative → smaller index)
    neg_peaks.sort()
    # but nuc_index n-1 should be the *closest* negative to zero, so reverse after sorting
    neg_peaks = neg_peaks[::-1]

    for i, p in enumerate(neg_peaks, start=1):
        rows.append((cond, typ, f"n-{i}", p))
    for i, p in enumerate(sorted(pos_peaks), start=1):
        rows.append((cond, typ, f"n+{i}", p))

peaks_long_df = pd.DataFrame(rows, columns=["condition", "type", "nuc_index", "nuc_pos"])
if peaks_long_df.empty:
    raise ValueError("No consensus peaks found in any type–condition group")

# ------------ compute Δ_pos relative to 0th (baseline) condition -------------
# Determine "0th condition" per type: the first in sorted order of conditions
conds_sorted = sorted(filtered_reads_df["condition"].unique())
baseline_condition = conds_sorted[1]  # treat this as 0th

# Build a lookup: (type, nuc_index) -> nuc_pos of baseline_condition
baseline_positions = {}
for (typ, cond), sub in peaks_long_df.groupby(["type", "condition"]):
    if cond != baseline_condition:
        continue
    for idx, row in sub.iterrows():
        key = (typ, row["nuc_index"])
        # if multiple baseline rows for same (typ, nuc_index), take their median/mean
        baseline_positions.setdefault(key, []).append(row["nuc_pos"])

# reduce lists to single baseline value (mean) per key
for key, vals in baseline_positions.items():
    baseline_positions[key] = float(np.mean(vals))

# Now subtract baseline from every row in peaks_long_df
def compute_delta(row):
    key = (row["type"], row["nuc_index"])
    base = baseline_positions.get(key, np.nan)
    return row["nuc_pos"] - base

peaks_long_df["Δ_pos"] = peaks_long_df.apply(compute_delta, axis=1)

# For rows whose baseline was missing (NaN), Δ_pos becomes NaN—they'll be dropped in plotting:
peaks_long_df = peaks_long_df.dropna(subset=["Δ_pos"]).copy()

# ------------ box‑plot of Δ_pos by nuc_index, colored by condition ------------
fig = px.box(
    peaks_long_df,
    x="nuc_index",
    y="Δ_pos",
    color="condition",
    points="all",
    title="Deviation from baseline (0th condition) nucleosome position",
    labels={"Δ_pos": "nuc_pos – baseline (bp)"}
)
fig.update_layout(
    template=FIG_TEMPLATE,
    width=900,
    height=600
)

# ensure x‑axis categories are ordered from smallest index to largest index:
unique_indices = sorted(
    peaks_long_df["nuc_index"].unique(),
    key=lambda s: int(s.replace("n", ""))
)
fig.update_xaxes(categoryorder="array", categoryarray=unique_indices)
fig.show()

# Optional: display first few rows for verification
display(peaks_long_df.head())

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

# ------------ existing boxplot code up to fig creation ------------
fig = px.box(
    peaks_long_df,
    x="nuc_index",
    y="Δ_pos",
    color="condition",
    points="all",
    title="Deviation from baseline (0th condition) nucleosome position",
    labels={"Δ_pos": "nuc_pos – baseline (bp)"}
)
fig.update_layout(
    template=FIG_TEMPLATE,
    width=900,
    height=600
)

unique_indices = sorted(
    peaks_long_df["nuc_index"].unique(),
    key=lambda s: int(s.replace("n", "").replace("+", "").replace("-", "")) * (1 if "+" in s else -1)
)
fig.update_xaxes(categoryorder="array", categoryarray=unique_indices)

# Extract color mapping: condition -> color
color_map = {}
for trace in fig.data:
    # Box traces have name equal to condition; one trace per condition
    cond_name = trace.name
    if cond_name not in color_map:
        # For box, the fillcolor attribute holds RGBA; extract the line color
        # trace.marker.color gives box color
        color_map[cond_name] = trace.marker.color

# ------------ compute median, Q1, Q3 per condition and nuc_index ------------
summary = (
    peaks_long_df
    .groupby(["condition", "nuc_index"])["Δ_pos"]
    .agg(Q1=lambda x: np.percentile(x, 25),
         median="median",
         Q3=lambda x: np.percentile(x, 75))
    .reset_index()
)

# ------------ add median lines and shaded IQR for each condition ------------
for cond in summary["condition"].unique():
    cond_df = summary[summary["condition"] == cond].copy()
    # Ensure ordering by nuc_index according to unique_indices
    cond_df["order"] = cond_df["nuc_index"].apply(lambda x: unique_indices.index(x))
    cond_df = cond_df.sort_values("order")

    idxs = cond_df["nuc_index"].tolist()
    q1_vals = cond_df["Q1"].tolist()
    q3_vals = cond_df["Q3"].tolist()
    med_vals = cond_df["median"].tolist()
    col = color_map.get(cond, "black")

    # Trace for Q1 (invisible)
    fig.add_trace(
        go.Scatter(
            x=idxs,
            y=q1_vals,
            mode="lines",
            line=dict(color='rgba(0,0,0,0)', width=0),
            showlegend=False,
            hoverinfo='skip'
        )
    )
    # Trace for Q3 with fill to Q1
    fig.add_trace(
        go.Scatter(
            x=idxs,
            y=q3_vals,
            mode="lines",
            line=dict(color='rgba(0,0,0,0)', width=0),
            fill='tonexty',
            fillcolor=f'rgba({int(col[1:3],16)},{int(col[3:5],16)},{int(col[5:7],16)},0.2)',
            showlegend=False,
            hoverinfo='skip'
        )
    )
    # Median line
    fig.add_trace(
        go.Scatter(
            x=idxs,
            y=med_vals,
            mode="lines+markers",
            line=dict(color=col, width=2),
            marker=dict(color=col, size=6),
            name=f"{cond} median",
            hoverinfo='skip'
        )
    )

# Show final figure
fig.show()



In [None]:
###############################################################################
#  CONSISTENCY OFFSETS USING *EXISTING* CONSENSUS PEAKS
#  • “Consensus” positions come from `peaks_long_df` computed in the KDE cell
#  • Toggle consensus scope with:
#        CONS_BY_CONDITION = False  → (type, nuc_index)     ← previous behaviour
#        CONS_BY_CONDITION = True   → (type, condition, nuc_index)
#
#  • Toggle distance definition with USE_CONS_DIST (unchanged).
#  • Optional: restrict analysis to SELECT_TYPES.
###############################################################################
from itertools import combinations
import numpy as np, pandas as pd, plotly.graph_objects as go
from tqdm.auto import tqdm

# ───── user‑configurable flags ───── #
SELECT_TYPES       = []       # empty → all types
CONS_BY_CONDITION  = True      # ⬅️ NEW: per‑condition consensus?
MAX_OFFSET         = 60        # bp inclusion window around consensus peak
USE_CONS_DIST      = True     # True → |centre – consensus| ; False → pairwise
FIG_WIDTH          = 900
DEBUG_CONS_OFFSETS = True

# --------------------------------------------------------------------------- #
# 1) BUILD a dict of consensus centres
#    key = (type, nuc_index)                     if CONS_BY_CONDITION is False
#    key = (type, condition, nuc_index)          if CONS_BY_CONDITION is True
# --------------------------------------------------------------------------- #
cons_dict = {}   # key → consensus_pos

if CONS_BY_CONDITION:
    # separate consensus per condition & type
    for (typ, cond, idx), sub in (
        peaks_long_df.groupby(["type", "condition", "nuc_index"])
    ):
        if SELECT_TYPES and typ not in SELECT_TYPES:
            continue
        cons_dict[(typ, cond, idx)] = float(sub["nuc_pos"].median())
else:
    # single consensus across conditions (original logic)
    for (typ, idx), sub in peaks_long_df.groupby(["type", "nuc_index"]):
        if SELECT_TYPES and typ not in SELECT_TYPES:
            continue
        # prefer baseline condition if available, else median of all
        base_rows = sub[sub["condition"] == baseline_condition]
        if not base_rows.empty:
            cons_dict[(typ, idx)] = float(base_rows["nuc_pos"].median())
        else:
            cons_dict[(typ, idx)] = float(sub["nuc_pos"].median())

# optional sanity print
if DEBUG_CONS_OFFSETS:
    scope = "(type,cond,nuc_index)" if CONS_BY_CONDITION else "(type,nuc_index)"
    print(f"[consensus peaks] scope = {scope}")
    for k, pos in sorted(cons_dict.items()):
        print(f"  {k}: {pos:8.1f} bp")

# --------------------------------------------------------------------------- #
# 2) SUBSET reads
# --------------------------------------------------------------------------- #
df_subset = (
    filtered_reads_df[filtered_reads_df["type"].isin(SELECT_TYPES)].copy()
    if SELECT_TYPES else
    filtered_reads_df.copy()
)
if df_subset.empty:
    raise ValueError(f"No reads found for type(s): {SELECT_TYPES}")

# --------------------------------------------------------------------------- #
# 3) COLLECT offsets
# --------------------------------------------------------------------------- #
records = []
for (cond, typ), grp in tqdm(
        df_subset.groupby(["condition", "type"]),
        desc="computing offsets", disable=not DEBUG_CONS_OFFSETS):

    all_centres = np.concatenate(grp["nuc_centers"].values)
    if all_centres.size == 0:
        continue

    # Build list of nuc_indices for this (typ, cond), depending on scope
    if CONS_BY_CONDITION:
        # Keys are (typ, cond, idx)
        nuc_indices = [
            key[2]
            for key in cons_dict.keys()
            if (key[0] == typ and key[1] == cond)
        ]
    else:
        # Keys are (typ, idx)
        nuc_indices = [
            key[1]
            for key in cons_dict.keys()
            if key[0] == typ
        ]

    for idx in nuc_indices:
        # Lookup consensus position with appropriate key
        if CONS_BY_CONDITION:
            key = (typ, cond, idx)
        else:
            key = (typ, idx)

        cpos = cons_dict.get(key, None)
        if cpos is None:
            continue

        # Find all centres within ±MAX_OFFSET of this consensus
        mask  = (all_centres >= cpos - MAX_OFFSET) & (all_centres <= cpos + MAX_OFFSET)
        local = all_centres[mask]
        if local.size == 0:
            continue

        if USE_CONS_DIST:
            # One record per centre: distance to consensus
            for centre in local:
                records.append({
                    "cons_label": idx,
                    "offset":     abs(centre - cpos),
                    "condition":  cond
                })
        else:
            # Pairwise absolute distances between every pair in local
            if local.size < 2:
                continue
            for i, j in combinations(local, 2):
                records.append({
                    "cons_label": idx,
                    "offset":     abs(i - j),
                    "condition":  cond
                })

offset_df = pd.DataFrame(records)
if offset_df.empty:
    raise ValueError("No offset data gathered – check MAX_OFFSET or filters.")

# --------------------------------------------------------------------------- #
# 4) PLOT (unchanged logic, apart from colour map keyed by cond)
# --------------------------------------------------------------------------- #
conds   = sorted(offset_df["condition"].unique())
clr_map = {c: DEFAULT_READ_CLRS[i % len(DEFAULT_READ_CLRS)] for i, c in enumerate(conds)}

def numeric_key(lbl):
    # lbl is like "n-3" or "n+2"
    return int(lbl.replace("n-", "-").replace("n+", ""))

ordered_labels = sorted(offset_df["cons_label"].unique(), key=numeric_key)

fig = go.Figure()
for cond in conds:
    sub = offset_df[offset_df["condition"] == cond]
    fig.add_trace(
        go.Box(
            x=sub["cons_label"],
            y=sub["offset"],
            name=cond,
            line=dict(color=clr_map[cond]),
            fillcolor="rgba(0,0,0,0)",
            boxpoints=False
        )
    )

ylabel = "|centre − consensus|" if USE_CONS_DIST else "|centre₁ − centre₂|"
title  = f"Distances (±{MAX_OFFSET} bp) — consensus scope: " + \
         ("per condition" if CONS_BY_CONDITION else "global")

fig.update_layout(
    template=FIG_TEMPLATE,
    width=FIG_WIDTH,
    height=450,
    title=title,
    xaxis_title="Consensus peak (n‑index)",
    yaxis_title=f"{ylabel} (bp)",
    boxmode="group"
)
fig.update_xaxes(categoryorder="array", categoryarray=ordered_labels)
fig.show()


In [None]:
###############################################################################
# CELL 1 — per-read autocorrelation for filtered reads (with optional filling)
###############################################################################
import multiprocessing as mp
from functools import partial
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from scipy.ndimage import gaussian_filter1d

# ─────────────── Filtering & Configuration ───────────────
CHIP_RANK_CUTOFF      = 80
ABOVE_FLAG            = True
TYPES_TO_INCLUDE      = [] #"MOTIFS_rex32"
CHR_TYPE_INCLUDE      = ["X"]
CONDITIONS_TO_INCLUDE = [
    analysis_cond[0],
    analysis_cond[1],
    analysis_cond[2],
    analysis_cond[3]
]

MIN_READ_LENGTH = 700          # same meaning as in v5
REQUIRE_CENTRAL = True         # toggle central‑overlap filter on/off

# ─────────────── Filling algorithm toggle & width ───────────────
PERFORM_FILLING       = True
MET_DOMAIN_WIDTH      = 10

# ────────────────── Core smoothing/interp config ──────────────────
REL_POS_RANGE     = 1000
LAPSE_WINDOW      = 0
PROCESSING_OPTION = 2          # 1 = whole-read; 2 = split & lapse
PERFORM_INTERP    = True
INTERP_WINDOW     = 5
RAW_SMOOTH_KIND   = "moving"
RAW_MOVING_WIN    = 10
RAW_GAUSS_SIGMA   = 5
RAW_CLAMP_01      = False
FINAL_SMOOTH_KIND = "moving"
FINAL_MOVING_WIN  = 25
FINAL_GAUSS_SIGMA = 2
FINAL_CLAMP_01    = False

N_WORKERS         = max(2, mp.cpu_count() - 2)
DEBUG_PROGRESS    = True

# ───────────────────── Filtering ─────────────────────
chiprank_df = (
    pd.read_csv("/Data1/reference/rex_chiprank.bed", sep=r"\s+")
      .assign(type=lambda d: "MOTIFS_" + d["type"].astype(str))
)
chip_rank_lookup = {
    t: round(float(rk) * 100, 3)
    for t, rk in zip(chiprank_df["type"], chiprank_df["chip_rank"])
}

keep_conds = set(CONDITIONS_TO_INCLUDE)
keep_types = (
    set(TYPES_TO_INCLUDE)
    if TYPES_TO_INCLUDE else
    {t for t, r in chip_rank_lookup.items() if (r >= CHIP_RANK_CUTOFF) == ABOVE_FLAG}
)
keep_chr = (
    set(CHR_TYPE_INCLUDE)
    if CHR_TYPE_INCLUDE else
    set(merged_df["chr_type"].unique())
)

# --- base metadata filter ----------------------------------------------------
df0 = (
    merged_df
      .query("condition in @keep_conds and type in @keep_types and chr_type in @keep_chr")
      .reset_index(drop=True)
)

# --- central‑overlap / length filter (mirrors v5) ----------------------------
if REQUIRE_CENTRAL:
    half = MIN_READ_LENGTH // 2
    mask_central = df0["rel_pos"].apply(
        lambda arr: (min(arr) <= -half) and (max(arr) >= half)
    )
    filtered_reads_df = df0[mask_central].reset_index(drop=True)
else:
    # fall back to simple read‑length threshold
    filtered_reads_df = df0[
        df0["rel_pos"].str.len() >= MIN_READ_LENGTH
    ].reset_index(drop=True)

# --- bookkeeping / sanity prints --------------------------------------------
types_kept = sorted(filtered_reads_df['type'].unique())
print(f"[INFO] types included after filtering ({len(types_kept)}): {types_kept}")

if DEBUG_PROGRESS:
    missing = keep_types - set(types_kept)
    if missing:
        print(f"[WARN] types requested but not found after filtering ({len(missing)}): "
              f"{sorted(missing)}")


# ───────────────────── Helpers ─────────────────────
def _mode_interpolate(arr: np.ndarray, radius: int) -> np.ndarray:
    isnan = np.isnan(arr)
    if not isnan.any():
        return arr
    valid = (~isnan).astype(int)
    ones  = ((arr == 1) & ~isnan).astype(int)
    c_val = np.concatenate(([0], np.cumsum(valid)))
    c_one = np.concatenate(([0], np.cumsum(ones)))
    def _cnt(c_vec, i):
        lo, hi = max(0, i - radius), min(len(arr)-1, i + radius)
        return c_vec[hi+1] - c_vec[lo]
    filled = arr.copy()
    for i in np.where(isnan)[0]:
        tot = _cnt(c_val, i)
        filled[i] = 0.0 if tot == 0 else (1.0 if _cnt(c_one, i) > tot/2 else 0.0)
    return filled

def _apply_smoothing(y: np.ndarray, *, kind: str,
                     moving_win: int, gauss_sigma: float,
                     clamp: bool) -> np.ndarray:
    if kind == "moving" and moving_win > 1:
        y = np.convolve(y, np.ones(moving_win)/moving_win, mode="same")
    elif kind == "gaussian" and gauss_sigma > 0:
        y = gaussian_filter1d(y, sigma=gauss_sigma, mode="nearest")
    return np.clip(y, 0, 1) if clamp else y

def _fill_met_domains(arr: np.ndarray, width: int) -> np.ndarray:
    idx = np.where(arr == 1)[0]
    if len(idx) < 2:
        return arr
    filled = arr.copy()
    for a, b in zip(idx, idx[1:]):
        if b - a <= width:
            filled[a:b+1] = 1.0
    return filled

raw_smooth = partial(
    _apply_smoothing,
    kind=RAW_SMOOTH_KIND,
    moving_win=RAW_MOVING_WIN,
    gauss_sigma=RAW_GAUSS_SIGMA,
    clamp=RAW_CLAMP_01
)
final_smooth = partial(
    _apply_smoothing,
    kind=FINAL_SMOOTH_KIND,
    moving_win=FINAL_MOVING_WIN,
    gauss_sigma=FINAL_GAUSS_SIGMA,
    clamp=FINAL_CLAMP_01
)

def _autocorr_vec(x: np.ndarray, max_lag: int) -> np.ndarray:
    """
    True Pearson-style autocorrelation up to max_lag:
     - pairwise-complete (ignores NaNs)
     - per-lag separate means & variances
    """
    n = len(x)
    ac = np.full(max_lag+1, np.nan)
    if n < 2:
        return ac

    for k in range(0, min(max_lag, n-1) + 1):
        x1 = x[:n-k]
        x2 = x[k:]
        valid = ~np.isnan(x1) & ~np.isnan(x2)
        if valid.sum() < 2:
            continue

        x1v = x1[valid]
        x2v = x2[valid]
        m1 = x1v.mean()
        m2 = x2v.mean()

        num = np.sum((x1v - m1) * (x2v - m2))
        v1  = np.sum((x1v - m1)**2)
        v2  = np.sum((x2v - m2)**2)
        denom = np.sqrt(v1 * v2)

        ac[k] = num/denom if denom > 0 else np.nan

    return ac

def _process_read(read_row: pd.Series) -> pd.DataFrame:
    meta     = read_row[['read_id','condition','type','chr_type']].to_dict()
    rel_pos  = np.asarray(read_row['rel_pos'], dtype=int)
    modq_bin = np.asarray(read_row['mod_qual_bin'], dtype=int)

    # mask to ±REL_POS_RANGE
    mask = (rel_pos >= -REL_POS_RANGE) & (rel_pos <= REL_POS_RANGE)
    rel_pos, modq_bin = rel_pos[mask], modq_bin[mask]
    if rel_pos.size == 0:
        return pd.DataFrame()

    def process_vec(vec, side_label):
        # restrict fill/smooth to covered span
        idxs = np.where(~np.isnan(vec))[0]
        start, end = idxs.min(), idxs.max()
        region = vec[start:end+1].copy()
        if PERFORM_FILLING:
            region = _fill_met_domains(region, MET_DOMAIN_WIDTH)
        if PERFORM_INTERP:
            region = _mode_interpolate(region, INTERP_WINDOW)
        region = raw_smooth(region)
        vec[start:end+1] = region
        ac = _autocorr_vec(vec, REL_POS_RANGE)
        return pd.DataFrame({
            **meta,
            'pos_neg': side_label,
            'lag':     np.arange(REL_POS_RANGE+1),
            'autocorr': ac
        })

    rows = []
    if PROCESSING_OPTION == 1:
        full_len = 2 * REL_POS_RANGE + 1
        vec = np.full(full_len, np.nan)
        idx = rel_pos + REL_POS_RANGE
        vec[idx] = modq_bin
        rows.append(process_vec(vec, 'whole'))
    else:
        half = LAPSE_WINDOW // 2
        # NEG side
        neg_mask = rel_pos <= -(half + 1)
        if neg_mask.any():
            rp = rel_pos[neg_mask]
            mb = modq_bin[neg_mask]
            new_pos = -rp
            keep = new_pos <= REL_POS_RANGE
            vec_neg = np.full(REL_POS_RANGE+1, np.nan)
            vec_neg[new_pos[keep]] = mb[keep]
            rows.append(process_vec(vec_neg, 'neg'))

        # POS side
        pos_mask = rel_pos >= (half + 1)
        if pos_mask.any():
            rp = rel_pos[pos_mask]
            mb = modq_bin[pos_mask]
            new_pos = rp
            keep = new_pos <= REL_POS_RANGE
            vec_pos = np.full(REL_POS_RANGE+1, np.nan)
            vec_pos[new_pos[keep]] = mb[keep]
            rows.append(process_vec(vec_pos, 'pos'))

    return pd.concat(rows, ignore_index=True)

def _process_read_dict(rd: dict) -> pd.DataFrame:
    return _process_read(pd.Series(rd))

# ───────────────────────── Run multiprocessing ────────────────────────────
read_records = filtered_reads_df.to_dict(orient='records')
if DEBUG_PROGRESS:
    print(f"[INFO] Queued reads: {len(read_records):,}")
with mp.Pool(N_WORKERS) as pool:
    it = pool.imap_unordered(_process_read_dict, read_records, chunksize=128)
    if DEBUG_PROGRESS:
        it = tqdm(it, total=len(read_records), desc="reads")
    per_read_ac = pd.concat((df for df in it if not df.empty),
                             ignore_index=True)
print(f"[INFO] per_read_ac shape: {per_read_ac.shape}")


In [None]:
# CELL 2 — aggregate, autocorrelation plots (split if PROCESSING_OPTION=2),
#          subtraction plots, and debug plots

import time, itertools, numpy as np, pandas as pd, plotly.graph_objects as go

# ─────────────── Configuration ───────────────
MIN_LAG, MAX_LAG       = 100, 600
MIN_OVERLAP            = 100    # minimum bases overlapping between read & lagged read
ZSCORE_NORMALIZE       = False  # set True to z-score normalize
TYPES_TO_PLOT          = []     # [] → include all; else subset
SHOW_CI                = False  # set False to disable CI envelope
# PROCESSING_OPTION, analysis_cond, per_read_ac, filtered_reads_df
# are assumed defined by Cell 1

# ─────────── Condition ordering & colour mapping ───────────
conditions = [analysis_cond[0], analysis_cond[1],
              analysis_cond[2], analysis_cond[3]]
COND_COLORS = {c: "#888888" for c in conditions}
COND_COLORS.update({c: "#4974a5" for c in conditions if "N2"   in c})
COND_COLORS.update({c: "#51ab4d" for c in conditions if "DPY27" in c})
COND_COLORS.update({c: "#b12537" for c in conditions if "SDC2"  in c})

def _hex_rgba(hex_code, alpha=0.20):
    """Convert hex colour to rgba string with given alpha."""
    hex_code = hex_code.lstrip('#')
    r, g, b = (int(hex_code[i:i+2], 16) for i in (0, 2, 4))
    return f"rgba({r},{g},{b},{alpha})"

# ─────────────── Apply type filter & MIN_OVERLAP filter ───────────────
df_ac = per_read_ac.copy()
df_fr = filtered_reads_df.copy()

# Assume filtered_reads_df has 'read_id' and either:
#  • 'read_length' column, or
#  • 'start' and 'end' columns so we can compute read_length.
if 'read_length' not in df_fr.columns:
    # Compute read_length if start/end are present
    if {'start', 'end'}.issubset(df_fr.columns):
        df_fr['read_length'] = df_fr['end'] - df_fr['start'] + 1
    else:
        raise RuntimeError("filtered_reads_df must contain 'read_length' or both 'start' & 'end' columns.")

# Merge read_length into df_ac (requires per_read_ac to have 'read_id')
if 'read_id' not in df_ac.columns:
    raise RuntimeError("per_read_ac must contain 'read_id' for overlap filtering.")

df_ac = df_ac.merge(
    df_fr[['read_id', 'read_length']],
    on='read_id', how='left'
)

# Keep only autocorr rows where overlap ≥ MIN_OVERLAP:
#   overlap = read_length − lag  ⇒ require lag ≤ read_length − MIN_OVERLAP
df_ac = df_ac[df_ac['lag'] <= (df_ac['read_length'] - MIN_OVERLAP)].reset_index(drop=True)

# Now apply type filter
if TYPES_TO_PLOT:
    df_ac = df_ac[df_ac['type'].isin(TYPES_TO_PLOT)]
    df_fr = df_fr[df_fr['type'].isin(TYPES_TO_PLOT)]

# ─────────────────────────── Timing start ───────────────────────────
t0 = time.time()
print(f"[DEBUG] Cell 2 start at {t0:.2f}s")

# ───── 1. Aggregate across read-ids within each type ─────
plot_df_auto = (
    df_ac
      .groupby(['pos_neg', 'condition', 'type', 'lag'], as_index=False, sort=False)['autocorr']
      .mean()
)
plot_df_auto['autocorr_smooth'] = (
    plot_df_auto
      .groupby(['pos_neg', 'condition', 'type'])['autocorr']
      .transform(final_smooth)
)
μ, σ = plot_df_auto['autocorr_smooth'].mean(), plot_df_auto['autocorr_smooth'].std(ddof=0)
plot_df_auto['plot_val'] = (
    (plot_df_auto['autocorr_smooth'] - μ) / σ
    if ZSCORE_NORMALIZE else plot_df_auto['autocorr_smooth']
)
plot_df_auto = plot_df_auto[plot_df_auto['lag'].between(MIN_LAG, MAX_LAG)].reset_index(drop=True)
print(f"[DEBUG] per-type aggregation done in {(time.time() - t0):.2f}s")

# ───── 2. Collapse over type for plotting / Δ-plots ─────
plot_df_cond = (
    plot_df_auto
      .groupby(['pos_neg', 'condition', 'lag'], as_index=False)
      .agg(mean_val=('plot_val', 'mean'),
           sd_val  =('plot_val', 'std'),
           n       =('plot_val', 'count'))
)
plot_df_cond['sem']   = plot_df_cond['sd_val'] / np.sqrt(plot_df_cond['n'].replace(0, np.nan))
plot_df_cond['ci_lo'] = plot_df_cond['mean_val'] - 1.96 * plot_df_cond['sem']
plot_df_cond['ci_hi'] = plot_df_cond['mean_val'] + 1.96 * plot_df_cond['sem']

y_label = "Autocorrelation (z-score)" if ZSCORE_NORMALIZE else "Mean autocorrelation"

# ───── 3. Helper to add a line + optional CI band ─────
def _add_trace_with_ci(fig, sub_df, lag_col, color, name):
    # main line
    fig.add_trace(go.Scatter(
        x=sub_df[lag_col], y=sub_df['mean_val'],
        mode='lines', name=name, line=dict(color=color)
    ))
    # CI envelope if enabled & ≥2 types
    if SHOW_CI and sub_df['n'].max() > 1:
        fig.add_trace(go.Scatter(
            x=pd.concat([sub_df[lag_col], sub_df[lag_col][::-1]]),
            y=pd.concat([sub_df['ci_hi'],  sub_df['ci_lo'][::-1]]),
            fill='toself', fillcolor=_hex_rgba(color),
            line=dict(width=0), hoverinfo='skip', showlegend=False
        ))

# ───── 4. Main autocorrelation plots ─────
if PROCESSING_OPTION == 2:
    # — Negative side —
    neg_df = plot_df_cond[plot_df_cond['pos_neg'] == 'neg'].copy()
    neg_df['plot_lag'] = -neg_df['lag']
    fig_neg = go.Figure(layout=dict(template="plotly_white"))
    for cond in conditions:
        _add_trace_with_ci(
            fig_neg,
            neg_df[neg_df['condition'] == cond].sort_values('plot_lag'),
            'plot_lag', COND_COLORS[cond], cond
        )
    fig_neg.update_layout(
        title="Read-level Autocorrelation (negative side)",
        xaxis_title="Lag (bp)", yaxis_title=y_label,
        legend_title="Condition", xaxis=dict(range=[-MAX_LAG, -MIN_LAG]),
        width=1200, height=600
    )
    fig_neg.show()
    print(f"[DEBUG] Negative-side plot {(time.time() - t0):.2f}s")

    # — Positive side —
    pos_df = plot_df_cond[plot_df_cond['pos_neg'] == 'pos'].copy()
    pos_df['plot_lag'] = pos_df['lag']
    fig_pos = go.Figure(layout=dict(template="plotly_white"))
    for cond in conditions:
        _add_trace_with_ci(
            fig_pos,
            pos_df[pos_df['condition'] == cond].sort_values('plot_lag'),
            'plot_lag', COND_COLORS[cond], cond
        )
    fig_pos.update_layout(
        title="Read-level Autocorrelation (positive side)",
        xaxis_title="Lag (bp)", yaxis_title=y_label,
        legend_title="Condition", xaxis=dict(range=[MIN_LAG, MAX_LAG]),
        width=1200, height=600
    )
    fig_pos.show()
    print(f"[DEBUG] Positive-side plot {(time.time() - t0):.2f}s")

else:
    whole_df = plot_df_cond.copy()
    fig = go.Figure(layout=dict(template="plotly_white"))
    for cond in conditions:
        _add_trace_with_ci(
            fig,
            whole_df[whole_df['condition'] == cond].sort_values('lag'),
            'lag', COND_COLORS[cond], cond
        )
    fig.update_layout(
        title="Read-level Autocorrelation (filtered)",
        xaxis_title="Lag (bp)", yaxis_title=y_label,
        legend_title="Condition", xaxis=dict(range=[MIN_LAG, MAX_LAG]),
        width=1000, height=600
    )
    fig.show()
    print(f"[DEBUG] Whole-read plot {(time.time() - t0):.2f}s")

# ───── 5. Pair-wise subtraction plots (mean-across-types) ─────
def _pairwise_diff(base_df, lag_col, xaxis_range, title):
    fig = go.Figure(layout=dict(template="plotly_white"))
    for c1, c2 in itertools.combinations(conditions, 2):
        d1 = base_df[base_df['condition'] == c1].set_index(lag_col)['mean_val']
        d2 = base_df[base_df['condition'] == c2].set_index(lag_col)['mean_val']
        diff_df = (d1 - d2).reset_index(name='diff')
        fig.add_trace(go.Scatter(
            x=diff_df[lag_col], y=diff_df['diff'],
            mode='lines', name=f"{c1} − {c2}"
        ))
    fig.update_layout(
        title=title,
        xaxis_title="Lag (bp)",
        yaxis_title=f"Δ {y_label}",
        xaxis=dict(range=xaxis_range),
        width=1000, height=600
    )
    fig.show()

if PROCESSING_OPTION == 2:
    _pairwise_diff(
        neg_df, 'plot_lag', [-MAX_LAG, -MIN_LAG],
        "Pair-wise Differences (negative side)"
    )
    _pairwise_diff(
        pos_df, 'plot_lag', [MIN_LAG, MAX_LAG],
        "Pair-wise Differences (positive side)"
    )
else:
    _pairwise_diff(
        plot_df_cond, 'lag', [MIN_LAG, MAX_LAG],
        "Pair-wise Differences"
    )

print(f"[DEBUG] Cell 2 complete in {(time.time() - t0):.2f}s")


In [None]:
###############################################################################
# CELL 3 — NRL‑guided peak detection, distances, amplitudes & Δ‑amplitudes
###############################################################################
import numpy as np
import pandas as pd
import plotly.graph_objects as go

# ─────────────── High‑level toggles ───────────────
RUN_DISTANCES      = False   # peak→peak Δ
RUN_AMPLITUDES     = False   # peak amplitude
RUN_AMP_SUBTRACT   = True   # NEW: amplitude – baseline amplitude

# ─────────────── Parameters ───────────────
NUM_CYCLES   = 4           # how many successive peaks to analyse
NRL          = 177         # nominal nucleosome repeat length
WINDOW       = 30          # ± bp around each NRL multiple to look for a peak
FIG_W, FIG_H = 900, 450
BASELINE     = analysis_cond[0]

# -----------------------------------------------------------------------------
# 1)  PEAK‑SELECTION + METRIC UTILS  (unchanged)
# -----------------------------------------------------------------------------
def _select_nrl_peaks(lag_arr, val_arr, n_cycles=NUM_CYCLES):
    chosen_lags, chosen_vals = [], []
    for k in range(1, n_cycles + 1):
        lo, hi = k * NRL - WINDOW, k * NRL + WINDOW
        mask = (lag_arr >= lo) & (lag_arr <= hi)
        if mask.any():
            idx_max = np.argmax(val_arr[mask])
            sub_lags = lag_arr[mask]
            sub_vals = val_arr[mask]
            chosen_lags.append(sub_lags[idx_max])
            chosen_vals.append(sub_vals[idx_max])
    return np.asarray(chosen_lags), np.asarray(chosen_vals)

def _peak_metrics(lag_arr, val_arr, n_cycles=NUM_CYCLES):
    """Return ([distances], [amplitudes]) for successive NRL peaks."""
    pk_lag, pk_val = _select_nrl_peaks(lag_arr, val_arr, n_cycles)
    if pk_lag.size < 2:
        return [], []
    dists = np.diff(pk_lag)[:n_cycles]

    amps, prev_lag = [], 0
    for lag, val in zip(pk_lag, pk_val):
        seg_mask = (lag_arr >= prev_lag) & (lag_arr <= lag)
        trough = np.min(val_arr[seg_mask]) if seg_mask.any() else np.nan
        amps.append(val - trough)
        prev_lag = lag
        if len(amps) == n_cycles:
            break
    return dists.tolist(), amps[:n_cycles]

# -----------------------------------------------------------------------------
# 2)  DATA‑FRAME BUILDERS
# -----------------------------------------------------------------------------
def _build_metrics_df(auto_df, side=None):
    """Compute distance & amplitude DataFrames for (side · condition · type)."""
    sub = auto_df if side is None else auto_df[auto_df["pos_neg"] == side]
    dist_rows, amp_rows = [], []

    for (pos_neg, cond, typ), g in sub.groupby(["pos_neg", "condition", "type"]):
        g = g.sort_values("lag")
        dists, amps = _peak_metrics(g["lag"].to_numpy(), g["plot_val"].to_numpy())
        for i, d in enumerate(dists, 1):
            dist_rows.append(dict(pos_neg=pos_neg, condition=cond, type=typ,
                                  cycle=f"d{i}", distance=d))
        for i, a in enumerate(amps, 1):
            amp_rows.append(dict(pos_neg=pos_neg, condition=cond, type=typ,
                                 cycle=f"a{i}", amplitude=a))

    return pd.DataFrame(dist_rows), pd.DataFrame(amp_rows)

# -----------------------------------------------------------------------------
# 3)  PLOTTING HELPER
# -----------------------------------------------------------------------------
def _plot_box(df, value_col, ylabel, title, order, colours):
    fig = go.Figure(layout=dict(template="plotly_white", width=FIG_W, height=FIG_H))
    for cond in analysis_cond:
        if cond == BASELINE and "Δ" in ylabel:
            continue                      # skip baseline in subtraction plot
        sub = df[df["condition"] == cond]
        if sub.empty:
            continue
        fig.add_trace(go.Box(
            x=sub["cycle"], y=sub[value_col], name=cond,
            marker_color=colours.get(cond, "#888888"),
            boxpoints="outliers", offsetgroup=cond
        ))
    fig.update_xaxes(categoryorder="array", categoryarray=order, title_text="Cycle")
    fig.update_layout(title=title, yaxis_title=ylabel,
                      boxmode="group", legend_title="Condition")
    fig.show()

# -----------------------------------------------------------------------------
# 4)  BUILD METRICS, UNIFY SIDES (when needed), & PLOT
# -----------------------------------------------------------------------------
if PROCESSING_OPTION == 2:
    # build once per side, then concatenate with signed cycle labels
    dist_neg, amp_neg = _build_metrics_df(plot_df_auto, side="neg")
    dist_pos, amp_pos = _build_metrics_df(plot_df_auto, side="pos")

    # label cycles from NEG side as negative (d-1, a-1 …)
    for df in (dist_neg, amp_neg):
        df["cycle"] = df["cycle"].str.replace(r"^([da])", r"\1-",
                                              regex=True)
    dist_df = pd.concat([dist_neg, dist_pos], ignore_index=True)
    amp_df  = pd.concat([amp_neg,  amp_pos],  ignore_index=True)
    side_name = "Negative & Positive sides"
    order_dist = [f"d-{i}" for i in range(NUM_CYCLES, 0, -1)] + \
                 [f"d{i}"  for i in range(1, NUM_CYCLES + 1)]
    order_amp  = [f"a-{i}" for i in range(NUM_CYCLES, 0, -1)] + \
                 [f"a{i}"  for i in range(1, NUM_CYCLES + 1)]
else:
    dist_df, amp_df = _build_metrics_df(plot_df_auto)    # whole read
    side_name = "Whole read"
    order_dist = [f"d{i}" for i in range(1, NUM_CYCLES + 1)]
    order_amp  = [f"a{i}" for i in range(1, NUM_CYCLES + 1)]

# ----- Δ‑Amplitude (baseline‑subtracted) -------------------------------------
if RUN_AMP_SUBTRACT and not amp_df.empty:
    # look‑up table of baseline amplitudes keyed by (pos_neg, type, cycle)
    base_amp = amp_df[amp_df["condition"] == BASELINE] \
               .set_index(["pos_neg", "type", "cycle"])["amplitude"]

    amp_sub_df = amp_df[amp_df["condition"] != BASELINE].copy()
    amp_sub_df["amp_sub"] = amp_sub_df.apply(
        lambda r: r["amplitude"] - base_amp.get((r["pos_neg"], r["type"], r["cycle"]),
                                                np.nan),
        axis=1
    )
    amp_sub_df = amp_sub_df.dropna(subset=["amp_sub"])

# ----- PLOTS -----------------------------------------------------------------
if RUN_DISTANCES and not dist_df.empty:
    _plot_box(dist_df, "distance",
              "Distance between peaks (bp)",
              f"Peak→Peak Distances – {side_name}",
              order_dist, COND_COLORS)

if RUN_AMPLITUDES and not amp_df.empty:
    _plot_box(amp_df, "amplitude",
              "Amplitude (peak − local trough)",
              f"Peak Amplitudes – {side_name}",
              order_amp, COND_COLORS)

if RUN_AMP_SUBTRACT and not amp_sub_df.empty:
    _plot_box(amp_sub_df, "amp_sub",
              f"Δ Amplitude vs {BASELINE}",
              f"Amplitude Subtraction (vs {BASELINE}) – {side_name}",
              order_amp, COND_COLORS)


In [None]:
###############################################################################
# CELL 3c — interactive autocorrelation viewer (2 dropdowns + peak markers)
###############################################################################
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import ipywidgets as wd
from IPython.display import display, clear_output

# ───────────────────────── CONFIG ───────────────────────── #
NUM_CYCLES   = 3
NRL          = 177          # expected repeat
WINDOW       = 30           # ± bp around k*NRL
FIG_W, FIG_H = 1200, 500
# plot_df_auto, analysis_cond, COND_COLORS, PROCESSING_OPTION must exist (from Cell 1/2)

# ─────────────────── helper: NRL‑guided peaks ─────────────────── #
def _nrl_peaks(x, y, n_cycles=NUM_CYCLES):
    """
    Return array of x‑positions of peaks:
    one per window [k*NRL ± WINDOW], k=1…n_cycles, picking the highest point.
    x must be ascending.
    """
    out = []
    for k in range(1, n_cycles + 1):
        ctr = k * NRL
        mask = (x >= ctr - WINDOW) & (x <= ctr + WINDOW)
        if mask.any():
            seg_x = x[mask]
            seg_y = y[mask]
            out.append(seg_x[np.argmax(seg_y)])
    return np.array(out)

# ─────────────────── pre‑compute curves & peaks ─────────────────── #
# keys = (type, condition, side)  side = 'neg'|'pos'|'whole'
mean_curve = {}   # (ty,cond,side) → (x_vals, y_vals)
peak_dict  = {}   # (ty,cond,side) → peak_x array

sides = ['whole'] if PROCESSING_OPTION != 2 else ['neg', 'pos']
for side in sides:
    df_side = (plot_df_auto
               if side == 'whole'
               else plot_df_auto[plot_df_auto['pos_neg'] == side])

    for (typ, cond), grp in df_side.groupby(['type', 'condition']):
        g = grp.sort_values('lag')
        x_vals = g['lag'].to_numpy()
        y_vals = g['plot_val'].to_numpy()

        mean_curve[(typ, cond, side)] = (x_vals, y_vals)
        peak_dict [(typ, cond, side)] = _nrl_peaks(x_vals, y_vals)

# ─────────────────── interactive viewer ─────────────────── #
def viewer(side='whole'):
    """
    Build dropdown‑controlled Plotly figure for the chosen side.
    """
    # available combos for this side
    combos = [(t, c) for (t, c, s) in mean_curve.keys() if s == side]
    if not combos:
        print(f"[WARN] no curves for side={side}")
        return

    types      = sorted({t for (t, _) in combos})
    conditions = sorted({c for (_, c) in combos})

    dd_type  = wd.Dropdown(options=types, value=types[0], description='Type:')
    dd_cond  = wd.Dropdown(options=conditions, value=conditions[0],
                           description='Condition:')
    out_plot = wd.Output()

    def _redraw(*_):
        typ  = dd_type.value
        cond = dd_cond.value
        x, y = mean_curve[(typ, cond, side)]
        peaks = peak_dict[(typ, cond, side)]

        # build vertical‑line shapes for peaks
        shapes = [dict(type='line',
                       x0=px, x1=px,
                       y0=y.min(), y1=y.max(),
                       line=dict(color='black', width=1, dash='dot'))
                  for px in peaks]

        fig = go.Figure(layout=dict(template='plotly_white',
                                    width=FIG_W, height=FIG_H,
                                    shapes=shapes))
        fig.add_trace(go.Scatter(
            x=x, y=y, mode='lines',
            line=dict(color=COND_COLORS.get(cond, '#888888'), width=2),
            name=f'{cond} — {typ}'
        ))
        fig.update_layout(
            title=(f"Autocorrelation ({side.upper() if side!='whole' else 'WHOLE'})"
                   f" — {typ} | {cond}"),
            xaxis_title='Lag (bp)', yaxis_title='Autocorrelation'
        )

        with out_plot:
            clear_output(wait=True)
            display(fig)

    # link callbacks
    dd_type.observe(_redraw, names='value')
    dd_cond.observe(_redraw, names='value')

    # initial draw
    _redraw()

    caption = wd.HTML(f"<b>{'Whole read' if side=='whole' else side.upper()+' side'}</b>")
    display(wd.VBox([caption, wd.HBox([dd_type, dd_cond]), out_plot]))

# ─────────────────── render viewer(s) ─────────────────── #
for s in sides:
    viewer(s)


In [None]:
###############################################################################
# CELL 3 — Per-read FFT analysis, power spectra, & comparative plots
###############################################################################
import time
import numpy as np
import pandas as pd
#from scipy.fft import rfft, rfftfreq # No longer needed for metrics if all LS based
from scipy.signal import find_peaks, peak_widths, lombscargle
import plotly.graph_objects as go
from tqdm.auto import tqdm
import plotly.colors
from typing import Union, Optional, Tuple, List, Dict

# ─────────────── FFT Analysis Configuration ───────────────
# For dominant peak identification (now from Lomb-Scargle spectrum)
NUC_PERIOD_MIN = 150
NUC_PERIOD_MAX = 250
MIN_VALID_AUTOCORR_POINTS = 20

# For plotting average power spectra (using Lomb-Scargle)
SPECTRUM_PLOT_PERIOD_MIN = 50
SPECTRUM_PLOT_PERIOD_MAX = 500
SPECTRUM_PLOT_PERIOD_STEP = 1.0 # Desired 1bp resolution

# ───────────────────── Timing start ─────────────────────
t_cell3_start = time.time()
print(f"[INFO] Cell 3: Analysis and Plotting started at {t_cell3_start:.2f}s")

# ───────────────── Color Helper ─────────────────
def hex_to_rgba_str(hex_color: str, alpha: float) -> str:
    hex_color = hex_color.lstrip('#')
    r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    return f'rgba({r},{g},{b},{alpha})'

# ───────────────── Lomb-Scargle Calculation Functions ─────────────────

def extract_power_spectrum_lombscargle(autocorr_values: np.ndarray,
                                       sampling_interval: float,
                                       plot_period_min: float,
                                       plot_period_max: float,
                                       period_step: float) -> pd.DataFrame:
    N = len(autocorr_values)
    if N < 2 or np.all(autocorr_values == 0):
        return pd.DataFrame({'period': [], 'power': []})

    window = np.hanning(N)
    autocorr_windowed = autocorr_values * window
    t_points = np.arange(N) * sampling_interval
    target_periods = np.arange(float(plot_period_min), float(plot_period_max) + period_step, period_step)

    valid_period_mask = target_periods > 1e-9
    if not np.any(valid_period_mask): return pd.DataFrame({'period': [], 'power': []})
    final_target_periods = target_periods[valid_period_mask]
    target_angular_frequencies = 2 * np.pi / final_target_periods

    power_values = lombscargle(t_points, autocorr_windowed, target_angular_frequencies, normalize=False)
    return pd.DataFrame({'period': final_target_periods, 'power': power_values})

def find_dominant_peak_from_lombscargle_spectrum(
    lombscargle_spectrum_df: pd.DataFrame,
    metric_period_min: float,
    metric_period_max: float,
    period_step_in_spectrum: float = 1.0 # Assuming 1bp resolution from input spectrum
) -> pd.Series:
    default_return = pd.Series({'dominant_period': np.nan, 'dominant_power': np.nan, 'sharpness_q': np.nan})
    if lombscargle_spectrum_df.empty: return default_return

    metric_range_spectrum = lombscargle_spectrum_df[
        (lombscargle_spectrum_df['period'] >= metric_period_min) &
        (lombscargle_spectrum_df['period'] <= metric_period_max)
    ].copy()

    if metric_range_spectrum.empty or metric_range_spectrum['power'].isnull().all() or (metric_range_spectrum['power'] < 1e-9).all():
        return default_return

    idx_max_power_in_slice = metric_range_spectrum['power'].idxmax()
    dominant_period = metric_range_spectrum.loc[idx_max_power_in_slice, 'period']
    dominant_power = metric_range_spectrum.loc[idx_max_power_in_slice, 'power']

    if dominant_power < 1e-9: return default_return # No significant peak

    power_values_in_metric_range = metric_range_spectrum['power'].values
    peak_idx_relative_to_slice = np.argmax(power_values_in_metric_range) # int position

    widths, _, _, _ = peak_widths( # Using _,_ for unused left_ips, right_ips if not interpolating FWHM from them
        power_values_in_metric_range,
        [peak_idx_relative_to_slice],
        rel_height=0.5
    )
    sharpness_q = np.nan
    if widths.size > 0 and not np.isnan(widths[0]) and widths[0] > 0:
        delta_period_fwhm = widths[0] * period_step_in_spectrum # Width in samples * period step
        if delta_period_fwhm > 1e-9:
            sharpness_q = dominant_period / delta_period_fwhm

    return pd.Series({
        'dominant_period': dominant_period,
        'dominant_power': dominant_power,
        'sharpness_q': sharpness_q
    })

# ───────────────── Core Data Processing Function ─────────────────

def _apply_metric_and_spectrum_from_lombscargle(group: pd.DataFrame, min_lag: int, max_lag: int,
                                                metric_period_min: float, metric_period_max: float,
                                                spectrum_plot_period_min: float, spectrum_plot_period_max: float,
                                                spectrum_plot_period_step: float,
                                                min_valid_points: int) -> Tuple[pd.Series, Optional[pd.DataFrame]]:
    group = group.sort_values('lag')
    original_data_in_range = group[
        (group['lag'] >= min_lag) & (group['lag'] <= max_lag)
    ]['autocorr'].dropna()

    default_metrics = pd.Series({'dominant_period': np.nan, 'dominant_power': np.nan, 'sharpness_q': np.nan})
    if len(original_data_in_range) < min_valid_points: return default_metrics, None

    autocorr_series_for_fft = group.set_index('lag')['autocorr'].reindex(np.arange(min_lag, max_lag + 1))
    autocorr_values = autocorr_series_for_fft.fillna(0).values
    if np.all(autocorr_values == 0): return default_metrics, None

    # 1. Generate the Lomb-Scargle spectrum over the broad plotting range
    lombscargle_plot_spectrum_df = extract_power_spectrum_lombscargle(
        autocorr_values, 1.0, # sampling_interval = 1.0 bp
        spectrum_plot_period_min, spectrum_plot_period_max, spectrum_plot_period_step
    )

    # 2. Derive dominant peak metrics from this spectrum, focusing on the NUC_PERIOD range
    if lombscargle_plot_spectrum_df is not None and not lombscargle_plot_spectrum_df.empty:
        metrics_series = find_dominant_peak_from_lombscargle_spectrum(
            lombscargle_plot_spectrum_df,
            metric_period_min, metric_period_max,
            period_step_in_spectrum=spectrum_plot_period_step # Pass the step used
        )
    else:
        metrics_series = default_metrics
        lombscargle_plot_spectrum_df = None # Ensure it's None if spectrum calculation failed

    return metrics_series, lombscargle_plot_spectrum_df


def _process_group_for_lombscargle(args):
    """
    Worker wrapper: runs the metric + spectrum extraction and
    captures any exception.
    """
    name, group_df, grouping_cols, min_lag, max_lag, \
    metric_period_min, metric_period_max, \
    spectrum_plot_period_min, spectrum_plot_period_max, spectrum_plot_period_step, \
    min_valid_points = args

    try:
        metrics_s, spectrum_df = _apply_metric_and_spectrum_from_lombscargle(
            group_df,
            min_lag, max_lag,
            metric_period_min, metric_period_max,
            spectrum_plot_period_min, spectrum_plot_period_max,
            spectrum_plot_period_step,
            min_valid_points
        )
        return (name, metrics_s, spectrum_df, None)
    except Exception as e:
        return (name, None, None, e)

def generate_all_lombscargle_data(
    source_df: pd.DataFrame,
    min_lag: int, max_lag: int,
    metric_period_min: float, metric_period_max: float,
    spectrum_plot_period_min: float, spectrum_plot_period_max: float,
    spectrum_plot_period_step: float,
    processing_option: int, min_valid_points: int,
    n_workers: int = 50
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    grouping_cols = ['read_id', 'condition', 'type', 'chr_type']
    if processing_option == 2:
        grouping_cols.append('pos_neg')

    print(f"[INFO] Grouping by {grouping_cols} for Lomb-Scargle metrics and spectra extraction...")
    grouped = source_df.groupby(grouping_cols)
    group_items = []
    for name, group_df in grouped:
        group_items.append((name, group_df.copy(), grouping_cols,
                            min_lag, max_lag,
                            metric_period_min, metric_period_max,
                            spectrum_plot_period_min, spectrum_plot_period_max,
                            spectrum_plot_period_step,
                            min_valid_points))

    # Parallel map with tqdm
    with mp.Pool(processes=n_workers) as pool:
        results = list(tqdm(
            pool.imap_unordered(_process_group_for_lombscargle, group_items),
            total=len(group_items),
            desc="Processing Lomb-Scargle per group"
        ))

    # Collect results, skipping any that errored
    metrics_results_list = []
    spectra_results_list: List[pd.DataFrame] = []

    for name, metrics_s, spectrum_df, err in results:
        if err is not None:
            print(f"[WARN] Error processing group {name}: {err}")
            continue

        # build metadata dict
        if isinstance(name, tuple):
            meta_dict = dict(zip(grouping_cols, name))
        else:
            meta_dict = {grouping_cols[0]: name}

        # append metrics
        metrics_results_list.append({**meta_dict, **metrics_s})

        # append raw spectra with metadata columns
        if spectrum_df is not None and not spectrum_df.empty:
            for col, val in meta_dict.items():
                spectrum_df[col] = val
            spectra_results_list.append(spectrum_df)

    # finalize DataFrames
    lombscargle_metrics_df = pd.DataFrame(metrics_results_list)
    if not lombscargle_metrics_df.empty:
        lombscargle_metrics_df.dropna(
            subset=['dominant_period', 'dominant_power', 'sharpness_q'],
            how='all', inplace=True)

    all_spectra_raw_df = (
        pd.concat(spectra_results_list, ignore_index=True)
        if spectra_results_list else pd.DataFrame()
    )

    return lombscargle_metrics_df, all_spectra_raw_df

# ───────────────── Perform Lomb-Scargle Analysis & Spectra Extraction ─────────────────
if 'conditions' not in globals():
    print("[WARN] 'conditions' list not found, inferring from per_read_ac['condition'].")
    conditions = sorted(list(per_read_ac['condition'].unique()))

lombscargle_metrics_df, all_spectra_raw_df = generate_all_lombscargle_data(
    per_read_ac,
    MIN_LAG, MAX_LAG,
    NUC_PERIOD_MIN, NUC_PERIOD_MAX,
    SPECTRUM_PLOT_PERIOD_MIN, SPECTRUM_PLOT_PERIOD_MAX,
    SPECTRUM_PLOT_PERIOD_STEP,
    PROCESSING_OPTION, MIN_VALID_AUTOCORR_POINTS,
    n_workers=50
)

print(f"[INFO] Lomb-Scargle metrics calculation complete. Shape: {lombscargle_metrics_df.shape}")
if not lombscargle_metrics_df.empty: print(lombscargle_metrics_df.head())
else: print("[WARN] Lomb-Scargle metrics DataFrame is empty.")

print(f"[INFO] Raw spectra (Lomb-Scargle) extraction complete. Shape: {all_spectra_raw_df.shape}")
if not all_spectra_raw_df.empty: print(all_spectra_raw_df.head())
else: print("[WARN] Raw spectra DataFrame is empty.")

# ───────────────── Aggregate Power Spectra ─────────────────
aggregated_spectra_df = pd.DataFrame()
if not all_spectra_raw_df.empty:
    agg_group_cols = ['condition', 'period']
    if PROCESSING_OPTION == 2 and 'pos_neg' in all_spectra_raw_df.columns:
        agg_group_cols.insert(1, 'pos_neg')

    if all(col in all_spectra_raw_df.columns for col in agg_group_cols):
        with pd.option_context('mode.chained_assignment', None):
            aggregated_spectra_df = all_spectra_raw_df.groupby(agg_group_cols)['power'].agg(['mean', 'sem']).reset_index()
        aggregated_spectra_df.rename(columns={'mean': 'mean_power', 'sem': 'sem_power'}, inplace=True)
        if 'sem_power' in aggregated_spectra_df.columns:
             aggregated_spectra_df['sem_power'] = aggregated_spectra_df['sem_power'].fillna(0)
        print(f"[INFO] Aggregated spectra DataFrame created. Shape: {aggregated_spectra_df.shape}")
        if not aggregated_spectra_df.empty: print(aggregated_spectra_df.head())
    else:
        print(f"[WARN] Could not aggregate spectra. Missing one or more grouping cols: {agg_group_cols}")
else:
    print("[INFO] Skipping spectra aggregation as raw spectra DataFrame is empty.")

# ───────────────── Plotting Functions (definitions assumed from previous correct version) ─────────────────
def plot_lombscargle_metric_distribution(df: pd.DataFrame, metric_col: str, title_suffix: str,
                                         plot_conditions: list, color_map: dict, p_option: int,
                                         y_axis_label: Optional[str] = None): # Renamed for clarity
    if df.empty or metric_col not in df.columns or df[metric_col].isnull().all():
        print(f"[WARN] No data to plot for metric: {metric_col} {title_suffix}")
        return
    base_title = f"Distribution of {metric_col.replace('_', ' ').title()}"
    y_label = y_axis_label if y_axis_label else metric_col.replace('_', ' ').title()
    plot_items: List[Tuple[str, pd.DataFrame]] = [('whole', df)] if p_option == 1 else \
        ([('neg', df[df['pos_neg'] == 'neg']), ('pos', df[df['pos_neg'] == 'pos'])] if 'pos_neg' in df.columns else [('combined', df)])

    for side_label, plot_data in plot_items:
        if plot_data.empty:
            print(f"[WARN] No data for side: {side_label} for metric: {metric_col}")
            continue
        fig = go.Figure(layout=dict(template="plotly_white"))
        valid_data_for_plot = False
        for cond in plot_conditions:
            subset = plot_data[plot_data['condition'] == cond]
            if not subset.empty and not subset[metric_col].isnull().all():
                valid_data_for_plot = True
                fig.add_trace(go.Violin(y=subset[metric_col], name=cond, box_visible=True, meanline_visible=True,
                                      marker_color=color_map.get(cond, '#888888')))
        if not valid_data_for_plot:
            print(f"[WARN] No valid data to plot for metric: {metric_col}, side: {side_label} {title_suffix}")
            continue
        current_title = f"{base_title} {title_suffix}"
        if p_option == 2 and side_label != 'combined': current_title = f"{base_title} ({side_label} side) {title_suffix}"
        fig.update_layout(title=current_title, yaxis_title=y_label, xaxis_title="Condition", showlegend=True,
                          legend_title="Condition", width=max(600, 150 * len(plot_conditions)), height=600)
        fig.show()
        print(f"[INFO] Displayed violin plot: {current_title}")

def plot_average_power_spectrum(agg_spectra_df: pd.DataFrame, title_suffix: str,
                                plot_conditions: list, color_map: dict, p_option: int,
                                spectrum_plot_period_min_axis: float, spectrum_plot_period_max_axis: float):
    if agg_spectra_df.empty:
        print(f"[WARN] No aggregated spectra data to plot {title_suffix}")
        return
    base_title = "Average Power Spectrum (Lomb-Scargle)"
    y_label, x_label = "Mean Power (a.u.)", "Period (bp)"
    plot_items: List[Tuple[str, pd.DataFrame]] = [('whole', agg_spectra_df)] if p_option == 1 else \
        ([('neg', agg_spectra_df[agg_spectra_df['pos_neg'] == 'neg']), ('pos', agg_spectra_df[agg_spectra_df['pos_neg'] == 'pos'])] if 'pos_neg' in agg_spectra_df.columns else [('combined', agg_spectra_df)])

    for side_label, plot_data_side in plot_items:
        if plot_data_side.empty:
            if p_option == 2 : print(f"[WARN] No aggregated spectra data for side: {side_label} {title_suffix}")
            elif p_option == 1 : print(f"[WARN] No aggregated spectra data for {title_suffix}")
            continue
        fig = go.Figure(layout=dict(template="plotly_white"))
        valid_data_for_plot = False
        for cond in plot_conditions:
            subset = plot_data_side[plot_data_side['condition'] == cond].sort_values('period')
            if not subset.empty and 'mean_power' in subset.columns and 'sem_power' in subset.columns:
                valid_data_for_plot = True
                subset_sem = subset['sem_power'].fillna(0)
                fig.add_trace(go.Scatter(x=subset['period'], y=subset['mean_power'], mode='lines', name=cond,
                                       line=dict(color=color_map.get(cond, '#888888'))))
                fig.add_trace(go.Scatter(x=np.concatenate([subset['period'], subset['period'][::-1]]),
                                       y=np.concatenate([subset['mean_power'] + subset_sem, (subset['mean_power'] - subset_sem)[::-1]]),
                                       fill='toself', fillcolor=hex_to_rgba_str(color_map.get(cond, '#888888'), 0.2),
                                       line=dict(color='rgba(255,255,255,0)'), hoverinfo="skip", showlegend=False))
        if not valid_data_for_plot:
            print(f"[WARN] No valid data to plot for power spectrum, side: {side_label} {title_suffix}")
            continue
        current_title = f"{base_title} {title_suffix}"
        if p_option == 2 and side_label != 'combined': current_title = f"{base_title} ({side_label} side) {title_suffix}"
        fig.update_layout(title=current_title, yaxis_title=y_label, xaxis_title=x_label,
                          xaxis=dict(range=[spectrum_plot_period_min_axis, spectrum_plot_period_max_axis]),
                          showlegend=True, legend_title="Condition", width=1000, height=600)
        fig.show()
        print(f"[INFO] Displayed power spectrum plot: {current_title}")

# ───────────────── Generate and Display Plots ─────────────────
if not lombscargle_metrics_df.empty:
    plot_lombscargle_metric_distribution(lombscargle_metrics_df, 'dominant_period', # Changed DataFrame
                                 f"({NUC_PERIOD_MIN}-{NUC_PERIOD_MAX} bp range, Lomb-Scargle based)", # Updated suffix
                                 conditions, COND_COLORS, PROCESSING_OPTION,
                                 y_axis_label="Dominant Period (bp)")
    plot_lombscargle_metric_distribution(lombscargle_metrics_df, 'dominant_power', # Changed DataFrame
                                 f"(from {NUC_PERIOD_MIN}-{NUC_PERIOD_MAX} bp period range, Lomb-Scargle based)", # Updated suffix
                                 conditions, COND_COLORS, PROCESSING_OPTION,
                                 y_axis_label="Power of Dominant Frequency (a.u.)")
    plot_lombscargle_metric_distribution(lombscargle_metrics_df, 'sharpness_q', # Changed DataFrame
                                 f"(Q Factor, {NUC_PERIOD_MIN}-{NUC_PERIOD_MAX} bp period range, Lomb-Scargle based)", # Updated suffix
                                 conditions, COND_COLORS, PROCESSING_OPTION,
                                 y_axis_label="Sharpness (Q Factor)")
else:
    print("[INFO] Skipping Lomb-Scargle metric distribution plots as metrics DataFrame is empty.")

if not aggregated_spectra_df.empty:
    plot_average_power_spectrum(aggregated_spectra_df,
                                f"({SPECTRUM_PLOT_PERIOD_MIN}-{SPECTRUM_PLOT_PERIOD_MAX} bp, {SPECTRUM_PLOT_PERIOD_STEP}bp res.)",
                                conditions, COND_COLORS, PROCESSING_OPTION,
                                SPECTRUM_PLOT_PERIOD_MIN, SPECTRUM_PLOT_PERIOD_MAX)
else:
    print("[INFO] Skipping average power spectrum plots as aggregated spectra DataFrame is empty.")

print(f"[INFO] Cell 3: Analysis and Plotting finished in {(time.time() - t_cell3_start):.2f}s")

In [None]:
###############################################################################
# CELL 4 — window‑wise dampening metrics & box‑plots
###############################################################################
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from scipy.signal import hilbert


# ────────────────── New high‑level toggle ────────────────── #
DATAPOINT_UNIT = "type"   # "type" (current)  |  "read" (one point per read)

# ─────────────── Metric configuration (unchanged) ───────────────
METRIC_OPTION = 5          # 0=max‑min, 1=robust, 2=median, 3=RMS, 4=AUC, 5=hilbert RNS
HIGH_PCTL     = 95
LOW_PCTL      = 5
METRIC_LABELS = {0:"Max − Min",1:"P95 − P5",2:"Median Difference",
                 3:"RMS Difference",4:"AUC Difference", 5:"Hilbert-RMS"}

# ─────────────── Unchanged user config ───────────────
# choose half‑width for the centred window (175 bp total)
WINDOW_CENTER_START = 135   # Center for "n"
BIN_STEP            = 177   # Step size between window centers
WINDOW_WIDTH        = 150   # Window width (bp)
NUC_REPS            = 5     # Number of windows on each side (set as needed)

WIN_HALF = round(BIN_STEP / 2)  # 88 bp

# ─────────────── Build df_all at desired grain ───────────────
if DATAPOINT_UNIT == "type":
    # existing per‑type dataframe already prepared in Cell 2
    df_all = plot_df_auto.copy()
else:  # "read" → regenerate per‑read smoothed autocorr
    df_all = per_read_ac.copy()
    df_all['autocorr_smooth'] = (
        df_all.groupby(['read_id', 'pos_neg'])['autocorr']
              .transform(final_smooth)
    )
    # retain *type* for optional future use
    # (no aggregation → one row per read/lag already)

# common lag re‑mapping
if PROCESSING_OPTION == 2:
    df_all['plot_lag'] = np.where(df_all['pos_neg'] == 'neg',
                                  -df_all['lag'],
                                   df_all['lag'])
else:
    df_all['plot_lag'] = df_all['lag']
df_all = df_all[df_all['plot_lag'].between(-MAX_LAG, MAX_LAG)].reset_index(drop=True)

# ─────────────── Build sliding windows (unchanged) ───────────────
window_defs = []

if PROCESSING_OPTION == 1:
    # Only the positive side: n+1 to n+NUC_REPS
    for i in range(1, NUC_REPS + 1):
        center = WINDOW_CENTER_START + (i - 1) * BIN_STEP
        win_start = center - WINDOW_WIDTH // 2
        win_end = center + WINDOW_WIDTH // 2
        label = f"n+{i}"
        window_defs.append({
            'center': center,
            'start': win_start,
            'end': win_end,
            'label': label
        })
else:
    # Negative side: n-NUC_REPS to n-1 (left of zero)
    for i in range(NUC_REPS, 0, -1):  # n-5, ..., n-1
        center = -WINDOW_CENTER_START - (i - 1) * BIN_STEP
        win_start = center - WINDOW_WIDTH // 2
        win_end = center + WINDOW_WIDTH // 2
        label = f"n-{i}"
        window_defs.append({
            'center': center,
            'start': win_start,
            'end': win_end,
            'label': label
        })
    # Positive side: n+1 to n+NUC_REPS (right of zero)
    for i in range(1, NUC_REPS + 1):  # n+1, ..., n+5
        center = WINDOW_CENTER_START + (i - 1) * BIN_STEP
        win_start = center - WINDOW_WIDTH // 2
        win_end = center + WINDOW_WIDTH // 2
        label = f"n+{i}"
        window_defs.append({
            'center': center,
            'start': win_start,
            'end': win_end,
            'label': label
        })

# Sort by center position for plotting/order
window_defs = sorted(window_defs, key=lambda w: w['center'])
ordered_windows = [w['label'] for w in window_defs]



# ─────────────── Compute metric per window  (vectorised) ───────────────
records = []

# 1.  Build a long dataframe that tags each row with its window + part
# For each window, tag all points within the window
parts = []
key_cols = ['condition', 'type'] if DATAPOINT_UNIT == "type" else ['condition', 'read_id', 'type']

for w in window_defs:
    mask = df_all['plot_lag'].between(w['start'], w['end'])
    if mask.any():
        tmp = df_all.loc[mask, key_cols + ['autocorr_smooth','plot_lag']].copy()
        tmp['window'] = w['label']
        parts.append(tmp)

long_df = pd.concat(parts, ignore_index=True)


# 2.  Aggregate once → pivot so ‘peak’ and ‘trough’ are columns
if METRIC_OPTION == 0:
    wide = (
        long_df
        .groupby(key_cols + ['window'])['autocorr_smooth']
        .agg(['max', 'min'])
        .dropna(subset=['max', 'min'])
    )
    diff_vals = wide['max'] - wide['min']

elif METRIC_OPTION == 1:
    q = (
        long_df
        .groupby(key_cols + ['window'])['autocorr_smooth']
        .quantile([LOW_PCTL/100, HIGH_PCTL/100])
        .unstack(-1)
        .rename(columns={LOW_PCTL/100: f'q{LOW_PCTL}', HIGH_PCTL/100: f'q{HIGH_PCTL}'})
        .dropna(subset=[f'q{LOW_PCTL}', f'q{HIGH_PCTL}'])
    )
    diff_vals = q[f'q{HIGH_PCTL}'] - q[f'q{LOW_PCTL}']

elif METRIC_OPTION == 2:
    median = (
        long_df
        .groupby(key_cols + ['window'])['autocorr_smooth']
        .median()
    )
    diff_vals = median  # Or: median - median, or just use as summary


elif METRIC_OPTION == 3:
    rms = (
        long_df
        .groupby(key_cols + ['window'])['autocorr_smooth']
        .apply(lambda x: np.sqrt((x**2).mean()))
    )
    diff_vals = rms


elif METRIC_OPTION == 4:
    auc = (
        long_df
        .set_index('plot_lag', append=True)
        .groupby(key_cols + ['window'])['autocorr_smooth']
        .apply(lambda x: np.trapz(x.values, x.index.get_level_values('plot_lag')))
    )
    diff_vals = auc


elif METRIC_OPTION == 5:
    from scipy.signal import hilbert
    hilbert_rms = (
        long_df
        .groupby(key_cols + ['window'])['autocorr_smooth']
        .apply(lambda x: np.sqrt(np.mean(np.abs(hilbert(x.values))**2)))
    )
    diff_vals = hilbert_rms


else:
    raise ValueError("Unsupported option in fast path.")

# 3.  Re‑assemble diff_df from the differences
diff_df = (
    diff_vals
    .rename('diff')
    .reset_index()
    .merge(
        pd.DataFrame(window_defs),
        left_on='window', right_on='label',
        how='left'
    )
)

diff_df['window'] = pd.Categorical(diff_df['window'], categories=ordered_windows, ordered=True)


# ─────────────── Combined box‑plot across conditions ───────────────
from plotly.subplots import make_subplots
# 1. compute per‐window median and IQR
stats = (
    diff_df
    .groupby(['window','condition'])['diff']
    .agg(
        q1=lambda x: np.percentile(x, 25),
        median='median',
        q3=lambda x: np.percentile(x, 75)
    )
)
# pivot into wide form
stats_wide = stats.unstack('condition')
median_df = stats_wide['median']
q1_df     = stats_wide['q1']
q3_df     = stats_wide['q3']

first_cond = conditions[0]
# 2. compute relative median and IQR with respect to the first condition
rel_median = median_df.subtract(median_df[first_cond], axis=0)
rel_q1     = q1_df    .subtract(median_df[first_cond], axis=0)
rel_q3     = q3_df    .subtract(median_df[first_cond], axis=0)

# 3. rebuild the 2‐row subplot (top row unchanged)
fig = make_subplots(
    rows=2, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.02,
    row_heights=[0.8, 0.2],
    subplot_titles=[
        f"{METRIC_LABELS[METRIC_OPTION]} per 175 bp Window",
        f"Median Δ vs. {first_cond} (with IQR)"
    ]
)

# --- top row: your existing box‐plots ---
for cond in conditions:
    sub = diff_df[diff_df['condition'] == cond]
    jitter_on = len(sub) <= 10_000
    fig.add_trace(
        go.Box(
            x=sub['window'], y=sub['diff'], name=cond,
            marker_color=COND_COLORS[cond],
            boxpoints='all' if jitter_on else 'outliers',
            jitter=0.3 if jitter_on else 0,
            marker_size=4 if jitter_on else 3
        ),
        row=1, col=1
    )

# --- bottom row: line + asymmetric IQR error bars ---
# zero‐baseline line
fig.add_trace(
    go.Scatter(
        x=rel_median.index, y=[0]*len(rel_median),
        mode='lines', line=dict(color='black', dash='dash'),
        showlegend=False
    ),
    row=2, col=1
)
for cond in conditions[1:]:
    fig.add_trace(
        go.Scatter(
            x=rel_median.index,
            y=rel_median[cond],
            mode='lines+markers',
            name=f"{cond} − {first_cond}",
            line=dict(color=COND_COLORS[cond]),
            error_y=dict(
                type='data',
                symmetric=False,
                array= rel_q3[cond] - rel_median[cond],      # upper bar
                arrayminus= rel_median[cond] - rel_q1[cond]  # lower bar
            )
        ),
        row=2, col=1
    )

# 4. axes & layout tweaks
fig.update_xaxes(categoryorder='array', categoryarray=ordered_windows, row=1, col=1)
fig.update_xaxes(categoryorder='array', categoryarray=ordered_windows, row=2, col=1)
fig.update_yaxes(title_text=METRIC_LABELS[METRIC_OPTION], row=1, col=1)
fig.update_yaxes(title_text="Δ Median (IQR)", row=2, col=1)

fig.update_layout(
    template='plotly_white',
    width=900, height=800,
    boxmode='group',
    showlegend=True
)

fig.show()

#
#
# # ─────────────── Faceted box‑plots per condition ───────────────
# n = len(conditions)
# fig_sub = make_subplots(rows=1, cols=n, shared_yaxes=True,
#                         subplot_titles=conditions, horizontal_spacing=0.03)
# for i, cond in enumerate(conditions, start=1):
#     sub = diff_df[diff_df['condition'] == cond]
#     jitter_on = len(sub) <= 10_000
#     fig_sub.add_trace(
#         go.Box(x=sub['window'], y=sub['diff'],
#                marker_color=COND_COLORS[cond],
#                boxpoints='all' if jitter_on else 'outliers',
#                jitter=0.3 if jitter_on else 0,
#                marker_size=4 if jitter_on else 3,
#                showlegend=False),
#         row=1, col=i
#     )
#     fig_sub.update_xaxes(categoryorder='array',
#                          categoryarray=ordered_windows, row=1, col=i)
# fig_sub.update_layout(
#     title=(f"{METRIC_LABELS[METRIC_OPTION]} per Condition"
#            + (" (auto‑flipped)" if AUTO_FLIP and PROCESSING_OPTION != 1 else "")
#            + (f" – one point per {DATAPOINT_UNIT}" if DATAPOINT_UNIT=="read" else "")),
#     xaxis_title="Window span (bp)",
#     yaxis_title=METRIC_LABELS[METRIC_OPTION],
#     width=300 * n, height=600,
#     template='plotly_white'
# )
# fig_sub.show()


In [None]:
# CELL 5 — subtraction of baseline (condition1) vs others per sliding windows
#            matching Cell 3’s binning

import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# ─────────────── Bring in df_all from Cell 2’s plot_df_auto ───────────────
if PROCESSING_OPTION == 2:
    neg = plot_df_auto[plot_df_auto['pos_neg']=='neg'].copy()
    neg['plot_lag'] = -neg['lag']
    pos = plot_df_auto[plot_df_auto['pos_neg']=='pos'].copy()
    pos['plot_lag'] = pos['lag']
    df_all = pd.concat([neg, pos], ignore_index=True)
else:
    df_all = plot_df_auto.copy()
    df_all['plot_lag'] = df_all['lag']

# ─────────────── Configuration ───────────────
BIN_STEP = 173    # same as in Cell 3
# offsets for the two half‐windows
MIN_OFF, MAX_OFF = (80, 119), (160, 199)

# ─────────────── Build sliding windows ───────────────
window_defs = []
k = 0
while True:
    # positive‐side min/max
    min_s = MIN_OFF[0] + BIN_STEP * k
    min_e = MIN_OFF[1] + BIN_STEP * k
    max_s = MAX_OFF[0] + BIN_STEP * k
    max_e = MAX_OFF[1] + BIN_STEP * k
    if max_e > MAX_LAG:
        break

    window_defs.append({
        'side':      'pos',
        'min_range': (min_s, min_e),
        'max_range': (max_s, max_e),
        'span':      (min_s, max_e)
    })
    if PROCESSING_OPTION != 1:
        window_defs.append({
            'side':      'neg',
            'min_range': (-min_e, -min_s),
            'max_range': (-max_e, -max_s),
            'span':      (-max_e, -min_s)
        })

    k += 1

# sort by numeric start of each span, and label
window_defs = sorted(window_defs, key=lambda w: w['span'][0])
for w in window_defs:
    w['label'] = f"{int(w['span'][0])} to {int(w['span'][1])}"

# ─────────────── Compute max−min per condition/type/window ───────────────
records = []
for w in window_defs:
    sub_min = df_all[df_all['plot_lag'].between(*w['min_range'])]
    sub_max = df_all[df_all['plot_lag'].between(*w['max_range'])]
    if sub_min.empty or sub_max.empty:
        continue

    gm = sub_max.groupby(['condition','type'])['autocorr_smooth'].max()
    gn = sub_min.groupby(['condition','type'])['autocorr_smooth'].min()
    diff = (gm - gn).reset_index(name='range_val')
    diff['window'] = w['label']
    records.append(diff)

range_df = pd.concat(records, ignore_index=True)

# ─────────────── Prepare subtraction DataFrame ───────────────
baseline    = conditions[0]
other_conds = conditions[1:]

pivot = range_df.pivot_table(
    index=['type','window'],
    columns='condition',
    values='range_val'
)

sub_records = []
for (typ, win), row in pivot.iterrows():
    base_val = row.get(baseline, np.nan)
    for cond in other_conds:
        val = row.get(cond, np.nan)
        if not np.isnan(base_val) and not np.isnan(val):
            sub_records.append({
                'type':        typ,
                'window':      win,
                'subtraction': f"{baseline} − {cond}",
                'diff':        base_val - val
            })

sub_df = pd.DataFrame(sub_records)

# ─────────────── Combined box‑plot of subtractions ───────────────
fig_comb = go.Figure(layout=dict(template='plotly_white'))
for cond in other_conds:
    label = f"{baseline} − {cond}"
    sub   = sub_df[sub_df['subtraction']==label]
    fig_comb.add_trace(go.Box(
        x=sub['window'], y=sub['diff'],
        name=label,
        marker_color=COND_COLORS[cond],
        boxpoints='all', jitter=0.3
    ))

fig_comb.update_layout(
    title=f"Range Difference per Window: {baseline} minus Others",
    xaxis_title="Window span (bp)",
    yaxis_title="(Max–Min)₍baseline₎ − (Max–Min)₍cond₎",
    boxmode='group',
    width=1200, height=600
)
fig_comb.update_xaxes(
    categoryorder='array',
    categoryarray=[w['label'] for w in window_defs]
)
fig_comb.show()

# ─────────────── Subplots per subtraction pair ───────────────
subs = [f"{baseline} − {c}" for c in other_conds]
n    = len(subs)
fig_sub = make_subplots(
    rows=1, cols=n,
    shared_yaxes=True,
    subplot_titles=subs,
    horizontal_spacing=0.03
)

for i, label in enumerate(subs, start=1):
    sub  = sub_df[sub_df['subtraction']==label]
    cond = label.split(' − ')[1]
    fig_sub.add_trace(
        go.Box(
            x=sub['window'], y=sub['diff'],
            name=label,
            marker_color=COND_COLORS[cond],
            boxpoints='all', jitter=0.3,
            showlegend=False
        ),
        row=1, col=i
    )
    fig_sub.update_xaxes(
        categoryorder='array',
        categoryarray=[w['label'] for w in window_defs],
        row=1, col=i
    )

fig_sub.update_layout(
    title=f"Per‑window Range Difference: {baseline} minus Each Condition",
    xaxis_title="Window span (bp)",
    yaxis_title="Difference",
    width=300 * n, height=600
)
fig_sub.show()


In [None]:
# COLOR_LOOKUP        = {          # user palette
#     CONDITIONS_TO_PLOT[0]: "#1F78B4", # blue
#     CONDITIONS_TO_PLOT[1]: "#E31A1C", # red
# }

# COLOR_LOOKUP        = {          # user palette
#     CONDITIONS_TO_PLOT[0]: "#E31A1C", # red
#     CONDITIONS_TO_PLOT[1]: "#51ab4d", # green
# }
# 
COLOR_LOOKUP        = {          # user palette
    CONDITIONS_TO_PLOT[0]: "#1F78B4", # blue
    CONDITIONS_TO_PLOT[1]: "#51ab4d", # green
}

# ───────────────────────────── 4.  Plotly figure ────────────────────────────
fig = go.Figure(layout=dict(template="plotly_white"))
for cond in CONDITIONS_TO_PLOT:
    sub = plot_df_auto[plot_df_auto["condition"] == cond]
    fig.add_trace(
        go.Scatter(
            x=sub["lag"],
            y=sub["autocorr_smooth"],
            mode="lines",
            name=cond,
            line=dict(color=COLOR_LOOKUP.get(cond, None)),
        )
    )

fig.update_layout(
    title="Read‑level Autocorrelation",
    xaxis_title="Lag (bp)",
    yaxis_title="Mean autocorrelation",
    legend_title="Condition",
    # set x axis range min to 20
    xaxis=dict(range=[20, LAG_RANGE]),
    yaxis=dict(range=[-0.025, 0.07]),
    width=900,
    height=600,
)
fig.show()

# ───────────────────────── QUICK‑DIAGNOSTIC SNIPPET ─────────────────────────
print("\n─ BASIC SHAPES ─")
print("per_read_ac :", per_read_ac.shape)
print("plot_df_auto     :", plot_df_auto.shape)

print("\n─ UNIQUE CONDITIONS IN per_read_ac ─")
print(per_read_ac["condition"].value_counts(dropna=False).head(10))

print("\n─ UNIQUE CONDITIONS IN plot_df_auto ─")
print(plot_df_auto["condition"].value_counts(dropna=False))

print("\n─ CONDITIONS_TO_PLOT VS plot_df_auto ─")
print("Missing in plot_df_auto:",
      set(CONDITIONS_TO_PLOT) - set(plot_df_auto["condition"].unique()))

print("\n─ per_read_ac autocorr stats by condition (lag 0 only) ─")
print(
    per_read_ac[per_read_ac["lag"] == 0]
      .groupby("condition")["autocorr"]
      .agg(["count", "min", "max", "mean"])
      .head(10)
)

print("\n─ NaN fraction in per_read_ac['autocorr'] ─")
print(per_read_ac["autocorr"].isna().mean().round(3))

print("\n─ First few rows of plot_df_auto ─")
print(plot_df_auto.head())

print("\n─ Check smoothed values ─")
print(plot_df_auto["autocorr_smooth"].describe())
# ─────────────────────────────────────────────────────────────────────────────


In [None]:
# Reimport nanotools
import importlib
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks
from plotly.subplots import make_subplots

# Assuming 'nanotools' is a module you have and 'down_sampled_plot_df', 'down_sampled_group_df', etc. are defined
# importlib.reload(nanotools)

def create_plot(plot_df, group_df, condition, chr_type, data_type, plot_window, plot_motifs=True):
    print("Creating dataframes...")
    plot_df_copy = plot_df.copy(deep=True)
    plot_df_copy = plot_df_copy[(plot_df_copy['condition'] == condition) &
                                (plot_df_copy['chr_type'] == chr_type) &
                                (plot_df_copy['type'] == data_type) &
                                (plot_df_copy['rel_pos'] > -plot_window) &
                                (plot_df_copy['rel_pos'] < plot_window)]

    # Drop rows where both smallest_positive_nuc_midpoint and greatest_negative_nuc_midpoint are NaN
    plot_df_copy = plot_df_copy[~(plot_df_copy['smallest_positive_nuc_midpoint'].isna() & plot_df_copy['greatest_negative_nuc_midpoint'].isna())]
    plot_df_copy = plot_df_copy.sort_values(by=['smallest_positive_nuc_midpoint', 'greatest_negative_nuc_midpoint'], ascending=[True, False])
    plot_df_copy_nodups = plot_df_copy.drop_duplicates(subset=['read_id'])[['read_id', 'smallest_positive_nuc_midpoint', 'greatest_negative_nuc_midpoint']]
    plot_df_copy.reset_index(inplace=True, drop=True)

    plot_df_copy_nodups.reset_index(inplace=True, drop=True)
    # Use ngroup to create an incrementing column in ascending order
    plot_df_copy_nodups['read_count'] = range(1, len(plot_df_copy_nodups) + 1)

    # Merge the read_count column back into plot_df_copy
    plot_df_copy = pd.merge(plot_df_copy, plot_df_copy_nodups[['read_id', 'read_count']], on='read_id', how='left')

    # Drop rows from down_sampled_group_df_copy where read_id not in plot_df_copy read_ids
    down_sampled_group_df_copy = group_df.copy(deep=True)
    down_sampled_group_df_copy = down_sampled_group_df_copy[down_sampled_group_df_copy['read_id'].isin(plot_df_copy_nodups['read_id'])]
    # Merge read_count column from plot_df_copy_no_dups with down_sampled_group_df_copy on read_id
    down_sampled_group_df_copy = pd.merge(down_sampled_group_df_copy, plot_df_copy_nodups[['read_id', 'read_count']], on='read_id', how='left')
    # Drop rows where nucs_list is NaN
    down_sampled_group_df_copy.dropna(subset=['nucs_list'], inplace=True)
    # nanotools.display_sample_rows(down_sampled_group_df_copy, 10)  # Uncomment if needed

    # Initialize a numpy array with zeros for each base pair in the genome region
    genome_size = 2 * plot_window

    # Initialize read_counts for occupancy calculation
    read_counts = np.zeros(genome_size)

    # Calculate sum and count of mod_qual_bin at each rel_pos
    agg_df = plot_df_copy.groupby('rel_pos')['mod_qual_bin'].agg(['sum', 'count']).reset_index()
    agg_df['ratio'] = agg_df['sum'] / agg_df['count']
    # Calculate the moving average of the ratio with a centered window
    rolling_window_size = 50
    agg_df['moving_avg'] = agg_df['ratio'].rolling(window=rolling_window_size, center=True).mean()
    # Drop NaN values
    agg_df.dropna(inplace=True)

    print("Adding m6A line traces...")
    # Prepare data for read plot (Figure 1)
    fig1 = make_subplots(rows=1, cols=1)
    fig1.update_xaxes(range=[-plot_window, plot_window])

    # For calculating occupancy
    for read_id in plot_df_copy['read_count'].unique():
        read_data = plot_df_copy[plot_df_copy['read_count'] == read_id]
        min_rel_pos = read_data['rel_pos'].min()
        max_rel_pos = read_data['rel_pos'].max()

        for pos in range(int(min_rel_pos + plot_window), int(max_rel_pos + plot_window + 1)):
            if 0 <= pos < genome_size:
                read_counts[pos] += 1

        fig1.add_trace(
            go.Scatter(x=[min_rel_pos, max_rel_pos],
                       y=[read_data['read_count'].iloc[0], read_data['read_count'].iloc[0]],
                       mode='lines', line=dict(color='#000000', width=0.2), showlegend=False),
            row=1, col=1
        )

    print("Plotting nucleosomes on read plot...")
    # Plot nucleosomes on read plot
    midpoints_list = []
    x_coords = []
    y_coords = []

    for read_id in down_sampled_group_df_copy['read_count']:
        read_data = down_sampled_group_df_copy[down_sampled_group_df_copy['read_count'] == read_id]
        read_height = read_data['read_count'].iloc[0]
        # Drop nucs from nucs_list that are outside of plot_window
        read_data['nucs_list'] = read_data['nucs_list'].apply(lambda x: [nuc for nuc in x if nuc >= -plot_window and nuc <= plot_window])

        for nuc in read_data['nucs_list'].iloc[0]:
            NUC_width = 147
            midpoints_list.append(nuc)
            min_rel_pos = nuc - NUC_width / 2
            max_rel_pos = nuc + NUC_width / 2

            x_coords.extend([min_rel_pos, max_rel_pos, None])  # Use None to separate individual line segments
            y_coords.extend([read_height, read_height, None])

    # Create a new column for marker colors based on 'mod_qual' values
    plot_df_copy['marker_color'] = plot_df_copy['mod_qual'].apply(
        lambda x: '#FF0000' if x >= 0.85 else 'rgba(0,0,0,0)'  # Red color for >=0.85, transparent otherwise
    )

    # Upper scatter plot
    scatter_trace = go.Scatter(
        x=plot_df_copy['rel_pos'],
        y=plot_df_copy['read_count'],
        mode='markers',
        marker=dict(
            size=4,  # Increase the size of the dots
            color='rgba(0,0,0,0)',  # No fill color
            line=dict(
                width=1,  # Border width
                color=plot_df_copy['marker_color']  # Set border color to the same as 'marker_color'
            ),
        ),
        name='Read Count'
    )
    fig1.add_trace(scatter_trace, row=1, col=1)

    # Add a single trace for all nucleosome line segments
    fig1.add_trace(
        go.Scatter(
            x=x_coords,
            y=y_coords,
            mode='lines',
            line=dict(color='rgba(51,0,141,45)', width=2),
            opacity=0.75,
            showlegend=False
        ),
        row=1,
        col=1
    )

    # # Add Rex Line to fig1
    # fig1.add_shape(
    #     go.layout.Shape(
    #         type="line",
    #         x0=0,
    #         x1=0,
    #         y0=0,
    #         y1=1,
    #         yref="paper",
    #         line=dict(
    #             color="grey",
    #             width=1,
    #             dash="dash",
    #         )
    #     )
    # )

    # Plot motifs if plot_motifs is True
    # Update layout for fig1
    fig1.update_layout(template="simple_white",
                       height=300,  # Adjusted to maintain relative dimensions
                       width=1000,
                       )
    fig1.update_yaxes(title_text="Read_ID", row=1, col=1)
    fig1.update_xaxes(title_text="Genomic location (bp)", row=1, col=1)

    # Create Figure 2 for the other plots
    print("Creating Figure 2 for the other plots...")
    fig2 = make_subplots(rows=2,
                         cols=1,
                         shared_xaxes=True,
                         vertical_spacing=0.02,
                         specs=[[{}], [{"secondary_y": True}]],
                         row_heights=[0.25, 0.25])

    fig2.update_xaxes(range=[-plot_window, plot_window])

    # Lower line plot for moving average of the ratio
    line_trace = go.Scatter(x=agg_df['rel_pos'], y=agg_df['moving_avg'], mode='lines',
                            line=dict(color='#FF0000', width=2),
                            line_shape='spline')
    fig2.add_trace(line_trace, row=1, col=1)

    print("Plotting histogram and smoothed density on Figure 2...")
    # Histogram of nucleosome midpoints
    rolling_window_size_hist = 20
    hist_bins = int(round(2 * plot_window / 10) + 1)
    midpoint_histogram = go.Histogram(x=midpoints_list,
                                      nbinsx=hist_bins,
                                      marker=dict(color='rgba(51,0,141,45)', opacity=0.6)
                                      )
    fig2.add_trace(midpoint_histogram, row=2, col=1, secondary_y=False)

    # Assume midpoints_list contains midpoints of nucleosomes for the current plot
    nucleosome_array = np.zeros(genome_size)

    # Populate the nucleosome_array based on the midpoints
    for midpoint in midpoints_list:
        position_index = int(midpoint + plot_window)
        if 0 <= position_index < genome_size:
            nucleosome_array[position_index] += 1

    nucleosome_normalized = np.divide(nucleosome_array, read_counts, where=read_counts != 0)

    # Apply Gaussian smoothing with a standard deviation of 20 base pairs
    smoothed_nucleosome_array = gaussian_filter1d(nucleosome_normalized, 10)

    # Generate x values for the smoothed density plot
    x_values = np.arange(-plot_window, plot_window, 1)

    # Add the smoothed nucleosome density as a line trace to the second subplot
    smoothed_trace = go.Scatter(
        x=x_values,
        y=smoothed_nucleosome_array,
        mode='lines',
        name='Smoothed Nucleosome Density',
        line=dict(color='rgba(51,0,141,45)', width=2),
    )
    fig2.add_trace(smoothed_trace, row=2, col=1, secondary_y=True)

    # Adjust x-axis labels to appear only on the lower plot of fig2
    fig2.update_xaxes(showticklabels=False, row=1, col=1)
    fig2.update_xaxes(title_text="Genomic location (bp)", row=2, col=1)

    # Update y-axes titles for fig2
    fig2.update_yaxes(title_text="% m6A", row=1, col=1)
    fig2.update_yaxes(title_text="Nucleosome Count", row=2, col=1)
    fig2.update_yaxes(title_text="Nucleosome Density", row=2, col=1, secondary_y=True)

    # Update layout for fig2
    fig2.update_layout(template="simple_white",
                       height=300,  # Adjusted to maintain relative dimensions
                       width=1000,
                       )

    # Plot motifs if plot_motifs is True
    if plot_motifs:
        print("Plotting motifs on both figures...")
        # Ensure 'motif_rel_start' and 'motif_id' are in plot_df_copy
        if 'motif_rel_start' in plot_df_copy.columns and 'motif_id' in plot_df_copy.columns:
            data_filtered = plot_df_copy[['motif_rel_start', 'motif_id']].dropna()

            if not data_filtered.empty:
                import ast

                def ensure_list(x):
                    if isinstance(x, str):
                        return ast.literal_eval(x)
                    elif isinstance(x, tuple):
                        return list(x)
                    elif isinstance(x, list):
                        return x
                    else:
                        return [x]

                # Ensure 'motif_rel_start' and 'motif_id' are lists
                data_filtered['motif_rel_start'] = data_filtered['motif_rel_start'].apply(ensure_list)
                data_filtered['motif_id'] = data_filtered['motif_id'].apply(ensure_list)

                # Create a new column 'motif_pairs' which is a list of tuples
                data_filtered['motif_pairs'] = data_filtered.apply(
                    lambda row: list(zip(row['motif_rel_start'], row['motif_id'])),
                    axis=1
                )

                # Explode 'motif_pairs' to get each (motif_rel_start, motif_id) pair in its own row
                exploded_df = data_filtered.explode('motif_pairs')

                # Split 'motif_pairs' into 'motif_rel_start' and 'motif_id'
                exploded_df[['motif_rel_start', 'motif_id']] = pd.DataFrame(
                    exploded_df['motif_pairs'].tolist(), index=exploded_df.index
                )

                # Remove duplicates if any
                exploded_df = exploded_df[['motif_rel_start', 'motif_id']].drop_duplicates()

                # Group by 'motif_rel_start' to handle overlapping motifs at the same position
                grouped_motifs = exploded_df.groupby('motif_rel_start')['motif_id'].apply(list).reset_index()

                # Function to add motifs to a given figure
                def add_motifs_to_figure(fig):
                    # Global counter for motif annotations
                    annotation_idx = 0

                    for _, row in grouped_motifs.iterrows():
                        motif_rel_start = row['motif_rel_start']
                        motif_ids = row['motif_id']
                        num_motifs = len(motif_ids)

                        # Plot the vertical line once per position
                        fig.add_shape(
                            type="line",
                            x0=motif_rel_start,
                            x1=motif_rel_start,
                            y0=0,
                            y1=1,
                            line=dict(color="grey", width=1, dash="dash"),
                            xref="x",
                            yref="paper",
                        )

                        # Adjust the vertical position of every motif annotation using a global index
                        for motif_id in motif_ids:
                            # Calculate y position using the global annotation index
                            y = 1 + (annotation_idx % 2) * 0.1  # Alternate y-position between 1 and 0.8
                            y_anchor = 'bottom'
                            fig.add_annotation(
                                x=motif_rel_start,
                                y=y,
                                yref="paper",
                                text=f"{motif_id}",
                                showarrow=False,
                                yanchor=y_anchor,
                                xanchor="center",
                                font=dict(size=10),
                                bgcolor="rgba(255,255,255,0.5)",  # Optional: Background color for readability
                            )

                            # Increment the global annotation index for each motif
                            annotation_idx += 1


                # Add motifs to fig1
                add_motifs_to_figure(fig1)

                # Add motifs to fig2
                add_motifs_to_figure(fig2)

    # remove legends from both figures
    fig1.update_layout(showlegend=False)
    fig2.update_layout(showlegend=False)
    
    return fig1, fig2

# Sample usage of the function
# Replace the following variables with your actual data
selec_cond = analysis_cond[0]
selec_type = "rex32"
selected_chr_type = "X"
# bed_window = 1000  # Example value
# n_read_ids = 100   # Example value

# Assuming 'down_sampled_plot_df' and 'down_sampled_group_df' are your DataFrames
fig1, fig2 = create_plot(down_sampled_plot_df, down_sampled_group_df, selec_cond, selected_chr_type, selec_type, int(round(bed_window, 0)))

# Save and show the figures
fig1.write_image("temp_files/"+selec_cond+"_"+selec_type+"_read_track.png", width=1000, height=400)
fig2.write_image("temp_files/"+selec_cond+"_"+selec_type+"_nuc_track.png", width=1000, height=400)
fig1.write_image("temp_files/"+selec_cond+"_"+selec_type+"_read_track.svg", width=1000, height=400)
fig2.write_image("temp_files/"+selec_cond+"_"+selec_type+"_nuc_track.svg", width=1000, height=400)
# save as svg

fig1.show(renderer='plotly_mimetype+notebook')
fig2.show(renderer='plotly_mimetype+notebook')


In [None]:
### Generate bedgraph from bam files
## Note: files are saved in same folders as original .bam files
regenerate_bit = False # SEt to true to force regenerate, otherwise load if available.
num_processors = 10

# Generating the list of input bam folder paths from new_bam_files
input_bam_paths = [os.path.dirname(bam) for bam in new_bam_files]

# Function to run a single command
def modkit_pileup_extract(args):
    (each_bam, each_thresh, each_condition, each_index, each_bamfrac, each_expid, 
     modkit_path, output_stem, num_processors) = args

    # if regenerate_bit is True delete all files ending in .bedgraph in output_stem
    if regenerate_bit:
        for file in os.listdir(output_stem):
            if file.endswith(".bedgraph"):
                print("Deleting file: ", os.path.join(output_stem, file))
                os.remove(os.path.join(output_stem, file))
                
    # Check if the output file exists
    if not regenerate_bit:
        print("Checking if file exists: ", output_stem + "/"+each_expid + "-" + each_condition + "_a_A0_m_GC1.bedgraph")
        if os.path.exists(output_stem + "/"+each_expid + "-" + each_condition + "_a_A0_m_GC1.bedgraph"):
            print(f"File already exists: {output_stem}/{each_expid}-{each_condition}_a_A0_m_GC1.bedgraph")
            # Read in output file and check if empty
            return
        else:
            for file in os.listdir(output_stem):
                # if file contains {each_expid}-{each_condition} and ends with .bedgraph, delete it
                if each_expid in file and each_condition in file and file.endswith(".bedgraph"):
                    print("Deleting file: ", os.path.join(output_stem, file))

    
    print(f"Starting on bam file: ", each_bam)
    command = [
        modkit_path,
        "pileup",
        #"--only-tabs",
        #"--ignore",
        #"m",
        "--threads",
        f"{num_processors}",
        "--bedgraph",
        #"--combine-strands",
        #"--filter-threshold",
        #f"A:{1-each_thresh}",
        #f"A:{1-each_thresh}",
        "--mod-thresholds",
        f"a:{each_thresh}",
        "--mod-thresholds",
        f"m:{each_thresh}",
        "--ref",
        "/Data1/reference/c_elegans.WS235.genomic.fa",
        #"--filter-threshold",
        #f"A:{1-each_thresh}",
        #"--filter-threshold",
        #f"C:{1-each_thresh}",
        "--motif",
        "GC",
        "1",
        #"--motif",
        #"CC",
        #"0",
        "--motif",
        "A",
        "0",
        "--prefix",
        f"{each_expid}-{each_condition}",

        #"--include-bed",
        #modkit_bed_name,
        each_bam,
        output_stem
    ]
    subprocess.run(command, text=True)
    
    # delete any files in output_stem that contain any of the following strings: "a_CG0" or "m_A0"
    for file in os.listdir(output_stem):
        if "a_GC1" in file or "m_A0" in file or "a_CC0" in file:
            print("Deleting file: ", os.path.join(output_stem, file))
            os.remove(os.path.join(output_stem, file))

    # if m_GC1_positive and m_GC1_negative files not exist, due to missing mods in bam file, create file with empty row
    if not os.path.exists(f"{output_stem}/{each_expid}-{each_condition}_m_GC1_positive.bedgraph"):
        with open(f"{output_stem}/{each_expid}-{each_condition}_m_GC1_positive.bedgraph", "w") as f:
            f.write("\n")
    if not os.path.exists(f"{output_stem}/{each_expid}-{each_condition}_m_GC1_negative.bedgraph"):
        with open(f"{output_stem}/{each_expid}-{each_condition}_m_GC1_negative.bedgraph", "w") as f:
            f.write("\n")
    
    # Merge A0_negative and A0_positive files by concatenating them, and then sorting by chromosome and start position in bash
    # and saving the output to a new file, then deleting the old files
    def merge_and_sort_bedgraph_files(output_stem, each_expid, each_condition, file_suffixes, num_processors=8):
        for suffix_pair in file_suffixes:
            negative_suffix, positive_suffix, output_suffix = suffix_pair
    
            negative_file = f"{output_stem}/{each_expid}-{each_condition}_{negative_suffix}.bedgraph"
            positive_file = f"{output_stem}/{each_expid}-{each_condition}_{positive_suffix}.bedgraph"
            merged_file = f"{output_stem}/{each_expid}-{each_condition}_{output_suffix}.bedgraph"
    
            command = f"cat {negative_file} {positive_file} | sort -k1,1 -k2,2n --parallel={num_processors} > {merged_file}"
            subprocess.run(command, shell=True)
            
            
            # if either suffix contains "positive" or "negative, Delete the old files
            if "positive" in negative_suffix or "negative" in negative_suffix or "positive" in positive_suffix or "negative" in positive_suffix:
                os.remove(negative_file)
                os.remove(positive_file)
    
    file_suffixes = [
        ("a_A0_negative", "a_A0_positive", "a_A0"),
        ("m_GC1_negative", "m_GC1_positive", "m_GC1"),
        ("a_A0", "m_GC1", "a_A0_m_GC1"),
    ]
    
    merge_and_sort_bedgraph_files(output_stem, each_expid, each_condition, file_suffixes, num_processors)

    
    
# Now you need to adjust the task_args to include the index
# Prepare the arguments for each task
task_args = list(zip(
    new_bam_files,
    thresh_list,
    conditions,
    sample_indices,
    bam_fracs,
    exp_ids,
    [modkit_path]*len(new_bam_files),
    input_bam_paths,
    [num_processors] * len(new_bam_files)
))

# Select task_args where new_bam_files contains "AG1"
#task_args = [task for task in task_args if "AG1" in task[0]]

# Print bam paths for debugging
print("new_bam_files: ", input_bam_paths)

# Execute commands in parallel w
with Pool(processes=8) as pool:
    pool.map(modkit_pileup_extract, task_args)

In [None]:
    # OPTIONAL
### Fill and Smooth
import os
import subprocess
import tempfile
import pandas as pd
import multiprocessing
import importlib
import nanotools

def process_bedgraph(args):
    each_bam, each_condition, each_expid, smoothing_window, imputation_window, bedgraphtobigwig_path, force_replace, raw_only = args

    # Define raw output file names
    bedgraph_fn = os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0.bedgraph")
    raw_bw_fn = os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0.bw")

    if not os.path.exists(raw_bw_fn) or force_replace:
        print("Converting raw bedgraph directly to raw bigwig...")

        with tempfile.NamedTemporaryFile(mode='w+t', delete=False, suffix='.bedgraph') as temp_file:
            temp_filename = temp_file.name
            cut_command = f"cut -f 1-4 {bedgraph_fn}"

            try:
                subprocess.run(cut_command, shell=True, check=True, stdout=temp_file)
            except subprocess.CalledProcessError as e:
                print(f"An error occurred while cutting the bedgraph file: {e}")
                os.unlink(temp_filename)
                raise

        try:
            bigwig_command = [
                bedgraphtobigwig_path,
                temp_filename,
                chrom_sizes,
                raw_bw_fn
            ]
            subprocess.run(bigwig_command, check=True)
        except subprocess.CalledProcessError as e:
            print(f"An error occurred during bedgraph to bigwig conversion: {e}")
            raise
        finally:
            print("Saved raw bigwig file: ", raw_bw_fn)
            os.unlink(temp_filename)
    else:
        print("Raw bigwig file already exists, skipping conversion.")

    # If only raw bigwig files are to be created, skip the rest
    if raw_only:
        return raw_bw_fn

    # Define output file names
    filled_bedgraph_fn = os.path.join(os.path.dirname(each_bam),
                                      f"{each_expid}-{each_condition}_a_A0_raw_filled.bedgraph")
    filled_bw_fn = filled_bedgraph_fn.replace(".bedgraph", ".bw")
    nafilled_bedgraph_fn = os.path.join(os.path.dirname(each_bam),
                                        f"{each_expid}-{each_condition}_a_A0_nafilled.bedgraph")
    nafilled_bw_fn = nafilled_bedgraph_fn.replace(".bedgraph", ".bw")
    smoothed_bedgraph_fn = os.path.join(os.path.dirname(each_bam),
                                        f"{each_expid}-{each_condition}_a_A0_smoothed-{smoothing_window}-{imputation_window}.bedgraph")
    smoothed_bigwig_fn = smoothed_bedgraph_fn.replace(".bedgraph", ".bw")

    bedgraph_df = pd.DataFrame()

    if not os.path.exists(filled_bedgraph_fn) or force_replace:
        print("Starting to fill raw bedgraph...")
        print("Loading bedgraph file: ", bedgraph_fn)
        if bedgraph_df.empty:
            bedgraph_df = nanotools.load_bedgraph_file(bedgraph_fn)
            bedgraph_df['score'] = bedgraph_df['score'].fillna(0)

        bedgraph_df[['chromosome', 'start', 'end', 'score']].to_csv(
            filled_bedgraph_fn,
            sep="\t", header=False, index=False)
    else:
        print(f"Raw filled bedgraph file already exists, skipping: {filled_bedgraph_fn}")

    if not os.path.exists(filled_bw_fn) or force_replace:
        print("Converting filled bedgraph to bigwig...")
        try:
            bigwig_command = [
                bedgraphtobigwig_path,
                filled_bedgraph_fn,
                chrom_sizes,
                filled_bw_fn
            ]
            subprocess.run(bigwig_command, check=True)
        except subprocess.CalledProcessError as e:
            print(f"An error occurred during bedgraph to bigwig conversion: {e}")
            raise
    else:
        print(f"Raw filled bigwig file already exists, skipping: {filled_bw_fn}")

    if not os.path.exists(nafilled_bedgraph_fn) or force_replace:
        print("Starting to fill raw bedgraph with NAs...")
        print("Loading bedgraph file: ", bedgraph_fn)
        if bedgraph_df.empty:
            bedgraph_df = nanotools.load_bedgraph_file(bedgraph_fn)

        bedgraph_df[['chromosome', 'start', 'end', 'score']].to_csv(
            nafilled_bedgraph_fn,
            sep="\t", header=False, index=False)
    else:
        print(f"Raw NA filled bedgraph file already exists, skipping: {nafilled_bedgraph_fn}")

    if not os.path.exists(smoothed_bedgraph_fn) or force_replace:
        if bedgraph_df.empty:
            print("Loading bedgraph file: ", bedgraph_fn)
            bedgraph_df = nanotools.load_bedgraph_file(bedgraph_fn)
            bedgraph_df['score'] = bedgraph_df['score'].fillna(0)

        print(f"Imputing and smoothing bedgraph file: {smoothed_bedgraph_fn}")
        bedgraph_df['imputed_score'], bedgraph_df['imputed_coverage'], bedgraph_df['smoothed_score'], bedgraph_df[
            'smoothed_coverage'] = nanotools.parallel_impute_and_smooth(
            bedgraph_df,
            impute_window=imputation_window,
            smooth_window=smoothing_window,
            fill_value=0
        )

        print(f"Saving smoothed bedgraph file: {smoothed_bedgraph_fn}")
        bedgraph_df[['chromosome', 'start', 'end', 'smoothed_score']].to_csv(
            smoothed_bedgraph_fn,
            sep="\t", header=False, index=False)
    else:
        print(f"Imputed and smoothed bedgraph file already exists, skipping: {smoothed_bedgraph_fn}")

    if not os.path.exists(smoothed_bigwig_fn) or force_replace:
        command = [
            bedgraphtobigwig_path,
            smoothed_bedgraph_fn,
            chrom_sizes,
            smoothed_bigwig_fn
        ]
        print(f"Converting smoothed bedgraph to bigwig: {smoothed_bigwig_fn}")
        subprocess.run(command, text=True, capture_output=True)
    else:
        print(f"Imputed and smoothed bigwig file already exists, skipping: {smoothed_bigwig_fn}")

    return raw_bw_fn

# Configurable parameters
smoothing_window = 20
imputation_window = 0
force_replace = False  # Set this to True if you want to force replacement of existing files
raw_only = True  # Set this to True to create raw bw files only, skipping other operations

# Prepare arguments for multiprocessing
args_list = [
    (each_bam, each_condition, each_expid, smoothing_window, imputation_window, bedgraphtobigwig_path, force_replace, raw_only)
    for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)
]

# Use multiprocessing to process bedgraph files in parallel and collect raw bigwig paths
with multiprocessing.Pool(processes=15) as pool:
    raw_bw_files = pool.map(process_bedgraph, args_list)

print("All processing completed.")
print("Raw bigwig files:", raw_bw_files)

if not raw_only:
    # If you still want to collect the filled bigwig files as before:
    filled_bw_files = [os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0_raw_filled.bw")
                       for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)]
    print("Filled bigwig files:", filled_bw_files)


In [None]:
### Plot size of files
import os
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots


def get_file_info(bam_files, conditions, exp_ids):
    file_info = []
    filled_bigwig_files = []
    smoothed_bigwig_files = []
    for bam_file, condition, exp_id in zip(bam_files, conditions, exp_ids):
        output_dir = os.path.dirname(bam_file)

        patterns = [
            f"{exp_id}-{condition}_a_A0.bedgraph",
            f"{exp_id}-{condition}_a_A0_raw.bw",
            f"{exp_id}-{condition}_a_A0_raw_filled.bedgraph",
            f"{exp_id}-{condition}_a_A0_raw_filled.bw",
            f"{exp_id}-{condition}_a_A0_smoothed-{smoothing_window}-{imputation_window}.bedgraph",
            f"{exp_id}-{condition}_a_A0_smoothed-{smoothing_window}-{imputation_window}.bw"
        ]

        for pattern in patterns:
            file_path = os.path.join(output_dir, pattern)
            if os.path.exists(file_path):
                file_size = os.path.getsize(file_path) / (1024 ** 3)  # Convert to GB

                if file_path.endswith('.bedgraph'):
                    file_type = 'Bedgraph'
                elif file_path.endswith('.bw'):
                    file_type = 'Bigwig'

                if 'raw' in file_path and 'filled' not in file_path:
                    processing = 'Raw'
                elif 'filled' in file_path:
                    processing = 'Filled'
                    if file_type == 'Bigwig':
                        filled_bigwig_files.append(file_path)
                elif 'smoothed' in file_path:
                    processing = 'Smoothed'
                    if file_type == 'Bigwig':
                        smoothed_bigwig_files.append(file_path)
                else:
                    processing = 'Original'

                file_info.append({
                    'File Name': os.path.basename(file_path),
                    'File Path': file_path,
                    'File Size (bytes)': file_size,
                    'File Type': file_type,
                    'Processing': processing,
                    'Experiment ID': exp_id,
                    'Condition': condition
                })

    return file_info, filled_bigwig_files, smoothed_bigwig_files


# Use the variables from your original script
new_bam_files = [each_bam for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)]
conditions = [each_condition for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)]
exp_ids = [each_expid for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)]

# Define smoothing_window and imputation_window as in your original script
smoothing_window = 20
imputation_window = 0

# Get the file information
file_info, filled_bigwig_files, smoothed_bigwig_files = get_file_info(new_bam_files, conditions, exp_ids)

if file_info:
    df_file_info = pd.DataFrame(file_info)
    df_file_info = df_file_info.sort_values('File Size (bytes)', ascending=False)

    # Create plots using Plotly (same as before)
    file_types = df_file_info['File Type'].unique()
    fig = make_subplots(rows=len(file_types), cols=1,
                        subplot_titles=[f"{ftype} File Sizes" for ftype in file_types],
                        vertical_spacing=0.1)

    for i, file_type in enumerate(file_types, start=1):
        df_subset = df_file_info[df_file_info['File Type'] == file_type]

        trace = go.Bar(
            x=df_subset['Experiment ID'],
            y=df_subset['File Size (bytes)'],
            name=file_type,
            text=df_subset['Processing'],
            hoverinfo='text+y',
            hovertext=[f"Exp ID: {exp}<br>Size: {size:.2f} GB<br>Processing: {proc}"
                       for exp, size, proc in zip(df_subset['Experiment ID'],
                                                  df_subset['File Size (bytes)'],
                                                  df_subset['Processing'])]
        )

        fig.add_trace(trace, row=i, col=1)

        fig.update_xaxes(title_text="Experiment ID", row=i, col=1)
        fig.update_yaxes(title_text="File Size (GB)", row=i, col=1)

    fig.update_layout(
        height=300 * len(file_types),
        title_text="File Sizes by Experiment ID and File Type",
        showlegend=False,
        template="plotly_white"
    )

    fig.show()

    # Print total number of files found
    print(f"\nTotal number of files found: {len(file_info)}")

    # Print total size of all files
    total_size = sum(file['File Size (bytes)'] for file in file_info)
    print(f"Total size of all files: {total_size:.2f} GB")

    # Print lists of filled and smoothed bigwig files
    print("\nFilled Bigwig Files:")
    for file in filled_bigwig_files:
        print(file)

    print("\nSmoothed Bigwig Files:")
    for file in smoothed_bigwig_files:
        print(file)

else:
    print("No matching files found in the specified directories.")

In [None]:
import os
import pandas as pd
from multiprocessing import Pool
from tqdm import tqdm
import numpy as np


def get_optimal_dtypes(chunk):
    dtypes = {
        'chromosome': 'category',
        'start': 'uint32',
    }
    score_col = chunk.columns[-1]
    if pd.api.types.is_float_dtype(chunk[score_col]):
        if chunk[score_col].apply(lambda x: x.is_integer()).all():
            dtypes[score_col] = 'float32'
        else:
            dtypes[score_col] = 'float32'
    elif pd.api.types.is_integer_dtype(chunk[score_col]):
        dtypes[score_col] = 'float32'
    else:
        dtypes[score_col] = 'float32'
    return dtypes


def process_chunk(args):
    chunk, file_name = args
    chunk.columns = ['chromosome', 'start', file_name]
    dtypes = get_optimal_dtypes(chunk)
    return chunk.astype(dtypes)


def process_bedgraph_in_chunks(file_path, chunk_size=1000000, rows_to_process=None, suffix=None):
    file_name = os.path.basename(file_path).replace(suffix, '')
    chunks = pd.read_csv(file_path, sep='\t', header=None, chunksize=chunk_size, usecols=[0, 1, 3],
                         nrows=rows_to_process)

    with Pool(processes=10) as pool:
        processed_chunks = list(pool.imap(process_chunk, ((chunk, file_name) for chunk in chunks)))

    return pd.concat(processed_chunks)


def check_and_concat_dataframes(base_df, new_df):
    if base_df is None:
        return new_df

    if not np.array_equal(base_df[['chromosome', 'start']].values, new_df[['chromosome', 'start']].values):
        raise ValueError("chromosome, start, and end columns are not identical across all files.")

    base_df[new_df.columns[-1]] = new_df.iloc[:, -1]
    return base_df


def find_bedgraph_files(directory, suffix=None):
    if suffix is None:
        print("Requires Suffix to find files")
        return []
    return [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(suffix)]


def process_bedgraph_files(new_bam_files, file_prefix, force_replace=False, bam_filter=None, bedgraph_filter=None,
                           output_dir=None):
    unique_directories = list(set(os.path.dirname(path) for path in new_bam_files))
    print(f"Found {len(unique_directories)} unique directories.")
    suffix = "_nafilled.bedgraph"
    all_bedgraph_files = []
    for directory in unique_directories:
        all_bedgraph_files.extend(find_bedgraph_files(directory, suffix=suffix))

    if bedgraph_filter:
        all_bedgraph_files = [file for file in all_bedgraph_files if any(substr in file for substr in bedgraph_filter)]
    else:
        all_bedgraph_files = [file for file in all_bedgraph_files if
                              "BM_" in file or "BK_" in file or "BN_" in file or "AG-22" in file or "AH-" in file or "AM" in file or "H1" in file]

    print(f"Found {len(all_bedgraph_files)} bedgraph files.")
    for file in all_bedgraph_files:
        print(os.path.basename(file))

    final_df = None
    if output_dir is None:
        output_dir = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"
    temp_folder = os.path.join(output_dir, file_prefix)
    input_file = os.path.join(temp_folder, "merged_bedgraph_test.csv")

    if not os.path.exists(input_file) or force_replace:
        for file in tqdm(all_bedgraph_files, desc="Processing files"):
            try:
                df = process_bedgraph_in_chunks(file, rows_to_process=None, suffix=suffix)
                final_df = check_and_concat_dataframes(final_df, df)
            except Exception as e:
                print(f"Error processing {file}: {str(e)}")
                continue

        if final_df is not None:
            print("Saving merged dataframe...")
            os.makedirs(os.path.dirname(input_file), exist_ok=True)
            final_df.to_csv(input_file, index=False,force_replace=True)
            print(
                f"Processed {len(all_bedgraph_files)} files from {len(unique_directories)} directories. Merged dataframe saved as '{input_file}'")
        else:
            print("No data was processed successfully. Please check the errors above.")
    else:
        print("Merged dataframe already exists, skipping. Use force_replace=True to overwrite.")

    return input_file


# Filter BAM files based on original criteria
filtered_bam_files = new_bam_files
#[
#    bam for bam in new_bam_files
#    if "BM_" in bam or "BK_" in bam or "BN_" in bam or "AG1_" in bam or "AH_" in bam or "AM" in bam or "H1" in bam
#]

# Execute the function with pre-filled filters
result = process_bedgraph_files(
    filtered_bam_files,
    file_prefix,
    force_replace=False,  # Set to True if you want to overwrite existing files
    bedgraph_filter=['Y9B-08_03_21_23', 'D1A-nb_12_22_22', 'AD1-nb_06_13_23', 'AB-05_04_10_23', 'AB-04_04_10_23', 'AH-07_08_19_23', 'AH-08_08_19_23', 'H1-nb_12_10_22', 'AB-10_04_10_23', 'AB-09_04_10_23', 'AH-09_08_19_23', 'AG-22_11_30_23', 'BN_05_24_24', 'BM_05_30_24'], #["BM_", "BK_", "BN_", "AG-22", "AH-", "AM","H1"],
    output_dir="/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"
)

print(f"Output file: {result}")

In [None]:
### Qnromalization
import stat
import numpy as np

# https://pypi.org/project/qnorm/
from qnorm import quantile_normalize

# reimport nanotools
importlib.reload(nanotools)
import pyarrow.csv as pv


def import_normalize_smooth_and_convert(bedgraphtobigwig_path, chrom_sizes, imputation_window, smoothing_window):
    temp_folder = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"
    input_file = temp_folder + file_prefix + "/merged_bedgraph_test.csv"
    qnormalized_bedgraph_files = []
    qnormalized_bigwig_files = []

    # Check if the input file exists
    if not os.path.exists(input_file):
        raise FileNotFoundError(f"Input file not found: {input_file}")

    # Read the CSV file
    print("Importing data...")
    # Read the CSV file using pyarrow with parallel processing
    read_options = pv.ReadOptions(use_threads=True)
    table = pv.read_csv(input_file, read_options=read_options)

    # Convert the pyarrow Table to a pandas DataFrame
    df = table.to_pandas()

    print(f"Imported data shape: {df.shape}")

    # Check the structure of the dataframe
    if 'chromosome' not in df.columns or 'start' not in df.columns:
        raise ValueError("Input file must have 'chromosome' and 'start' columns")

    if df.shape[1] < 3:
        raise ValueError("Input file must have at least one data column besides 'chromosome' and 'start'")

    # Identify columns to normalize (all except 'chrom' and 'start')
    columns_to_normalize = df.columns[2:]

    # Convert 'na' to NaN and columns to float
    df[columns_to_normalize] = df[columns_to_normalize].replace('na', np.nan).astype(float)

    # Prepare data for normalization
    data_to_normalize = df[columns_to_normalize].values

    # Apply qnorm with parallel processing
    print("Applying quantile normalization using 10 CPU cores...")
    normalized_data = quantile_normalize(data_to_normalize, ncpus=10)

    # Replace original data with normalized data
    df[columns_to_normalize] = normalized_data

    print("Normalization complete.")
    print(f"Final data shape: {df.shape}")

    # Process each column
    for column in columns_to_normalize:
        # Create a new dataframe for this column
        bedgraph_df = df[['chromosome', 'start']].copy()
        bedgraph_df['end'] = bedgraph_df['start'] + 1  # Assuming 1-base positions
        bedgraph_df['score'] = df[column]
        bedgraph_df['coverage'] = 1  # Placeholder for 'coverage' column

        # Generate output filenames
        original_filename = f"{column}_filled.bedgraph"
        qnorm_bedgraph_filename = original_filename.replace("filled.bedgraph", "qnorm.bedgraph")
        qnorm_bigwig_filename = qnorm_bedgraph_filename.replace(".bedgraph", ".bw")
        smoothed_bedgraph_filename = original_filename.replace("filled.bedgraph", "qnorm_smoothed.bedgraph")
        smoothed_bigwig_filename = smoothed_bedgraph_filename.replace(".bedgraph", ".bw")

        qnorm_bedgraph_path = os.path.join(os.path.dirname(input_file), qnorm_bedgraph_filename)
        qnorm_bigwig_path = os.path.join(os.path.dirname(input_file), qnorm_bigwig_filename)
        smoothed_bedgraph_path = os.path.join(os.path.dirname(input_file), smoothed_bedgraph_filename)
        smoothed_bigwig_path = os.path.join(os.path.dirname(input_file), smoothed_bigwig_filename)

        # Save the qnorm bedgraph file from the first 4 columns of bedgraph_df
        bedgraph_df[['chromosome', 'start', 'end', 'score']].to_csv(qnorm_bedgraph_path, sep='\t', index=False,
                                                                    header=False, na_rep='na')
        print(f"Saved qnorm bedgraph: {qnorm_bedgraph_path}")
        qnormalized_bedgraph_files.append(qnorm_bedgraph_path)

        # Convert qnorm bedgraph to bigwig
        if not os.path.exists(qnorm_bigwig_path):
            print(f"Converting {qnorm_bedgraph_filename} to bigwig...")
            try:
                bigwig_command = [bedgraphtobigwig_path, qnorm_bedgraph_path, chrom_sizes, qnorm_bigwig_path]
                subprocess.run(bigwig_command, check=True)
                print(f"Qnorm bigwig file created: {qnorm_bigwig_path}")
                qnormalized_bigwig_files.append(qnorm_bigwig_path)
            except subprocess.CalledProcessError as e:
                print(f"An error occurred during bedgraph to bigwig conversion: {e}")
                raise
        else:
            print(f"Qnorm bigwig file already exists, skipping: {qnorm_bigwig_path}")
            qnormalized_bigwig_files.append(qnorm_bigwig_path)

        # Apply smoothing
        if not os.path.exists(smoothed_bedgraph_path):
            print(f"Imputing and smoothing bedgraph file: {smoothed_bedgraph_path}")
            bedgraph_df['score'] = bedgraph_df['score'].fillna(0)
            bedgraph_df['imputed_score'], bedgraph_df['imputed_coverage'], bedgraph_df['smoothed_score'], bedgraph_df[
                'smoothed_coverage'] = nanotools.parallel_impute_and_smooth(
                bedgraph_df,
                impute_window=imputation_window,
                smooth_window=smoothing_window,
                fill_value=0
            )

            print(f"Saving smoothed bedgraph file: {smoothed_bedgraph_path}")
            bedgraph_df[['chromosome', 'start', 'end', 'smoothed_score']].to_csv(
                smoothed_bedgraph_path,
                sep="\t", header=False, index=False)
        else:
            print(f"Imputed and smoothed bedgraph file already exists, skipping: {smoothed_bedgraph_path}")

        # Convert smoothed bedgraph to bigwig
        if not os.path.exists(smoothed_bigwig_path):
            print(f"Converting smoothed bedgraph to bigwig: {smoothed_bigwig_path}")
            command = [bedgraphtobigwig_path, smoothed_bedgraph_path, chrom_sizes, smoothed_bigwig_path]
            subprocess.run(command, text=True, capture_output=True)
        else:
            print(f"Imputed and smoothed bigwig file already exists, skipping: {smoothed_bigwig_path}")

    print("All files have been processed, normalized, smoothed, and converted.")
    return qnormalized_bedgraph_files, qnormalized_bigwig_files


# Usage
imputation_window = 0  # Set this to your desired value
smoothing_window = 50  # Set this to your desired value
qnormalized_bedgraph_paths, qnormalized_bigwig_paths = import_normalize_smooth_and_convert(bedgraphtobigwig_path,
                                                                                           chrom_sizes,
                                                                                           imputation_window,
                                                                                           smoothing_window)

# Print the list of qnormalized bedgraph file paths
print("Qnormalized bedgraph file paths:")
for path in qnormalized_bedgraph_paths:
    print(path)

# Print the list of qnormalized bigwig file paths
print("\nQnormalized bigwig file paths:")
for path in qnormalized_bigwig_paths:
    print(path)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import os

def filter_and_plot(result_df, qvalue_cutoff=None, percentile_cutoff=None, num_categories=4):
    # Filter rows based on qvalue cutoff if provided
    if qvalue_cutoff is not None:
        result_df = result_df[result_df['LOG10(qvalue)'] >= qvalue_cutoff]

    # Split the type column by "_" and keep only the first element
    result_df['type'] = result_df['type'].str.split("_").str[0]
    # Rename 'all' to 'rex'
    result_df['type'] = result_df['type'].replace('all', 'rex')

    # Apply percentile cutoff for each type if provided
    if percentile_cutoff is not None:
        def filter_by_percentile(group):
            if group.name in ['SDC2', 'SDC3']:
                norm_col = f'mNeon{group.name}_rep1_antimNeon'
                threshold = group[norm_col].quantile(percentile_cutoff)
                return group[group[norm_col] >= threshold]
            else:
                return group  # No filtering for 'rex' and 'intergenic'

        result_df = result_df.groupby('type').apply(filter_by_percentile).reset_index(drop=True)

    # Create normalized columns
    result_df['sdc3_norm'] = result_df['mNeonSDC3_rep1_antimNeon'] / result_df['mNeonSDC3_rep1_IgG']
    result_df['sdc2_norm'] = result_df['mNeonSDC2_rep1_antimNeon'] / result_df['mNeonSDC2_rep1_IgG']

    # Create categories for each type
    def categorize(group):
        if group.name == 'SDC2':
            return pd.qcut(group['sdc2_norm'], q=num_categories, labels=[f'D{i + 1}' for i in range(num_categories)])
        elif group.name == 'SDC3':
            return pd.qcut(group['sdc3_norm'], q=num_categories, labels=[f'D{i + 1}' for i in range(num_categories)])
        else:
            return pd.Series([group.name] * len(group), index=group.index)

    result_df['chip_category'] = result_df.groupby('type').apply(categorize).reset_index(level=0, drop=True)

    # Extract experiment names
    experiment_names = analysis_cond  # Replace with your actual variable

    # Create a custom RdBu colormap for SDC2 and SDC3 (blue to red)
    colors_sdc = plt.cm.RdBu_r(np.linspace(0, 1, num_categories))
    cmap_sdc = LinearSegmentedColormap.from_list("custom_RdBu", colors_sdc)

    # Colors for 'rex' and 'intergenic'
    color_rex = 'green'
    color_intergenic = 'orange'

    def plot_boxplots(data, bw_column, ax1, ax2, title, palette_sdc, buffer=1):
        #drop rows where type == SDC2

        # Create a custom order for the chip_category
        category_order = sorted(
            [cat for cat in data['chip_category'].unique() if cat.startswith('D')],
            key=lambda x: int(x[1:])
        )

        # Boxplot properties
        boxprops = dict(facecolor='none')
        medianprops = dict(color='red', linewidth=2)


        # Plot SDC2 on the first subplot
        sns.boxplot(
            x='chip_category', y=bw_column, data=data[data['type'] == 'SDC3'],
            ax=ax1, palette=palette_sdc, showfliers=False, order=category_order,
            boxprops=boxprops, medianprops=medianprops, width=1
        )

        # Plot 'rex' and 'intergenic' on the second subplot
        sns.boxplot(
            x='chip_category', y=bw_column, data=data[data['type'].isin(['rex', 'intergenic'])],
            ax=ax2, palette=[color_rex, color_intergenic], showfliers=False, order=['rex', 'intergenic'],
            boxprops=boxprops, medianprops=medianprops, width=1
        )

        # Set titles and labels
        ax1.set_title(f'Boxplot for {title}')
        ax1.set_xlabel('ChIP Category')
        ax1.set_ylabel('Average Methylation')
        ax2.set_xlabel('ChIP Category')
        ax2.set_ylabel('')

        # Rotate x-axis labels and decrease font size
        ax1.tick_params(axis='x', rotation=45, labelsize=8)
        ax2.tick_params(axis='x', rotation=45, labelsize=8)

        # Adjust y-axis label font size
        ax1.yaxis.label.set_fontsize(12)

        # Set title font size
        ax1.title.set_fontsize(16)

        # Remove background
        ax1.set_facecolor('none')
        ax2.set_facecolor('none')
        ax1.grid(False)
        ax2.grid(False)

        # Remove legends from individual subplots
        ax1.legend().remove()
        ax2.legend().remove()

        data = data[data['type'] != 'SDC2']
        # Add (n=) for number of datapoints on the x-axis
        for ax in [ax1, ax2]:
            for i, label in enumerate(ax.get_xticklabels()):
                category = label.get_text()
                count = data[(data['chip_category'] == category) & (data[bw_column].notna())].shape[0]
                ax.text(i, ax.get_ylim()[0], f'(n={count})', ha='center', va='top', fontsize=8)

        # Set y-axis limits to 0 to 0.5
        ax1.set_ylim(0, 0.5)
        ax2.set_ylim(0, 0.5)

        # Add buffer space by setting xlim for each subplot
        ax1.set_xlim(-buffer, len(category_order) - 1 + buffer)  # Adjust limits to add buffer space
        ax2.set_xlim(-buffer, 1 + buffer)  # Add buffer space for two categories: rex and intergenic



    # Generate palette for SDC2 categories
    palette_sdc = [cmap_sdc(i / (num_categories - 1)) for i in range(num_categories)]

    # Set up the subplots
    fig, axes = plt.subplots(
        6, 4, figsize=(18, 28),
        gridspec_kw={'width_ratios': [5, 1, 5, 1]}
    )

    # Plot for each average_bw column
    for i, (col, exp_name) in enumerate(
        zip([col for col in result_df.columns if col.startswith('average_')], experiment_names)
    ):
        plot_boxplots(
            result_df, col,
            axes[i // 2, 2 * (i % 2)],
            axes[i // 2, 2 * (i % 2) + 1],
            exp_name, palette_sdc
        )

    # Remove overall background
    fig.patch.set_facecolor('none')

    # Adjust layout
    plt.tight_layout()

    # Adjust subplots to add buffer space around the boxplots
    plt.subplots_adjust(left=0.05, right=0.95)

    # Save as PNG and SVG
    save_path = '/Data1/git/meyer-nanopore/scripts/analysis/temp_files/'
    os.makedirs(save_path, exist_ok=True)

    png_path = os.path.join(save_path, 'SDC2_boxplot_figure.png')
    svg_path = os.path.join(save_path, 'SDC2_boxplot_figure.svg')

    plt.savefig(png_path, format='png', dpi=300, bbox_inches='tight', transparent=True)
    plt.savefig(svg_path, format='svg', bbox_inches='tight', transparent=True)

    print(f"Figures saved as:\n{png_path}\n{svg_path}")

    # Show plot
    plt.show()

    return result_df

# Example usage:
# Assuming result_df is your DataFrame and analysis_cond is defined
result_df_cat = filter_and_plot(result_df, percentile_cutoff=None, num_categories=10)

# Display sample rows and print columns
nanotools.display_sample_rows(result_df_cat, 5)
print(result_df_cat.columns)

# Print count by type
print(result_df_cat['type'].value_counts())


In [None]:
### Generate dataframe for plotting correlation between chip and accessibility
import pandas as pd
import pyBigWig
import numpy as np


qnormalized_bigwig_paths_smoothed = [
    "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/082624/BM_05_30_24-N2_old_SMACseq_R10_a_A0_norm.bw",
    "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/082624/BK_05_30_24-N2_young_SMACseq_R10_a_A0_norm.bw",
    "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/082624/AG-22_11_30_23-N2_old_fiber_R10_a_A0_norm.bw",
    "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/082624/AM_10_08_22-N2_mixed_endogenous_R10_a_A0_norm.bw",
    "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/082624/BN_05_24_24-96_old_DPY27degron_SMACseq_R10_a_A0_norm.bw"
]


def calculate_bigwig_scores(df, qnormalized_bigwig_paths):
    
    # Initialize new columns for each bigWig file
    stats = ['average', 'median', 'sum', 'max']
    for i, bw_path in enumerate(qnormalized_bigwig_paths):
        for stat in stats:
            df[f'{stat}_bw_{i + 1}'] = np.nan

    # Process each bigWig file
    for i, bw_path in enumerate(qnormalized_bigwig_paths):
        print(f"Processing bigWig file {i + 1}/{len(qnormalized_bigwig_paths)}: {bw_path}")
        with pyBigWig.open(bw_path) as bw:
            for index, row in df.iterrows():
                chrom = row['chr']
                start = row['start']
                end = row['end']

                try:
                    values = bw.values(chrom, start, end)

                    values = [v for v in values if v is not None]  # Remove any None values
                    values = [v for v in values if v is not None and not np.isnan(v)]

                    if values:
                        df.at[index, f'average_bw_{i + 1}'] = np.mean(values)
                        df.at[index, f'median_bw_{i + 1}'] = np.median(values)
                        df.at[index, f'sum_bw_{i + 1}'] = np.sum(values)
                        df.at[index, f'max_bw_{i + 1}'] = np.max(values)
                except RuntimeError:
                    # This can happen if the region is not in the bigWig file
                    pass

        print(f"Finished processing bigWig file {i + 1}")

    # Calculate normalized average columns
    for i in range(1, len(qnormalized_bigwig_paths) + 1):
        df[f'norm_avg_bw_{i}'] = df[f'average_bw_{i}'] / df['average_bw_1'] - 1

    return df


bed_file_path = "/Data1/ext_data/qiming_2024/SDC2_SDC3_20_peaks_500_2000_RPKM.csv" # SDC2_SDC3_20_peaks_500_2000_RPKM.csv or SDC2_SDC3_gt20_rpkm.csv for whole regions

center_length = 500 # 500,  1000, or 2000 

# Read the bed file
chip_bed = pd.read_csv(bed_file_path)

# drop rows where length not equal to center_length
chip_bed = chip_bed[chip_bed['length'] == center_length]

# recalculate start and end based on abs_summit and 1/2 of length
chip_bed['start'] = chip_bed['abs_summit'] - chip_bed['length'] // 2
chip_bed['end'] = chip_bed['abs_summit'] + chip_bed['length'] // 2

nanotools.display_sample_rows(chip_bed, 3)
#print columms
print(chip_bed.columns)

### Add bed regions of interest (e.g. control regions
# Create a new dataframe with the same columns as chip_bed
new_rows = pd.DataFrame(columns=chip_bed.columns)

# Map the columns from combined_bed_df_ext to chip_bed
new_rows['type'] = combined_bed_df_ext['type']
new_rows['chr'] = combined_bed_df_ext['chrom']
new_rows['start'] = combined_bed_df_ext['bed_start']
new_rows['end'] = combined_bed_df_ext['bed_end']

# Calculate length and abs_summit
new_rows['length'] = new_rows['end'] - new_rows['start']
new_rows['abs_summit'] = ((new_rows['start'] + new_rows['end']) / 2).astype(int)

# Append the new rows to chip_bed
chip_bed = pd.concat([chip_bed, new_rows], ignore_index=True)

# Sort the resulting dataframe by chr and start position
#chip_bed = chip_bed.sort_values(['chr', 'start']).reset_index(drop=True)

result_df = calculate_bigwig_scores(chip_bed, raw_bw_files)

# Display the first few rows of the resulting dataframe
nanotools.display_sample_rows(result_df, 10)
# print column names
print(result_df.columns)

# display number of rows by type
result_df['type'].value_counts()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import os

def filter_and_plot(result_df, type_to_plot='SDC3', qvalue_cutoff=None, percentile_cutoff=None, num_categories=4):
    """
    Filters and plots the ChIP-seq data for visualization.
    Parameters:
        - type_to_plot (str): The specific type ('SDC2' or 'SDC3') to be plotted against 'intergenic'.
        - qvalue_cutoff (float): Cutoff for filtering by q-value.
        - percentile_cutoff (float): Percentile cutoff for filtering specific types.
        - num_categories (int): Number of categories to use for categorization.
    """
    # Filter rows based on qvalue cutoff if provided
    if qvalue_cutoff is not None:
        result_df = result_df[result_df['LOG10(qvalue)'] >= qvalue_cutoff]

    # Split the type column by "_" and keep only the first element
    result_df['type'] = result_df['type'].str.split("_").str[0]
    # Rename 'all' to 'rex'
    result_df['type'] = result_df['type'].replace('all', 'rex')

    # Apply percentile cutoff for each type if provided
    if percentile_cutoff is not None:
        def filter_by_percentile(group):
            if group.name in ['SDC2', 'SDC3']:
                norm_col = f'mNeon{group.name}_rep1_antimNeon'
                threshold = group[norm_col].quantile(percentile_cutoff)
                return group[group[norm_col] >= threshold]
            else:
                return group  # No filtering for 'rex' and 'intergenic'

        result_df = result_df.groupby('type').apply(filter_by_percentile).reset_index(drop=True)

    # Create normalized columns
    result_df['sdc3_norm'] = result_df['mNeonSDC3_rep1_antimNeon'] / result_df['mNeonSDC3_rep1_IgG']
    result_df['sdc2_norm'] = result_df['mNeonSDC2_rep1_antimNeon'] / result_df['mNeonSDC2_rep1_IgG']

    # Extract experiment names
    experiment_names = analysis_cond  # Replace with your actual variable

    # Colors for 'rex', 'SDC2/SDC3', and 'intergenic'
    color_sdc = '#1F78B4'  # Blue
    color_intergenic = '#E31A1C'  # Red

    def plot_boxplots(data, bw_column, ax, title, type_to_plot):
        # Filter data to include only the selected type and 'intergenic'
        data = data[data['type'].isin([type_to_plot, 'intergenic'])]

        # Boxplot properties
        boxprops = dict(facecolor='none')
        medianprops = dict(color='red', linewidth=2)

        # Plot the selected type and 'intergenic'
        sns.boxplot(
            x='type', y=bw_column, data=data,
            ax=ax, palette=[color_sdc, color_intergenic], showfliers=False, order=[type_to_plot, 'intergenic'],
            boxprops=boxprops, medianprops=medianprops, width=0.6
        )

        # Set titles and labels
        ax.set_title(f'Boxplot for {title} ({type_to_plot} vs. intergenic)')
        ax.set_xlabel('Type')
        ax.set_ylabel('Average Methylation')

        # Rotate x-axis labels and decrease font size
        ax.tick_params(axis='x', rotation=45, labelsize=8)

        # Adjust y-axis label font size
        ax.yaxis.label.set_fontsize(12)

        # Set title font size
        ax.title.set_fontsize(16)

        # Remove background
        ax.set_facecolor('none')
        ax.grid(False)

        # Add (n=) for number of datapoints on the x-axis
        for i, label in enumerate(ax.get_xticklabels()):
            category = label.get_text()
            count = data[(data['type'] == category) & (data[bw_column].notna())].shape[0]
            ax.text(i, ax.get_ylim()[0], f'(n={count})', ha='center', va='top', fontsize=8)

        # Set y-axis limits to 0 to 0.5
        ax.set_ylim(0, 0.4)

    # Set width and height for the entire plot
    # Adjust the number of subplots dynamically based on the number of conditions
    num_conditions = len(experiment_names)

    # Set width and height for the entire plot
    fig_width = 2  # Fixed width for one-column layout
    fig_height = 4 * num_conditions  # Adjust height based on number of rows

    # Create a figure with a single column and the right number of rows
    fig, axes = plt.subplots(
        nrows=num_conditions, ncols=1,
        figsize=(fig_width, fig_height)  # Adjust height dynamically
    )

    # If only one axis, convert to list for consistency
    if num_conditions == 1:
        axes = [axes]

    # Update indexing in the loop to use a single axis per plot
    for i, (col, exp_name) in enumerate(
        zip([col for col in result_df.columns if col.startswith('average_')], experiment_names)
    ):
        plot_boxplots(result_df, col, axes[i], exp_name, type_to_plot)

    # Remove overall background
    fig.patch.set_facecolor('none')

    # Adjust layout
    plt.tight_layout()

    # Adjust subplots to add buffer space around the boxplots
    plt.subplots_adjust(left=0.05, right=0.95)

    # Save as PNG and SVG
    save_path = '/Data1/git/meyer-nanopore/scripts/analysis/temp_files/'
    os.makedirs(save_path, exist_ok=True)

    png_path = os.path.join(save_path, f'{type_to_plot}_ChIP_boxplot_figure.png')
    svg_path = os.path.join(save_path, f'{type_to_plot}_ChIP_boxplot_figure.svg')

    plt.savefig(png_path, format='png', dpi=300, bbox_inches='tight', transparent=True)
    plt.savefig(svg_path, format='svg', bbox_inches='tight', transparent=True)

    print(f"Figures saved as:\n{png_path}\n{svg_path}")

    # Show plot
    plt.show()

    return result_df


# Example usage:
# Assuming result_df is your DataFrame and analysis_cond is defined
result_df_cat = filter_and_plot(result_df, type_to_plot='SDC2', percentile_cutoff=None, num_categories=10)

# Display sample rows and print columns
nanotools.display_sample_rows(result_df_cat, 5)
print(result_df_cat.columns)

# Print count by type
print(result_df_cat['type'].value_counts())


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import os
import matplotlib.ticker as ticker
from scipy import stats

def filter_and_plot(result_df, qvalue_cutoff=None, percentile_cutoff=None, num_categories=4):
    # Filter rows based on qvalue cutoff if provided
    if qvalue_cutoff is not None:
        result_df = result_df[result_df['LOG10(qvalue)'] >= qvalue_cutoff]

    # Create the new normalized columns
    result_df['sdc3_norm'] = result_df['mNeonSDC3_rep1_antimNeon'] / result_df['mNeonSDC3_rep1_IgG']
    result_df['sdc2_norm'] = result_df['mNeonSDC2_rep1_antimNeon'] / result_df['mNeonSDC2_rep1_IgG']

    # Split the type column by "_" and keep only the first element
    result_df['type'] = result_df['type'].str.split("_").str[0]
    result_df['type'] = result_df['type'].replace('all', 'rex')

    # Filter to keep only SDC2 and SDC3 types
    sdc_df = result_df[result_df['type'].isin(['SDC2', 'SDC3'])].copy()

    # Get 'rex' type rows
    rex_df = result_df[result_df['type'] == 'rex']

    # Function to check if a row overlaps with any 'rex' row
    def overlaps_with_rex(row):
        return any((row['chr'] == rex_row['chr']) and
                   (row['start'] < rex_row['end']) and
                   (row['end'] > rex_row['start'])
                   for _, rex_row in rex_df.iterrows())

    # Mark overlapping rows
    sdc_df['overlaps_rex'] = sdc_df.apply(overlaps_with_rex, axis=1)

    # Apply percentile cutoff for each type if provided
    if percentile_cutoff is not None:
        def filter_by_percentile(group):
            norm_col = f'{group.name.lower()}_norm'
            threshold = group[norm_col].quantile(percentile_cutoff)
            return group[group[norm_col] >= threshold]

        sdc_df = sdc_df.groupby('type').apply(filter_by_percentile).reset_index(drop=True)

    # Create categories for each type using the appropriate normalized column
    def categorize(group):
        norm_col = f'{group.name.lower()}_norm'
        return pd.qcut(group[norm_col], q=num_categories, labels=[f'D{i + 1}' for i in range(num_categories)])

    sdc_df['chip_category'] = sdc_df.groupby('type').apply(categorize).reset_index(level=0, drop=True)

    # Extract experiment names
    experiment_names = [
        'H1_021_SDC2-AIDpAux',
        'AH-09_08_19_23-SDC2_degron',
        'AG-22_11_30_23-N2_old',
        'BN_05_24_24-96_old_DPY27degron',
        'BK_05_30_24-N2_young',
        'BM_05_30_24-N2_old',
    ]

    # Create a custom RdBu (reversed) colormap
    colors = plt.cm.RdBu_r(np.linspace(0, 1, num_categories))
    cmap = LinearSegmentedColormap.from_list("custom_RdBu_r", colors)

    # Color for overlapping points
    overlap_color = 'green'

    # Create a function to plot scatter for each type and average_bw
    def plot_scatter(data, bw_column, ax, title, type_name):
        # Use the colormap to generate colors for each category
        palette = [cmap(i / (num_categories - 1)) for i in range(num_categories)]

        # Determine which normalized column to use
        norm_column = f'{type_name.lower()}_norm'

        # Plot non-overlapping points
        sns.scatterplot(x=norm_column, y=bw_column, hue='chip_category',
                        data=data[~data['overlaps_rex']], ax=ax,
                        palette=palette, alpha=0.6, legend=False)

        # Plot overlapping points
        ax.scatter(data[data['overlaps_rex']][norm_column],
                   data[data['overlaps_rex']][bw_column],
                   c=overlap_color, alpha=0.6, label='Overlaps with "rex"')

        ax.set_title(f'{title} - {type_name}')
        ax.set_xlabel('Fold change (mNeon/IgG RPKM, log scale)')
        ax.set_ylabel('Average m6A Methylation')

        # Calculate R-squared value
        x = data[norm_column].astype(float)
        y = data[bw_column].astype(float)
        mask = ~np.isnan(x) & ~np.isnan(y)
        if mask.sum() > 1:  # Ensure we have at least two valid points
            slope, intercept, r_value, p_value, std_err = stats.linregress(x[mask], y[mask])
            r_squared = r_value**2
        else:
            r_squared = np.nan

        # Add number of datapoints and R-squared value
        count = data[data[bw_column].notna()].shape[0]
        overlap_count = data[data['overlaps_rex']].shape[0]
        ax.text(0.95, 0.15, f'n={count}\nOverlaps={overlap_count}\nR²={r_squared:.3f}',
                ha='right', va='bottom', transform=ax.transAxes)



        # Set axes to log scale
        ax.set_xscale('log')

        # set max x range to log 10
        ax.set_xlim(1, 100)

        # Add minor ticks
        ax.xaxis.set_minor_locator(ticker.LogLocator(subs=np.arange(2, 10)))
        ax.xaxis.set_minor_formatter(ticker.LogFormatter(minor_thresholds=(2, 0.4)))

        # Set major ticks format to non-scientific notation
        ax.xaxis.set_major_formatter(ticker.ScalarFormatter(useOffset=False))
        ax.xaxis.get_major_formatter().set_scientific(False)

        # Remove background
        ax.set_facecolor('none')
        ax.grid(True, which="both", ls="-", alpha=0.2)
        # increase font sizes
        ax.tick_params(axis='both', which='major', labelsize=12)
        ax.tick_params(axis='x', which='minor', labelsize=8, labelbottom=True)
        ax.tick_params(axis='x', which='major', labelsize=8, labelbottom=True)
        ax.title.set_fontsize(16)
        ax.xaxis.label.set_fontsize(14)
        ax.yaxis.label.set_fontsize(14)

    # Set up the subplots
    n_experiments = len(experiment_names)
    n_types = 2  # Only SDC2 and SDC3
    fig, axes = plt.subplots(n_experiments, n_types, figsize=(10, 5*n_experiments), squeeze=False)

    # Plot for each average_bw column and type
    for i, (col, exp_name) in enumerate(zip([col for col in sdc_df.columns if col.startswith('average_')], experiment_names)):
        for j, type_name in enumerate(['SDC2', 'SDC3']):
            ax = axes[i, j]
            data_subset = sdc_df[sdc_df['type'] == type_name]
            plot_scatter(data_subset, col, ax, exp_name, type_name)

    # Create a custom legend
    legend_elements = [plt.Line2D([0], [0], marker='o', color='w', label=f'D{i+1}',
                                  markerfacecolor=cmap(i / (num_categories - 1)), markersize=10)
                       for i in range(num_categories)]
    legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', label='Overlaps with "rex"',
                                      markerfacecolor=overlap_color, markersize=10))

    # Add the legend to the figure
    fig.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5), title='ChIP decile')

    # Remove overall background
    fig.patch.set_facecolor('none')

    # Adjust layout
    plt.tight_layout()

    # Save as PNG and SVG
    save_path = '/Data1/git/meyer-nanopore/scripts/analysis/temp_files/'
    os.makedirs(save_path, exist_ok=True)

    png_path = os.path.join(save_path, 'scatterplot_figure_0to10.png')
    svg_path = os.path.join(save_path, 'scatterplot_figure_0to10.svg')

    plt.savefig(png_path, format='png', dpi=300, bbox_inches='tight', transparent=True)
    plt.savefig(svg_path, format='svg', bbox_inches='tight', transparent=True)

    print(f"Figures saved as:\n{png_path}\n{svg_path}")

    # Show plot
    plt.show()

# Example usage:
filter_and_plot(result_df, percentile_cutoff=0, num_categories=10)

In [None]:
### Calculate bam summary statistics
importlib.reload(nanotools)
# import px
import plotly.express as px
force_replace = False
sampling_frac = 0.25 # fraction of bam to sample for summary statistics

summary_bam_df = pd.DataFrame()

### Define filename for summary table based on selected conditions
# We'll start by defining a function to encapsulate the task you want parallelized
def process_bam(args):
    each_bam, each_condition, each_thresh, each_exp_id = args
    print("starting on bam:", each_bam," | condition:", each_condition,"| each_exp_id:",each_exp_id, "with thresh:", each_thresh)
    return nanotools.get_summary_from_bam(sampling_frac, each_thresh, modkit_path, each_bam, each_condition,each_exp_id, thread_ct = 100)

# Define filename for summary table based on selected conditions
summary_table_name = "temp_files/" + "_" + conditions[0] + conditions[-1] + "_" + str(sampling_frac) + "_thresh" + str(thresh_list[0]) + "_summary_table.csv"

# Check if summary table exists
if not force_replace and os.path.exists(summary_table_name):
    print("Summary table exists, importing...")
    summary_bam_df = pd.read_csv(summary_table_name, sep="\t", header=0)
else:
    print("Summary table does not exist, creating...")
    #
    # Create a pool of worker processes
    pool = multiprocessing.Pool(1)

    # Map the function to the arguments
    results = pool.map(process_bam, zip(new_bam_files, conditions, thresh_list, exp_ids))

    # Close the pool
    pool.close()
    pool.join()

    # Append the results to the summary dataframe
    summary_bam_df = pd.concat(results, ignore_index=True)
    # Reset the index
    summary_bam_df = summary_bam_df.reset_index(drop=True)



    # Save the dataframe to a CSV file
    summary_bam_df.to_csv(summary_table_name, sep="\t", header=True, index=False)

### Create coverage_df file name (similar to summary_table_name)
coverage_df_name =  "temp_files/"+"_"+conditions[0]+conditions[-1] + "_"+str(sampling_frac)+"_thresh"+str(thresh_list[0])+"_coverage_df.csv"
# if coverage_df exists, import it otherwise create it
if not force_replace and os.path.exists(coverage_df_name):
    print("Coverage table exists, importing...")
    coverage_df = pd.read_csv(coverage_df_name, sep="\t", header=0)
else:
    # Calculate total_m6a and total_A
    nanotools.display_sample_rows(summary_bam_df,2)
    # Call the function to create and export the coverage DataFrame
    total_m6a = summary_bam_df.loc[summary_bam_df['code'] == 'a'].groupby(['exp_id', 'condition'])['pass_count'].sum().reset_index()
    total_m6a.rename(columns={'pass_count': 'total_m6a'}, inplace=True)
    total_5mc = summary_bam_df.loc[summary_bam_df['code'] == 'm'].groupby(['exp_id', 'condition'])['pass_count'].sum().reset_index()
    total_5mc.rename(columns={'pass_count': 'total_5mc'}, inplace=True)

    total_A = summary_bam_df.loc[(summary_bam_df['base'] == 'A') & (summary_bam_df['code'] == '-')].groupby(['exp_id', 'condition'])['pass_count'].sum().reset_index()
    total_A.rename(columns={'pass_count': 'total_A'}, inplace=True)

    total_C = summary_bam_df.loc[(summary_bam_df['base'] == 'C') & (summary_bam_df['code'] == '-')].groupby(['exp_id', 'condition'])['pass_count'].sum().reset_index()
    total_C.rename(columns={'pass_count': 'total_C'}, inplace=True)

    # Merge total_m6a and total_A DataFrames
    coverage_df_A = pd.merge(total_m6a, total_A,on=['exp_id', 'condition'], how='outer').fillna(0)
    coverage_df_C = pd.merge(total_5mc, total_C,on=['exp_id', 'condition'], how='outer').fillna(0)
    coverage_df = pd.merge(coverage_df_A, coverage_df_C,on=['exp_id', 'condition'], how='outer').fillna(0)

    # Calculate coverage (ce genome size = 100,272,763)
    coverage_df['coverage'] = ((coverage_df['total_A'] + coverage_df['total_m6a']) * (1/sampling_frac)) / 100000000 * 4 # * 4 since As are 1/4 of genome

    coverage_df['total_A_m6a'] = coverage_df['total_A'] + coverage_df['total_m6a']
    coverage_df['total_C_5mc'] = coverage_df['total_C'] + coverage_df['total_5mc']

    # Calculate m6A_frac
    coverage_df['m6A_frac'] = coverage_df['total_m6a'] / (coverage_df['total_A_m6a'])
    coverage_df['5mC_frac'] = coverage_df['total_5mc'] / (coverage_df['total_C_5mc'])

    # Drop rows where exp_id == AD1-nb_06_13_23
    #coverage_df = coverage_df[coverage_df.exp_id != 'AD1-nb_06_13_23']

    nanotools.display_sample_rows(coverage_df, 5)
    #Save coverage df
    coverage_df.to_csv(coverage_df_name, sep="\t", header=True, index=False)

nanotools.display_sample_rows(coverage_df, 5)

### Plot
def create_plots(coverage_df):
    # Define the desired order of conditions
    #condition_order = ['N2_mixed_endogenous_R10', 'N2_old_fiber_R10', 'N2_old_SMACseq_R10', 'N2_mixed_fiber_R10', 'N2_young_SMACseq_R10', '96_old_DPY27degron_SMACseq_R10','51_old_dpy21null_fiber_R10']

    # Filter and sort the dataframe
    #coverage_df = coverage_df[coverage_df['condition'].isin(condition_order)]
    #coverage_df['condition'] = pd.Categorical(coverage_df['condition'], categories=condition_order, ordered=True)
    #coverage_df = coverage_df.sort_values('condition')

    # Select a color scale
    color_scale = px.colors.qualitative.Vivid

    # Assign a color from the color scale to each condition
    color_dict = {condition: color_scale[i] for i, condition in enumerate(conditions)}

    # Create subplots
    fig = make_subplots(rows=1, cols=3, subplot_titles=("%m6A by Condition", "%5mC by Condition", "Coverage by Condition"))

    # Add bars for %m6A
    m6a_trace = go.Bar(
        x=coverage_df['condition'],
        y=coverage_df['m6A_frac'] * 100,
        name='%m6A',
        marker_color=[color_dict[condition] for condition in coverage_df['condition']],
        text=(coverage_df['m6A_frac']*100).round(2)
    )
    fig.add_trace(m6a_trace, row=1, col=1)

    # Add bars for %5mC
    mc5_trace = go.Bar(
        x=coverage_df['condition'],
        y=coverage_df['5mC_frac'] * 100,
        name='%5mC',
        marker_color=[color_dict[condition] for condition in coverage_df['condition']],
        text=(coverage_df['5mC_frac']*100).round(2)
    )
    fig.add_trace(mc5_trace, row=1, col=2)

    # Add bars for coverage
    coverage_trace = go.Bar(
        x=coverage_df['condition'],
        y=coverage_df['coverage'],
        name='Coverage',
        marker_color=[color_dict[condition] for condition in coverage_df['condition']],
        # label bars
        text=coverage_df['coverage'].round(2)
    )
    fig.add_trace(coverage_trace, row=1, col=3)

    # Update layout using plotly_white theme
    fig.update_layout(
        template="plotly_white",
        title_text='%m6A, %5mC, and Coverage by Condition',
        showlegend=True,
        height=500,
        width=1500,  # Increased width to accommodate the third subplot
    )

    # Update axes
    for i in range(1, 4):
        fig.update_xaxes(title_text='Condition', tickangle=45, row=1, col=i)

    fig.update_yaxes(title_text='Percentage (%)', row=1, col=1)
    fig.update_yaxes(title_text='Percentage (%)', row=1, col=2)
    fig.update_yaxes(title_text='Coverage', row=1, col=3)

    # Create individual figures for each plot
    m6a_fig = go.Figure(m6a_trace)
    m6a_fig.update_layout(
        template="plotly_white",
        title_text='%m6A by Condition',
        xaxis_title='Condition',
        yaxis_title='Percentage (%)',
        xaxis_tickangle=45,
        width = 600
    )

    mc5_fig = go.Figure(mc5_trace)
    mc5_fig.update_layout(
        template="plotly_white",
        title_text='%5mC by Condition',
        xaxis_title='Condition',
        yaxis_title='Percentage (%)',
        xaxis_tickangle=45,
        width = 600
    )

    coverage_fig = go.Figure(coverage_trace)
    coverage_fig.update_layout(
        template="plotly_white",
        title_text='Coverage by Condition',
        xaxis_title='Condition',
        yaxis_title='Coverage',
        xaxis_tickangle=45,
        width = 600
    )

    return fig, m6a_fig, mc5_fig, coverage_fig

# Function call to create and display the plots
combined_plot, m6a_plot, mc5_plot, coverage_plot = create_plots(coverage_df)

# Display the plots
m6a_plot.show()
mc5_plot.show()
coverage_plot.show()

"""
fig.write_image("images_11_14_23/bulk_m6Afrac_n2_sdc2degron_0p1sample.svg")
fig.write_image("images_11_14_23/bulk_m6Afrac_n2_sdc2degron_0p1sample.png")
fig2.write_image("images_11_14_23/coverage_n2_sdc2degron_0p1sample.svg")
fig2.write_image("images_11_14_23/coverage_n2_sdc2degron_0p1sample.png")

# Function call example
### Calculate N50s SKIP, NOT NECESSARY FOR ANY FOLLOWING STEPS
n50_fig = nanotools.calculate_and_plot_n50(new_bam_files, conditions, exp_ids)
n50_fig.show(renderer='plotly_mimetype+notebook')
n50_fig.write_image("images_11_14_23/n50_fig_n2_sdc2degron_0p1sample.svg")
n50_fig.write_image("images_11_14_23/n50_fig_n2_sdc2degron_0p1sample.png")"""

In [None]:
import pandas as pd
import numpy as np

def create_combined_bed_df(result_df, type_selected, num_categories):
    # Normalize the columns
    result_df['sdc2_norm'] = result_df['mNeonSDC2_rep1_antimNeon'] / result_df['mNeonSDC2_rep1_IgG']
    result_df['sdc3_norm'] = result_df['mNeonSDC3_rep1_antimNeon'] / result_df['mNeonSDC3_rep1_IgG']

    # Filter the dataframe for the specified type
    # if type_selected is None, use the entire dataframe
    if type_selected is None:
        filtered_df = result_df
    else:
        # keep all rows where type in type_selected
        filtered_df = result_df[result_df['type'].isin(type_selected)]
        #filtered_df = result_df[result_df['type'] == type_selected]

    # Create categories for each type using the appropriate normalized column
    def categorize(group):
        if group.name == 'SDC2':
            print("grouping ", group.name)
            # print average value for each num_categories
            print(group['sdc2_norm'].quantile(np.linspace(0, 1, num_categories + 1)))

            return pd.qcut(group['sdc2_norm'], q=num_categories, labels=[f'SDC2_D{i + 1}' for i in range(num_categories)])
        elif group.name == 'SDC3':
            print("grouping ", group.name)
            print(group['sdc3_norm'].quantile(np.linspace(0, 1, num_categories + 1)))
            return pd.qcut(group['sdc3_norm'], q=num_categories, labels=[f'SDC3_D{i + 1}' for i in range(num_categories)])
        else:
            print("grouping ", group.name)
            return pd.Series([group.name] * len(group), index=group.index)

    filtered_df['type'] = filtered_df.groupby('type').apply(categorize).reset_index(level=0, drop=True)

    # Create the combined_bed_df
    combined_bed_df = pd.DataFrame({
        'chrom': filtered_df['chr'],
        'bed_start': filtered_df['start'],
        'bed_end': filtered_df['end'],
        'bed_strand': '+',  # Assuming all strands are positive
        'type': filtered_df['type'],
        'chr_type': np.where(filtered_df['chr'] == 'CHROMOSOME_X', 'X', 'Autosome')
    })

    # Reset the index
    combined_bed_df = combined_bed_df.reset_index(drop=True)

    return combined_bed_df

# Example usage:
# Assuming result_df is already defined
#type_selected = ['SDC2'] #,'SDC3','intergenic_control','all']
#type_selected = None
num_categories = 10

combined_bed_df = create_combined_bed_df(result_df, type_selected, num_categories)
print(combined_bed_df)

# print count by type
print(combined_bed_df['type'].value_counts())


In [None]:
# to use mex combined bed dfs: combined_bed_df_mex_clust or combined_bed_df_mex_cat
#combined_bed_df = combined_bed_df_mex_cat.copy()

# keep only type that contain MEX_D10
#combined_bed_df = combined_bed_df[combined_bed_df['type'].str.contains('MEX_D10')]

# Test to keep only strand == "+"
#combined_bed_df = combined_bed_df[combined_bed_df['bed_strand'] == '+']

import pandas as pd
import tempfile
import os


def create_modkit_bed_df(filtered_df):
    # Create the modkit_bed_df
    modkit_bed_df = pd.DataFrame({
        0: filtered_df['chrom'],
        1: filtered_df['bed_start'],
        2: filtered_df['bed_end'],
        3: '.',
        4: '.',
        5: filtered_df['bed_strand']
        #'+' for mex combined bed dfs use filtered_df['bed_strand'] otherwise set strand to '+'
    })


    ## Duplicate each row with '-' strand
    modkit_bed_df_minus = modkit_bed_df.copy()
    ##  if 6th colomn is + set to - otherwise set to +
    modkit_bed_df_minus[5] = np.where(modkit_bed_df_minus[5] == '+', '-', '+')
    ## Combine the '+' and '-' strand dataframes
    modkit_bed_df = pd.concat([modkit_bed_df, modkit_bed_df_minus]).sort_index().reset_index(drop=True)

    # Remove the header
    modkit_bed_df.columns = range(modkit_bed_df.shape[1])
    # convert bed_start and bed_end to ints
    modkit_bed_df[1] = modkit_bed_df[1].astype(int)
    modkit_bed_df[2] = modkit_bed_df[2].astype(int)

    return modkit_bed_df

def save_modkit_bed_to_temp(modkit_bed_df, filename):
    # Create a temporary directory
    temp_dir = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"
    
    # Create the full path for the temporary file
    temp_file_path = os.path.join(temp_dir, filename)
    
    # Save the dataframe to the temporary file
    modkit_bed_df.to_csv(temp_file_path, sep='\t', header=False, index=False)
    
    print(f"Modkit BED file saved to: {temp_file_path}")
    return temp_file_path

# Create modkit_bed_df
modkit_bed_df = create_modkit_bed_df(combined_bed_df)

# Save modkit_bed_df to a temporary file
modkit_bed_name = "modkit_temp.bed"
temp_file_path = save_modkit_bed_to_temp(modkit_bed_df, modkit_bed_name)

print(f"Modkit BED file saved to: {temp_file_path}")

# drop rows with duplicate values
modkit_bed_df = modkit_bed_df.drop_duplicates()

print(modkit_bed_df)

In [None]:
output_stem = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/082624/modkit/"

In [None]:
# Set type_selected to unique type in combined_bed_df
type_selected = combined_bed_df["type"].unique()
print(type_selected)

regenerate_bit = True # SEt to true to force regenerate, otherwise load if available.

### Generate modkit pileup file, used for plotting m6A/A in a given region.
# Generating the list of output_file_names based on the given structure
out_file_names = [output_stem + "modkit-pileup-" + each_condition +"_"+ str(round(each_thresh,2))+"_"+str(each_index)+"_"+str(each_bamfrac)+ "_".join([each_type[-5:] for each_type in ["SDC2"]]) + "_".join([each_type[0:5] for each_type in ["SDC2"]]) + str(bed_window)+".bed" for each_condition,each_thresh,each_index, each_bamfrac in zip(conditions,thresh_list,sample_indices,bam_fracs)]

# Function to run a single command
def modkit_pileup_extract(args, index):
    each_bam, each_thresh, each_condition, each_index, each_bamfrac, each_type,modkit_path, output_stem, modkit_bed_name = args

    # Use the index to get the correct file name from out_file_names
    each_output = out_file_names[index]

    # Check if the output file exists
    if not regenerate_bit:
        if os.path.exists(each_output):
            print(f"File already exists: {each_output}")
            # Read in output file and check if empty
            modkit_qc = pd.DataFrame()
            try:
                modkit_qc = pd.read_csv(each_output, sep="\t", header=None, nrows=10)
            except:
                if modkit_qc.empty:
                    print(f"File is empty: {each_output}")
                    return
            return
    print(f"Starting on: {each_output}", "with bam file: ", each_bam,"and bedfile:", modkit_bed_name)
    command = [
        modkit_path,
        "pileup",
        "--only-tabs",
        #"--ignore",
        #"m",
        "--threads",
        "15",
        #"--filter-threshold",
        #f"A:{1-each_thresh}",
        #f"A:{1-each_thresh}",
        "--mod-thresholds",
        f"a:{each_thresh}",
        "--mod-thresholds",
        f"m:{each_thresh}",
        "--ref",
        "/Data1/reference/c_elegans.WS235.genomic.fa",
        "--filter-threshold",
        f"A:{1-each_thresh}",
        "--filter-threshold",
        f"C:{1-each_thresh}",
        "--motif",
        "GC",
        "1",
        "--motif",
        "A",
        "0",
        "--log-filepath",
        output_stem + each_condition + str(each_index) + "_modkit-pileup.log",
        "--include-bed",
        modkit_bed_name,
        each_bam,
        each_output
    ]
    subprocess.run(command, text=True)

# Now you need to adjust the task_args to include the index
# Instead of directly zipping, enumerate one of the lists to get the index
task_args_with_index = [(args, index) for index, args in enumerate(zip(
    new_bam_files,
    thresh_list,
    conditions,
    sample_indices,
    bam_fracs,
    [type_selected]*len(new_bam_files),
    [modkit_path]*len(new_bam_files),
    [output_stem]*len(new_bam_files),
    [temp_file_path]*len(new_bam_files),
))]

# Execute commands in parallel, unpacking the arguments and index within the map call
with Pool(
    processes=10
) as pool:
    pool.starmap(modkit_pileup_extract, task_args_with_index)


In [None]:

### Add bed and condition details to modkit output for plotting
# Using DataFrame merging to achieve the task without explicit loops
## Looks up the bed_start and bed_end values for each row in bedmethyl_df

def add_bed_columns_no_loops(bedmethyl_df_loc, combined_bed_df):
    # Calculate midpoint in combined_bed_df
    combined_bed_df['midpoint'] = (combined_bed_df['bed_start'] + combined_bed_df['bed_end']) / 2
    # Convert midpoint to the same type as start_position (int, in this case)
    combined_bed_df['midpoint'] = combined_bed_df['midpoint'].astype(int)

    combined_bed_df = combined_bed_df.sort_values(by='midpoint')

    # Ensure that start_position is of type int (if it's not already)
    bedmethyl_df_loc['start_position'] = bedmethyl_df_loc['start_position'].astype(int)

    # Merge bedmethyl_df with combined_bed_df based on the nearest midpoint
    merged_df = pd.merge_asof(bedmethyl_df_loc.sort_values('start_position'),
                              combined_bed_df,
                              by='chrom',
                              left_on='start_position',
                              right_on='midpoint',
                              direction='nearest')

    # Filter out rows where the start_position is not within the bed_start and bed_end range
    merged_df = merged_df.loc[(merged_df['start_position'] >= merged_df['bed_start']) &
                                (merged_df['start_position'] <= merged_df['bed_end'])]

    #reset index
    merged_df.reset_index(inplace=True, drop=True)

    # Create the final DataFrame by merging the merged DataFrame back to the original bedmethyl_df
    final_df = pd.merge(bedmethyl_df_loc,
                        merged_df[['chrom', 'start_position', 'bed_start', 'bed_end', 'bed_strand', 'type', 'chr_type']],
                        on=['chrom', 'start_position'],
                        how='left')

    # Drop all final_df rows where type == NaN
    final_df = final_df[final_df['type'].notna()]

    return final_df

def add_bed_columns_no_loops_small(bedmethyl_df_loc, combined_bed_df):
    # Ensure 'start_position' is of type int
    bedmethyl_df_loc['start_position'] = bedmethyl_df_loc['start_position'].astype(int)
    
    # Merge 'bedmethyl_df_loc' with 'combined_bed_df' on 'chrom' using an inner join
    merged_df = pd.merge(bedmethyl_df_loc, combined_bed_df, on='chrom', how='inner')

    # Filter rows where 'start_position' is within 'bed_start' and 'bed_end'
    merged_df = merged_df[
        (merged_df['start_position'] >= merged_df['bed_start']) &
        (merged_df['start_position'] <= merged_df['bed_end'])
    ]

    # Reset index
    merged_df.reset_index(drop=True, inplace=True)
    
    # Drop rows where 'type' is NaN, if necessary
    merged_df = merged_df[merged_df['type'].notna()]
    
    return merged_df


# combined_bed_df = nanotools.create_lookup_bed(new_bed_files)

# Initialize comb_bedmethyl_plot_df
comb_bedmethyl_df = pd.DataFrame()

# Create combined plotting dataframe
for each_output,each_condition,each_exp_id in zip(out_file_names,conditions,exp_ids):
    #print("Starting on:",each_output)
    # Define bed methyl columns and import bedmethyl file
    bedmethyl_df = pd.DataFrame()
    bedmethyl_cols = ['chrom','start_position','end_position','modified_base_code','score','strand','start_position_compat','end_position_compat','color','Nvalid_cov','fraction_modified','Nmod','Ncanonical','Nother_mod','Ndelete','Nfail','Ndiff','Nnocall']

    bedmethyl_df=pd.read_csv(each_output, sep="\t", header=None, names=bedmethyl_cols)

    # if bedmethyl_df is empty
    # drop all rows where modified_base_code is not equal to "a,A,0" or "m,GC,1"
    bedmethyl_df = bedmethyl_df[bedmethyl_df['modified_base_code'].isin(['a,A,0'])]#,'m,GC,1'])]
    if bedmethyl_df.empty:
        print("!Read in empty csv!!")
        print("Tried to select:",each_output," ",each_condition," ",each_exp_id, "and failed...")
        continue


    # sort bedmethyl_df by chrom and start_position
    bedmethyl_df = bedmethyl_df.sort_values(['start_position'], ascending=[True])
    # drop any rows with a nan
    bedmethyl_df = bedmethyl_df.dropna()
    bedmethyl_df.drop_duplicates(inplace=True)
    bedmethyl_df.reset_index(inplace=True, drop=True)

    bedmethyl_df = add_bed_columns_no_loops(bedmethyl_df, combined_bed_df)
    # Add rel_start and rel_end columns equal to start-bed_start and end-bed_start

    # if type_selected contains 'gene', map to metagene bins
    if 'gene' in type_selected[0] or 'damID' in type_selected[0]:
        print("Mapping to metagene bins...")
        # Define a function to process each group
        def process_group(group_tuple, num_bins, edge_window_size, sum_columns):
            _, group = group_tuple
            min_pos, max_pos = group['start_position'].min(), group['start_position'].max()
            bed_start, bed_end = group['bed_start'].iloc[0], group['bed_end'].iloc[0]

            group['rel_start'] = np.where(group['start_position'] < bed_start + edge_window_size, # if start position is less than or equal to bed_start + edge_window_size then
                                          group['start_position'] - bed_start - edge_window_size, # shift rel_pos by bed_start and window size. Otherwise
                                          np.where(group['start_position'] > bed_end - edge_window_size, #if start position is greater than bed_end - edge_window_size then
                                                   num_bins + edge_window_size - (bed_end - group['start_position']), #assign to bin otherwise
                                                   100000)) # assign to nan

            # delete any points outside of window
            if (max_pos - min_pos) > (num_bins + 2 * edge_window_size):
                binning_mask = (group['start_position'] >= bed_start + edge_window_size) & (group['start_position'] <= bed_end - edge_window_size)
                bin_edges = np.linspace(bed_start + edge_window_size, bed_end - edge_window_size, num_bins + 1)
                group.loc[binning_mask, 'rel_start'] = np.digitize(group.loc[binning_mask, 'start_position'], bins=bin_edges, right=True)

            return group

        def map_to_metagene_bins_and_sum(df, num_bins=1000, edge_window_size=500):
            # Columns for summing within bins
            sum_columns = ['Nmod', 'Ncanonical', 'Nother_mod', 'Ndelete', 'Nfail', 'Ndiff', 'Nnocall','Nvalid_cov']
            # Columns to retain in the final DataFrame
            retain_columns = ['bed_strand', 'chr_type', 'strand', 'bed_end','type']
            # Adjust group columns based on the updated request
            group_columns = ['bed_start', 'chrom' ,'modified_base_code']

            # Splitting the DataFrame into groups
            groups = list(df.groupby(group_columns))

            # Using multiprocessing to process groups in parallel
            with Pool(500) as pool:
                processed_groups = pool.starmap(process_group, [(group, num_bins, edge_window_size, sum_columns) for group in groups])

            # Combine the processed groups into a single DataFrame
            result_df = pd.concat(processed_groups, ignore_index=True)

            # Summing within bins and merging
            sum_group_columns = group_columns + ['rel_start']
            summed_df = result_df.groupby(sum_group_columns)[sum_columns].sum()
            merged_df = pd.merge(result_df[sum_group_columns + retain_columns].drop_duplicates(), summed_df, on=sum_group_columns, how='left')

            return merged_df

        bedmethyl_df = map_to_metagene_bins_and_sum(bedmethyl_df, num_bins=num_bins, edge_window_size=bed_window)

    else:
        bedmethyl_df['rel_start'] = bedmethyl_df['start_position'] - bedmethyl_df['bed_start'] - bed_window + 2

        # if "MEX_" in "type" then multiple rel_start by -1
        if "MEX_" in type_selected[0]:
            bedmethyl_df['rel_start'] = bedmethyl_df['rel_start']  #-1 # *-1

    # set rel_start to int
    bedmethyl_df['rel_start'] = bedmethyl_df['rel_start'].astype(int)

    #print("2. bedmethyl_df")
    #display(bedmethyl_df)
    bedmethyl_df['condition'] = each_condition
    bedmethyl_df['exp_id'] = each_exp_id
    # eliminate levels in dataframe

    # if bedmethyl_df is empty
    if bedmethyl_df.empty:
        print("!Bedmethyl_df is empty!")
        print("Tried to select:",each_output," ",each_condition," ",each_exp_id, "and failed...")
        continue

    # if comb_bedmethyl_plot_df is null, set it equal to bedmethyl_plot
    if comb_bedmethyl_df.empty:
        print("comb_bedmethyl_plot_df is empty, setting it equal to bedmethyl_plot...")
        comb_bedmethyl_df = bedmethyl_df
        #print("comb_bedmethyl_plot_df:",comb_bedmethyl_plot_df)
    # else append bedmethyl_plot to comb_bedmethyl_plot_df
    else:
        print("comb_bedmethyl_plot_df is not empty, appending bedmethyl_plot...")
        comb_bedmethyl_df = comb_bedmethyl_df.append(bedmethyl_df)
        #print("comb_bedmethyl_plot_df:",comb_bedmethyl_plot_df)

comb_bedmethyl_df.reset_index(inplace=True, drop=True)

#print("head")
#display(comb_bedmethyl_df.head(100))
print("sample")
display(nanotools.display_sample_rows(comb_bedmethyl_df,10))
#print("tail")
#display(comb_bedmethyl_df.tail(100))

#print head of first 15 cols
print("head")
display(comb_bedmethyl_df.iloc[:,15:30].head(10))

# print unique type counts
print("unique type counts")
print(comb_bedmethyl_df['type'].value_counts())

In [None]:
import pandas as pd
import pyranges as pr

# Load FIMO files
fimo_files = [
    "/Data1/ext_data/motifs/fimo_MEX_0.01.tsv",
    "/Data1/ext_data/motifs/fimo_MEXII_0.01.tsv",
    "/Data1/ext_data/motifs/fimo_motifc_0.01.tsv"
]

dfs = []
for file in fimo_files:
    print(f"Loading {file}")
    df = pd.read_csv(file, sep='\t')
    print(f"Loaded {len(df)} rows from {file}")
    dfs.append(df)

# Combine dataframes
fimo_df = pd.concat(dfs, ignore_index=True)
print(f"Combined dataframe has {len(fimo_df)} rows")

# keep 200 random rows
#fimo_df = fimo_df.sample(n=200, random_state=1)

# Convert 'chrI' to 'CHROMOSOME_I' in 'sequence_name'
print("Converting 'sequence_name' to 'CHROMOSOME_*' format")
fimo_df['sequence_name'] = fimo_df['sequence_name'].str.replace('chr', 'CHROMOSOME_')

# Keep only rows where ln(pv-alue) <= -6.5
fimo_df['ln_p_value'] = round(np.log(fimo_df['p-value']), 1)
fimo_df = fimo_df[fimo_df['ln_p_value'] <= -8]  # Adjust as needed
print(f"Filtered dataframe has {len(fimo_df)} rows")

# Select relevant columns
print("Selecting relevant columns")
fimo_df = fimo_df[['sequence_name', 'start', 'stop', 'strand', 'score', 'ln_p_value', 'motif_id']]
fimo_df = fimo_df.rename(columns={'sequence_name': 'chr'})

# Assign motif priorities
print("Assigning motif priorities")
motif_priority = {'MEXII': 1, 'MEX': 2, 'motifC': 3}
fimo_df['motif_priority'] = fimo_df['motif_id'].map(motif_priority)

# Sort by chromosome, start, and motif priority
print("Sorting dataframe")
fimo_df = fimo_df.sort_values(by=['chr', 'start', 'motif_priority'])

# Deduplicate overlapping intervals
# Deduplicate overlapping intervals using pyranges
print("Deduplicating overlapping intervals using pyranges")

# Rename columns to match PyRanges requirements
fimo_df = fimo_df.rename(columns={'chr': 'Chromosome', 'start': 'Start', 'stop': 'End'})

# Ensure 'Start' and 'End' are integers
fimo_df['Start'] = fimo_df['Start'].astype(int)
fimo_df['End'] = fimo_df['End'].astype(int)

# Convert DataFrame to PyRanges object
pr_df = pr.PyRanges(fimo_df)

# Cluster overlapping intervals
clusters = pr_df.cluster()

# Convert back to DataFrame
clusters_df = clusters.df

# Sort by Cluster and motif_priority
clusters_df = clusters_df.sort_values(['Cluster', 'motif_priority'])

# Drop duplicates, keeping the interval with the lowest motif_priority in each cluster
fimo_dedup_df = clusters_df.drop_duplicates(subset=['Cluster'], keep='first')

# Optionally, rename columns back to original names
fimo_dedup_df = fimo_dedup_df.rename(columns={'Chromosome': 'chr', 'Start': 'start', 'End': 'stop'})

print(f"Deduplicated dataframe has {len(fimo_dedup_df)} rows")

# Rename columns for PyRanges compatibility
print("Renaming columns for PyRanges compatibility")
fimo_df_renamed = fimo_dedup_df.rename(columns={'chr': 'Chromosome', 'start': 'Start', 'stop': 'End'})
comb_bedmethyl_df_renamed = comb_bedmethyl_df.rename(columns={'chrom': 'Chromosome', 'bed_start': 'Start', 'bed_end': 'End'})

# Convert DataFrames to PyRanges objects
print("Converting DataFrames to PyRanges objects")
fimo_pr = pr.PyRanges(fimo_df_renamed)
comb_bedmethyl_df_pr = pr.PyRanges(comb_bedmethyl_df_renamed)

# Perform the join to find overlapping intervals
print("Performing join to find overlapping intervals")
overlap_result = comb_bedmethyl_df_pr.join(fimo_pr, suffix="_fimo")

# Extract motif_start and motif_id values and aggregate as pairs
print("Aggregating overlapping intervals")
overlap_df = overlap_result.df
agg_df = overlap_df.groupby(['Chromosome', 'Start', 'End']).apply(
    lambda x: list(zip(x['Start_fimo'], x['motif_id'], x['ln_p_value']))
).reset_index(name='motif_info')

# Remove duplicate ('motif_start', 'motif_id') pairs if needed
print("Removing duplicate motif info")
agg_df['motif_info'] = agg_df['motif_info'].apply(lambda x: list(set(x)))

# Separate 'motif_info' into 'motif_start' and 'motif_id' lists
print("Separating 'motif_info' into 'motif_start' and 'motif_id'")
agg_df['motif_start'] = agg_df['motif_info'].apply(lambda x: [item[0] for item in x])
agg_df['motif_id'] = agg_df['motif_info'].apply(lambda x: [item[1] for item in x])
agg_df['motif_score'] = agg_df['motif_info'].apply(lambda x: [item[2] for item in x])

# Compute 'motif_rel_start'
print("Computing 'motif_rel_start'")
agg_df['motif_rel_start'] = agg_df.apply(
    lambda row: [start - row['Start'] - bed_window + 2 for start in row['motif_start']], axis=1
)

# Rename columns back to original names in comb_bedmethyl_df
print("Renaming columns back to original names in comb_bedmethyl_df")
comb_bedmethyl_df = comb_bedmethyl_df.rename(columns={
    'Chromosome': 'chrom',
    'Start': 'bed_start',
    'End': 'bed_end',
})

# Rename columns back to the original names
print("Renaming columns back to original names")
agg_df = agg_df.rename(columns={
    'Chromosome': 'chrom',
    'Start': 'bed_start',
    'End': 'bed_end',
})

# Merge the aggregated data back into the original DataFrame
print("Merging aggregated data back into comb_bedmethyl_df")
comb_bedmethyl_df = comb_bedmethyl_df.merge(
    agg_df[['chrom', 'bed_start', 'bed_end', 'motif_rel_start', 'motif_id','motif_score']],
    how='left',
    on=['chrom', 'bed_start', 'bed_end']
)

print(comb_bedmethyl_df.columns)

# Convert lists to tuples
print("Converting lists to tuples")
comb_bedmethyl_df['motif_rel_start'] = comb_bedmethyl_df['motif_rel_start'].apply(
    lambda x: tuple(x) if isinstance(x, list) else x
)
comb_bedmethyl_df['motif_id'] = comb_bedmethyl_df['motif_id'].apply(
    lambda x: tuple(x) if isinstance(x, list) else x
)
comb_bedmethyl_df['motif_score'] = comb_bedmethyl_df['motif_score'].apply(
    lambda x: tuple(x) if isinstance(x, list) else x
)

# Replace NaN with empty tuples
print("Replacing NaN with empty tuples")
comb_bedmethyl_df['motif_rel_start'] = comb_bedmethyl_df['motif_rel_start'].apply(
    lambda x: tuple() if pd.isna(x) else x
)

comb_bedmethyl_df['motif_id'] = comb_bedmethyl_df['motif_id'].apply(
    lambda x: tuple() if pd.isna(x) else x
)

# Convert all tuple elements in motif_rel_start to int
comb_bedmethyl_df['motif_rel_start'] = comb_bedmethyl_df['motif_rel_start'].apply(
    lambda x: tuple(map(int, x)) if isinstance(x, tuple) else x
)

comb_bedmethyl_df['motif_score'] = comb_bedmethyl_df['motif_score'].apply(
    lambda x: tuple(map(float, x)) if isinstance(x, tuple) else x
)

# Debugging statements
print("Final DataFrame:")
print(comb_bedmethyl_df.head())

print("Unique type counts:")
print(comb_bedmethyl_df['type'].value_counts())

print("Rows containing 'rex' in 'type':")
print(comb_bedmethyl_df[comb_bedmethyl_df['type'].str.contains('rex')])


In [None]:
### SHIFT AND TRANSFORM (OPTIONAL)
align_zero_bool = False
flip_bool = False
plot_motifs = True

# Set normalization_type parameter here
normalization_type = 'global'  # or 'local'

def compute_lag_for_maximum_alignment(series1, bed_start1):
    flip = 0
    pos_max_series1 = np.argmax(series1)
    lag = (round(len(series1)/2))-pos_max_series1
    return (lag, flip)

def get_continuous_series(df_subset):
    # Create a Series with rel_start as the index and use 'weighted_norm_mod_frac' as the value
    series_filled = df_subset.set_index('rel_start')['weighted_norm_mod_frac']
    series_filled = series_filled.fillna(method='ffill').fillna(method='bfill')

    try:
        series_filled = series_filled.reindex(range(int(series_filled.index.min()), int(series_filled.index.max()) + 1), fill_value=0)
    except:
        print("Failed series_filled:", series_filled)
        print("Duplicate indexes:",series_filled.index[series_filled.index.duplicated()])

    return series_filled.values

def align_profiles(df):
    df = df.sort_values(['bed_start', 'rel_start']).copy()
    bed_starts = df['bed_start'].unique()

    # Determine the reference bed_start
    summed_Nvalid_cov = df.groupby('bed_start')['Nvalid_cov'].sum()
    reference_bed_start = summed_Nvalid_cov.idxmax()
    series_reference = get_continuous_series(df[df['bed_start'] == reference_bed_start])

    # Calculate the number of positions to shift
    shift_positions = int(round(len(series_reference)/2)) - np.argmax(series_reference)

    # Shift the entire series_reference by shift_positions
    if shift_positions > 0:  # shift to the left
        series_reference = np.concatenate(([0]*shift_positions, series_reference))
    else:
        series_reference = np.concatenate((series_reference,[0]*abs(shift_positions)))

    df["flipped"] = 0

    for other_bed_start in bed_starts:
        series_to_shift = get_continuous_series(df[df['bed_start'] == other_bed_start])
        lag, flip = compute_lag_for_maximum_alignment(series_to_shift, other_bed_start)

        df.loc[df['bed_start'] == other_bed_start, 'shift'] = lag
        df.loc[df['bed_start'] == other_bed_start, 'flipped'] = 1 if flip else 0

    total_flipped = df[df['flipped'] == 1]['bed_start'].nunique()
    lag_distribution = df['shift'].describe()
    print(f"Total bed_starts flipped: {total_flipped} out of {len(bed_starts) - 1}")
    print("Lag Distribution:")
    print(lag_distribution)

    return df

print("Copying and dropping rows...")
comb_bedmethyl_plot_df = comb_bedmethyl_df.copy()

nanotools.display_sample_rows(comb_bedmethyl_plot_df,5)

def debug_print(message, df=None, group_cols=None):
    print(f"DEBUG: {message}")
    if df is not None:
        print(f"Shape: {df.shape}")
        print(f"Columns: {df.columns.tolist()}")
        if group_cols:
            print(df.groupby(group_cols).size())
        print(df.head())
    print("\n")

for each_type in type_selected:
    if any(x in each_type for x in ["TSS", "TES", "MEX", "MEXII", "gene"]):
        print(f"Strand orientation sensitive {each_type} type selected, multiplying rel_start by -1 for '-' strand genes...")
        if 'gene' in each_type:
            comb_bedmethyl_plot_df['rel_start'] -= num_bins/2

        mask = (comb_bedmethyl_plot_df['type'] == each_type) & (comb_bedmethyl_plot_df['bed_strand'] == '-')
        comb_bedmethyl_plot_df.loc[mask, 'rel_start'] *= -1
        comb_bedmethyl_plot_df.loc[mask, 'rel_start'] += 13

        if plot_motifs:
            print("adjusting motif start...")
            comb_bedmethyl_plot_df.loc[
                comb_bedmethyl_plot_df.index.isin(mask.index) & mask.fillna(False), 'motif_rel_start'
            ] = comb_bedmethyl_plot_df.loc[
                comb_bedmethyl_plot_df.index.isin(mask.index) & mask.fillna(False), 'motif_rel_start'
            ].apply(lambda motif_tuple: tuple(int(x) * -1 + 13 for x in motif_tuple))

        debug_print("After gene strand adjustment", comb_bedmethyl_plot_df, ['type'])

        if 'gene' in each_type:
            print("Adjusting rel_start for gene type...")
            comb_bedmethyl_plot_df['rel_start'] += num_bins/2
            debug_print("After gene bin adjustment", comb_bedmethyl_plot_df, ['type'])

if plot_motifs:
    print("Grouping by chrom, rel_start, exp_id, modified_base_code, condition, type, chr_type, strand, motif_rel_start, motif_id...")
    grouped_df = comb_bedmethyl_plot_df.groupby([
        'chrom', 'rel_start', 'motif_rel_start', 'motif_id','motif_score',
        'exp_id', 'modified_base_code', 'condition',
        'type', 'chr_type', 'strand','bed_start','bed_strand'
    ]).agg({
        'Nvalid_cov': 'sum',
        'Nmod': 'sum',
        'Ncanonical': 'sum',
        'Nother_mod': 'sum'
    }).reset_index()
else:
    print("Grouping by chrom, rel_start, exp_id, modified_base_code, condition, type, chr_type...")
    grouped_df = comb_bedmethyl_plot_df.groupby([
        'chrom', 'rel_start', 'exp_id', 'modified_base_code', 'condition',
        'type', 'chr_type', 'strand'
    ]).agg({
        'Nvalid_cov': 'sum',
        'Nmod': 'sum',
        'Ncanonical': 'sum',
        'Nother_mod': 'sum'
    }).reset_index()

print("Calculating normalized m6A...")
grouped_df['raw_mod_frac'] = grouped_df['Nmod'] / (grouped_df['Nmod'] + grouped_df['Ncanonical'])

coverage_df['exp_id'] = coverage_df['exp_id'].str.strip()
grouped_df['exp_id'] = grouped_df['exp_id'].str.strip()

# Merge with coverage_df to get exp_id_m6A_frac
merged_df = pd.merge(grouped_df, coverage_df[['exp_id', 'm6A_frac']], on=['exp_id'], how='left')
merged_df.rename(columns={'m6A_frac': 'exp_id_m6A_frac'}, inplace=True)

if normalization_type == 'global':
    # Global normalization
    merged_df['norm_mod_frac_init'] = merged_df['exp_id_m6A_frac']
elif normalization_type == 'local':
    # Local normalization using intergenic mod_frac
    intergenic_df = grouped_df[grouped_df['type'].str.contains("intergenic", case=False)]
    local_mod_frac = intergenic_df.groupby('exp_id').agg({
        'Nmod': 'sum',
        'Ncanonical': 'sum'
    }).reset_index()
    local_mod_frac['mod_frac_local'] = local_mod_frac['Nmod'] / (local_mod_frac['Nmod'] + local_mod_frac['Ncanonical'])

    merged_df = pd.merge(merged_df, local_mod_frac[['exp_id', 'mod_frac_local']], on='exp_id', how='left')
    merged_df['mod_frac_local'] = merged_df['mod_frac_local'].fillna(1)
    merged_df['norm_mod_frac_init'] = merged_df['mod_frac_local']
else:
    raise ValueError("Invalid normalization_type. Choose 'global' or 'local'.")

plot_df = merged_df[grouped_df.columns.tolist() + ['norm_mod_frac_init']]

if plot_motifs:
    plot_df = plot_df.groupby([
        'rel_start', 'modified_base_code', 'motif_rel_start', 'motif_id','motif_score', 'condition', 'type', 'chr_type','strand','norm_mod_frac_init'
    ])[['Nvalid_cov', 'Ncanonical', 'Nmod']].sum().reset_index()
else:
    plot_df = plot_df.groupby([
        'rel_start', 'modified_base_code', 'condition', 'type', 'chr_type','strand','norm_mod_frac_init'
    ])[['Nvalid_cov', 'Ncanonical', 'Nmod']].sum().reset_index()

# Compute raw_mod_frac and weighted_norm_mod_frac
plot_df['raw_mod_frac'] = plot_df['Nmod']/(plot_df['Nmod']+plot_df['Ncanonical'])
plot_df['weighted_norm_mod_frac'] = plot_df['raw_mod_frac']/plot_df['norm_mod_frac_init']

plot_df.sort_values(['rel_start'], inplace=True)
plot_df.reset_index(inplace=True, drop=True)

print("Count of rows by type in merged_df:")
print(merged_df.groupby(['type']).size())
print("Count of rows by type in plot_df:")
print(plot_df.groupby(['type']).size())

print("unique type counts")
print(plot_df['type'].value_counts())

print("plot_df:")
nanotools.display_sample_rows(plot_df,10)


In [None]:
force_replace = True
# save final_df to /temp folder as csv, with all configurations in file name if it does not exist. If it does exist, import it.
final_fn = "temp_files/" + "final_df_" + "_".join([each_type for each_type in type_selected[:5]]) + str(round(thresh_list[0],2)) + "_"+str(bam_fracs[0])+str(bed_window)+".csv"
final_fn_chip = "temp_files/" + "final_df_chip" + "_".join([each_type for each_type in type_selected[:5]]) + str(round(thresh_list[0],2)) + "_"+str(bam_fracs[0])+str(bed_window)+".csv"

if not force_replace and os.path.exists(final_fn):
    print("final_df already exists, importing it...")
    plot_df = pd.read_csv(final_fn)
    nanotools.display_sample_rows(plot_df,5)
else:
    print("final_df does not exist, saving it...")
    plot_df.to_csv(final_fn, index=False)

# if plot_comb_bigwig_df dataframe does not exist:
try:
    if not force_replace and os.path.exists(final_fn_chip):
        print("final_df_chip already exists, importing it...")
        plot_comb_bigwig_df = pd.read_csv(final_fn_chip)
        nanotools.display_sample_rows(plot_comb_bigwig_df,5)
    else:
        print("final_df_chip does not exist, saving it...")
        plot_comb_bigwig_df.to_csv(final_fn_chip, index=False)
except:
    print("plot_comb_bigwig_df does not exist, skipping...")

In [None]:
importlib.reload(nanotools)
from scipy.signal import gaussian
import scipy.ndimage

import importlib
import numpy as np
import pandas as pd
from scipy.signal import gaussian
import scipy.ndimage
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_bedmethyl(comb_bedmethyl_df, conditions_input, chr_types=None, types=None, strands=["all"], window_size=50, metagene_bins=1000, smoothing_type="weighted", selection_indices=None, bed_window=[-500,500], mod_types=['5mC','m6A'], ignore_selec=[], bigwig_df=None, bw_selections=None, plot_motifs=False, plot_type="raw"):
    global analysis_cond
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    cov_fig = make_subplots(specs=[[{"secondary_y": True}]])
    y_min = float('inf')
    y_max = float('-inf')
    line_counter = -1

    # Drop all rows with rel_start outside of bed_window
    comb_bedmethyl_df = comb_bedmethyl_df[(comb_bedmethyl_df['rel_start'] >= bed_window[0]) & (comb_bedmethyl_df['rel_start'] <= bed_window[1])]

    motifs_plotted = False

    if selection_indices is not None:
        conditions = [conditions_input[i] for i in selection_indices]
    else:
        conditions = conditions_input

    for selected_condition in conditions:
        for selected_modification in (mod_types or ["all"]):
            skip_line = False
            if selected_modification == '5mC':
                selected_mod = 'm,GC,1'
            elif selected_modification == 'm6A':
                selected_mod = 'a,A,0'
            else:
                selected_mod = 'all'

            for each in ignore_selec:
                if selected_condition == conditions_input[each[0]] and selected_modification == each[1]:
                    skip_line = True
                    break

            if skip_line:
                continue

            # Handle chr_types == ["all"] by combining 'Autosome' and 'X' types if needed
            if chr_types == ["all"]:
                selected_chr_types = ["all_combined"]
                comb_df = comb_bedmethyl_df.copy()
                comb_df['chr_type'] = 'all_combined'
                if plot_motifs:
                    group_fields = ['rel_start', 'motif_rel_start','motif_id','motif_score','condition', 'type', 'strand', 'modified_base_code']
                else:
                    group_fields = ['rel_start','condition', 'type', 'strand', 'modified_base_code']
                sum_fields = ['raw_mod_frac', 'weighted_norm_mod_frac', 'Nvalid_cov','Nmod']
                comb_df = comb_df.groupby(group_fields, as_index=False)[sum_fields].sum()
            else:
                selected_chr_types = chr_types or ["all"]
                comb_df = comb_bedmethyl_df

            for selected_chr_type in selected_chr_types:
                for selected_type in (types or ["all"]):
                    for selected_strand in (strands or ["all"]):
                        if selected_strand == "all":
                            selected_strand_types = ["all_combined"]
                            temp_df = comb_df.copy()
                            temp_df['strand'] = "all_combined"
                            if plot_motifs:
                                group_fields = ['rel_start', 'motif_rel_start','motif_id','motif_score','condition', 'type','chr_type', 'modified_base_code']
                            else:
                                group_fields = ['rel_start','condition', 'type','chr_type', 'modified_base_code']
                            sum_fields = ['raw_mod_frac', 'weighted_norm_mod_frac', 'Nvalid_cov','Nmod']
                            temp_df = temp_df.groupby(group_fields, as_index=False)[sum_fields].sum()
                            df_to_use = temp_df
                        else:
                            selected_strand_types = strands or ["all"]
                            df_to_use = comb_df

                        # Apply filters
                        filters = []
                        if selected_condition:
                            filters.append(df_to_use['condition'] == selected_condition)
                        if selected_chr_type != "all":
                            filters.append(df_to_use['chr_type'] == selected_chr_type)
                        if selected_type != "all":
                            filters.append(df_to_use['type'] == selected_type)
                        if selected_strand != "all":
                            filters.append(df_to_use['strand'] == selected_strand)
                        if selected_mod != "all":
                            filters.append(df_to_use['modified_base_code'] == selected_mod)
                        if not filters:
                            base_filter = np.full(len(df_to_use), True)
                        else:
                            base_filter = np.logical_and.reduce(filters)

                        data_filtered = df_to_use.loc[base_filter][['raw_mod_frac', 'weighted_norm_mod_frac', 'rel_start', 'motif_rel_start', 'motif_id', 'motif_score', 'Nvalid_cov', 'Nmod']].copy()
                        if data_filtered.empty:
                            continue
                        data_filtered_nonan = data_filtered.dropna(subset=['motif_rel_start', 'motif_id', 'motif_score'])
                        if not data_filtered_nonan.empty:
                            import ast

                            def ensure_list(x):
                                if isinstance(x, str):
                                    return ast.literal_eval(x)
                                elif isinstance(x, tuple):
                                    return list(x)
                                elif isinstance(x, list):
                                    return x
                                else:
                                    return [x]
                            if not motifs_plotted:
                                data_filtered_nonan['motif_rel_start'] = data_filtered_nonan['motif_rel_start'].apply(ensure_list)
                                data_filtered_nonan['motif_id'] = data_filtered_nonan['motif_id'].apply(ensure_list)
                                data_filtered_nonan['motif_score'] = data_filtered_nonan['motif_score'].apply(ensure_list)
                                data_filtered_nonan['motif_info'] = data_filtered_nonan.apply(
                                    lambda row: list(zip(row['motif_rel_start'], row['motif_id'], row['motif_score'])),
                                    axis=1
                                )
                                exploded_df = data_filtered_nonan.explode('motif_info')
                                exploded_df[['motif_rel_start', 'motif_id', 'motif_score']] = pd.DataFrame(
                                    exploded_df['motif_info'].tolist(), index=exploded_df.index
                                )
                                exploded_df = exploded_df[['motif_rel_start', 'motif_id', 'motif_score']].drop_duplicates()
    
                                # Group by position and collect all motifs and scores
                                grouped_motifs = exploded_df.groupby('motif_rel_start').agg({
                                    'motif_id': list,
                                    'motif_score': list
                                }).reset_index()
    
                                # Function to calculate vertical position for staggering
                                def get_y_position(idx):
                                    # Base step size
                                    step = -0.025
                                    # Stagger every other label
                                    return 1 + (idx % 3) * step
                                
                                id_rep = 0
                                # Add vertical dashed lines and staggered labels
                                for _, row in grouped_motifs.iterrows():
                                    motif_rel_start = row['motif_rel_start']
                                    motif_ids = row['motif_id']
                                    motif_scores = row['motif_score']
    
                                    # Add vertical dashed line
                                    fig.add_shape(
                                        type="line",
                                        x0=motif_rel_start, x1=motif_rel_start,
                                        y0=0, y1=1,
                                        line=dict(color="grey", width=1, dash="dash"),
                                        xref="x", yref="paper"
                                    )
    
                                    # Add labels with scores
                                    for idx, (motif_id, score) in enumerate(zip(motif_ids, motif_scores)):
                                        y_pos = get_y_position(id_rep)
                                        label_text = f"{motif_id} ln({score:.1f})"
                                        
                                        #print("idx:",idx,"y_pos:",y_pos,"label_text:",label_text)
                                        
                                        fig.add_annotation(
                                            x=motif_rel_start,
                                            y=y_pos,
                                            yref="paper",
                                            text=label_text,
                                            showarrow=False,
                                            #yanchor="top",
                                            #xanchor="center",
                                            font=dict(size=12, color="white"),
                                            bgcolor="rgba(0,0,0,0)"
                                        )
                                        id_rep += 1
                        motifs_plotted = True
                        
                        if data_filtered.empty or pd.isna(data_filtered['rel_start'].min()) or pd.isna(data_filtered['rel_start'].max()):
                            continue

                        full_range_df = pd.DataFrame({'rel_start': range(int(data_filtered['rel_start'].min()), int(data_filtered['rel_start'].max() + 1))})
                        merged_df = pd.merge(full_range_df, data_filtered, on='rel_start', how='left')

                        # Update raw_mod_frac to be equal to Nmod / Nvalid_cov
                        merged_df['raw_mod_frac'] = merged_df['Nmod'] / merged_df['Nvalid_cov']

                        if smoothing_type != "weighted":
                            merged_df.fillna({'weighted_norm_mod_frac': 0, 'Nvalid_cov': 0,'Nmod':0,'raw_mod_frac':0}, inplace=True)
                        else:
                            merged_df.dropna(subset=['weighted_norm_mod_frac', 'Nvalid_cov','Nmod','raw_mod_frac'], inplace=True)
                            merged_df.reset_index(drop=True, inplace=True)

                        # Allow selecting raw or normalized data
                        if plot_type == "raw":
                            m6A_data = merged_df['raw_mod_frac']
                        else:
                            m6A_data = merged_df['weighted_norm_mod_frac']

                        m6A_data_xaxis = merged_df['rel_start']
                        Nvalid_cov_data = merged_df['Nvalid_cov']

                        # Smooth coverage data
                        if smoothing_type == "none":
                            smoothed_cov_data = Nvalid_cov_data
                        else:
                            smoothed_cov_data = Nvalid_cov_data.rolling(window=window_size, center=True).mean()

                        # Apply smoothing to m6A_data
                        if smoothing_type == "weighted":
                            def weighted_rolling_average(values, weights, window_size):
                                def calculate_weighted_avg(window):
                                    return (window * weights[window.index]).sum() / weights[window.index].sum()
                                return values.rolling(window=window_size, center=True).apply(calculate_weighted_avg, raw=False)
                            smoothed_data = weighted_rolling_average(m6A_data, Nvalid_cov_data, window_size)

                        elif smoothing_type == "gaussian":
                            smoothed_data_array = scipy.ndimage.gaussian_filter1d(m6A_data, sigma=window_size)
                            smoothed_data = pd.Series(smoothed_data_array)

                        elif smoothing_type == "exponential":
                            def exponential_decay_smoothing(x, alpha=0.1):
                                sm = np.zeros_like(x)
                                sm[0] = x[0]
                                for t in range(1, len(x)):
                                    sm[t] = alpha * x[t] + (1 - alpha) * sm[t-1]
                                return sm

                            def symmetrical_exponential_smoothing(x, alpha=0.1):
                                forward_smoothed = exponential_decay_smoothing(x, alpha)
                                backward_smoothed = exponential_decay_smoothing(x[::-1], alpha)[::-1]
                                return (forward_smoothed + backward_smoothed) / 2

                            alpha = 0.05
                            smoothed_data_array = symmetrical_exponential_smoothing(m6A_data, alpha=alpha)
                            smoothed_data = pd.Series(smoothed_data_array)

                        elif smoothing_type == "lowess":
                            from statsmodels.nonparametric.smoothers_lowess import lowess
                            smoothed = lowess(m6A_data, m6A_data_xaxis, frac=0.05, it=0)
                            m6A_data_xaxis_array, smoothed_data_array = smoothed[:, 0], smoothed[:, 1]
                            smoothed_data = pd.Series(smoothed_data_array, index=m6A_data_xaxis_array)
                            m6A_data_xaxis = pd.Series(m6A_data_xaxis_array, index=m6A_data_xaxis_array)

                        elif smoothing_type == "none":
                            smoothed_data = m6A_data
                        else:
                            smoothed_data = m6A_data.rolling(window=window_size, center=True).mean()

                        y_min = min(y_min, smoothed_data.min())
                        y_max = max(y_max, smoothed_data.max())

                        label = f"{selected_condition}_{selected_chr_type}_{selected_type}_{selected_strand}_{selected_modification}"

                        color = nanotools.get_colors(selected_condition)
                        fig.add_trace(
                            go.Scatter(
                                x=m6A_data_xaxis.values,
                                y=smoothed_data.values,
                                mode='lines',
                                name=label,
                                opacity=0.9,
                                line=dict(width=3, color=color)
                            ),
                            secondary_y=False
                        )

                        color = nanotools.get_colors(selected_condition)
                        cov_fig.add_trace(
                            go.Scatter(
                                x=m6A_data_xaxis.values,
                                y=smoothed_cov_data.values,
                                mode='lines',
                                name=label + "_Nvalid_cov",
                                opacity=0.9,
                                line=dict(width=3, color=color)
                            ),
                            secondary_y=False
                        )

    if bigwig_df is not None:
        bigwig_df = bigwig_df[(bigwig_df['rel_start'] >= bed_window[0]) & (bigwig_df['rel_start'] <= bed_window[1])]
        if bw_selections is not None:
            for bw_selection in bw_selections:
                for selected_chr_type in (chr_types or ["all"]):
                    for selected_type in (types or ["all"]):
                        for selected_strand in (strands or ["all"]):
                            filters = []
                            if bw_selection:
                                filters.append(bigwig_df['condition'] == bw_selection)
                            if selected_chr_type != "all":
                                filters.append(bigwig_df['chr_type'] == selected_chr_type)
                            if selected_type != "all":
                                filters.append(bigwig_df['type'] == selected_type)
                            if selected_strand != "all":
                                filters.append(bigwig_df['strand'] == selected_strand)

                            if filters:
                                base_filter = np.logical_and.reduce(filters)
                            else:
                                base_filter = np.full(len(bigwig_df), True)

                            value_data = bigwig_df.loc[base_filter]['value']
                            value_data_xaxis = bigwig_df.loc[base_filter]['rel_start']

                            if smoothing_type == "none":
                                smoothed_data = value_data
                            else:
                                smoothed_data = value_data.rolling(window=window_size, center=True).mean()

                            y_min = min(y_min, smoothed_data.min())
                            y_max = max(y_max, smoothed_data.max())

                            label = f"{bw_selection}_{selected_chr_type}_{selected_type}_{selected_strand}"

                            fig.add_trace(
                                go.Scatter(
                                    x=value_data_xaxis.values,
                                    y=smoothed_data.values,
                                    mode='lines',
                                    name=label,
                                    opacity=0.9,
                                    line=dict(width=3)
                                ),
                                secondary_y=True,
                            )

    # Add a white border shape
    border_shape = dict(
        type="rect",
        x0=0, y0=0, x1=0.95, y1=1,
        xref="paper", yref="paper",
        line=dict(color="white", width=2, dash='solid'),
        fillcolor='rgba(0,0,0,0)',
    )

    # Title
    plot_title = "_".join([each_type for each_type in (types or [])])

    # Update main figure layout
    fig.update_xaxes(showgrid=False,showline=False)
             
    fig.update_yaxes(showgrid=False,showline=False)
    # Add border after all motif lines and annotations are added
    fig.add_shape(border_shape)
    
    fig.update_layout(
        title=plot_title,
        xaxis_title='Genomic Position',
        template="plotly_white",
        title_font=dict(size=24),
        width=900,
        height=900,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        font=dict(color='white')
    )
    fig.update_yaxes(title_text="modBase/Base", secondary_y=False)
    #fig.update_yaxes(title_text="ChIP enrichment", secondary_y=True)
    fig.update_yaxes(tickformat=".0%")
    fig.update_layout(legend=dict(
        traceorder="normal",
        y=-0.2,
        x=0.25,
        yanchor="top",
        orientation='h',
        font=dict(size=14),
    ))
    fig.update_xaxes(range=[bed_window[0], bed_window[1]])

    # Adjust tick fonts and axes lines
    fig.update_xaxes(
        tickfont=dict(size=24),
        ticks='inside',
        ticklen=10,
        tickwidth=2,
        tickcolor='white',
        showline=False,
        zeroline=False
        #linecolor='white',
    )

    fig.update_yaxes(
        tickfont=dict(size=24),
        ticks='inside',
        ticklen=10,
        tickwidth=2,
        tickcolor='white',
        showline=False,
        zeroline=False
        #linecolor='white',
    )

    # If plot_type is norm, update y-axis label and format
    if plot_type == "norm":
        fig.update_yaxes(title_text="Norm modBase/Base", secondary_y=False,
                         # increase size
                         title_font=dict(size=24))
        fig.update_yaxes(tickformat=".2")

    # Update coverage figure layout
    cov_fig.update_xaxes(showgrid=False)
    cov_fig.update_yaxes(showgrid=False)
    cov_plot_title = "Motif Count" + "_".join([each_type for each_type in (types or [])])
    cov_fig.update_layout(
        title=cov_plot_title,
        xaxis_title='Genomic Position',
        title_font=dict(size=24),
        template="plotly_white",
        width=900,
        height=900,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        shapes=[border_shape],
        font=dict(color='white')
    )
    cov_fig.update_yaxes(title_text="Nvalid_cov", secondary_y=False)
    cov_fig.update_layout(legend=dict(
        traceorder="normal",
        y=-1,
        x=0.25,
        yanchor="top",
        orientation='h',
        font=dict(size=14),
    ))
    cov_fig.update_xaxes(range=[bed_window[0], bed_window[1]])

    # Add TSS/TES lines if applicable
    if any('TSS' in (types or []) for _ in (types or [])) or any('gene' in (types or []) for _ in (types or [])) or any('damID' in (types or []) for _ in (types or [])):
        fig.add_shape(type="line", x0=0, y0=0, x1=0, y1=1, line=dict(color="grey", width=1.5), xref="x", yref="paper")
        fig.add_annotation(x=0, y=1, yref="paper", text="TSS", showarrow=False, yanchor="bottom", xanchor="center")
        cov_fig.add_shape(type="line", x0=0, y0=0, x1=0, y1=1, line=dict(color="grey", width=1.5), xref="x", yref="paper")

    if any('TES' in (types or []) for _ in (types or [])) or any('gene' in (types or []) for _ in (types or [])) or any('damID' in (types or []) for _ in (types or [])):
        fig.add_shape(type="line", x0=metagene_bins, y0=0, x1=metagene_bins, y1=1, line=dict(color="grey", width=1.5), xref="x", yref="paper")
        fig.add_annotation(x=metagene_bins, y=1, yref="paper", text="TES", showarrow=False, yanchor="bottom", xanchor="center")
        cov_fig.add_shape(type="line", x0=metagene_bins, y0=0, x1=metagene_bins, y1=1, line=dict(color="grey", width=1.5), xref="x", yref="paper")

    return fig, label


#Display random 100 rows from comb_bedmethyl_plot_df
print("type_selected:",type_selected)
#print head of plot_df where type is intergenic_control
print("analysis_cond:",analysis_cond)
window_s = 50
smoothing_type = "weighted" #gaussian
bed_w = 1000
num_bins = 1500
plot_type = "norm"
# Example usage:
# Note: final_df and conditions should be defined in your code

#list(plot_df['type'].unique()[15:16])
if True:
    for each_type in plot_df['type'].unique():
        print("Plotting:", each_type)
        region_fig = plot_bedmethyl(plot_df, analysis_cond, chr_types=["X"], types=[each_type], strands=["all"], window_size=window_s, metagene_bins=num_bins, smoothing_type=smoothing_type,selection_indices=[2,3], bed_window=[-bed_w,bed_w], mod_types=['m6A'],ignore_selec=[],plot_motifs = True,plot_type = plot_type)
        
        # create folder if it does not exist
        if not os.path.exists("temp_files/single_weak_mex/N2_DPY27old/"):
            os.makedirs("temp_files/single_weak_mex/N2_DPY27old/")
                    
        region_fig[0].write_image("temp_files/single_weak_mex/N2_DPY27old/"+smoothing_type+"_"+each_type+str(bed_w)+"bp_withfiber"+"_smooth"+str(window_s)+".svg")
        region_fig[0].write_image("temp_files/single_weak_mex/N2_DPY27old/"+smoothing_type+"_"+each_type+str(bed_w)+"bp_withfiber"+"_smooth"+str(window_s)+".png")
        
if False:
    # just show plot
    region_fig = plot_bedmethyl(plot_df, analysis_cond, chr_types=["X"], types=["MEX_motif_weak_dcc_2"], strands=["all"], window_size=window_s, metagene_bins=num_bins, smoothing_type=smoothing_type,selection_indices=[0,1,2,3,4,5], bed_window=[-bed_w,bed_w], mod_types=['m6A'],ignore_selec=[],plot_motifs = True,plot_type = plot_type)
    
    region_fig[0].show()


    #region_fig[0].write_image("temp_files/"+smoothing_type+"_"+region_fig[1]+type_selected[0]+str(bed_w)+"bp_withfiber"+"_smooth"+str(window_s)+".svg")
    #region_fig[0].write_image("temp_files/"+smoothing_type+"_"+region_fig[1]+type_selected[0]+str(bed_w)+"bp_withfiber"+"_smooth"+str(window_s)+".png")

#"MEX_none","MEX_D1to5","MEX_D6to9","MEX_D10"

#,bw_selections=["sdc2_chip_albritton","sdc3_chip_albritton","sdc3_chip_anderson","dpy27_chip_anderson"],bigwig_df=plot_comb_bigwig_df)#[1,'5mC'],[3,'5mC']])# #
# smoothing types: "gaussian", "weighted", "rolling"

#analysis_cond = ["N2_mixed_DPY27_dimelo_pAHia5_R10","50_mixed_dpy27-3xGNB_GFP-Hia5_mcvipi_R10","66_old_sdc2_3xGNB_GFPHia5_mChMCVIPI","N2_mixed_endogenous_R10","54_mixed_sdc2_3xmCNB_mChMCVIPI_GFPHia5"]

# print unique count bed_start values for combination of chr_type, type and condition in each comb_bedmethyl_df
#print("Unique count of bed_start values for each combination of chr_type, type and condition in comb_bedmethyl_df:")
#print(plot_df.groupby(['chr_type','type','condition'])['bed_start'].nunique())

# # #rand_suffix = nanotools.random_alpha_numeric(8)

#"center_DPY27_chip_albretton_ONLY","center_DPY27_chip_albretton;gene_ol2000;TSS_ol2000","strong_rex;DPY27_ol2000;SDC_ol2000","center_DPY27_chip_albretton;SDC_ol2000"
#"center_DPY27_chip_albretton","intergenic_control","strong_rex","weak_rex","TSS_q4","TSS_1"

In [None]:
# # reset index of plot_df
# plot_df.reset_index(drop=True, inplace=True)
# # drop type == MEX_D10_3040109_minus
# plot_df = plot_df[plot_df['type'] != 'MEX_D10_9697989_plus']
# plot_df = plot_df[plot_df['type'] != 'MEX_D10_3040109_minus']

for each_type in plot_df['type'].unique():
    print("Plotting:", each_type)
    region_fig = plot_bedmethyl(plot_df, analysis_cond, chr_types=["X"], types=[each_type], strands=["all"], window_size=window_s*2, metagene_bins=num_bins, smoothing_type=smoothing_type,selection_indices=[3,4], bed_window=[-1000,1000], mod_types=['m6A'],ignore_selec=[],plot_motifs = True)

    region_fig[0].write_image("temp_files/single_D10_MEX/"+smoothing_type+"_"+each_type+str(bed_w)+"bp_withfiber"+"_smooth"+str(window_s)+".svg")
    region_fig[0].write_image("temp_files/single_D10_MEX/"+smoothing_type+"_"+each_type+str(bed_w)+"bp_withfiber"+"_smooth"+str(window_s)+".png")

In [None]:
import os
import pandas as pd
import numpy as np
import re
from sklearn.utils import resample

# Configuration
down_var = True  # Set to True to enable downsampling

# Define the folder paths
folders = [
    '/Data1/ext_data/motifs/MEX_none',
    '/Data1/ext_data/motifs/MEX_D1to5',
    '/Data1/ext_data/motifs/MEX_D10'
]
# folders = [
#     '/Data1/ext_data/motifs/rohslab_D1_5_1000bp',
#     '/Data1/ext_data/motifs/rohslab_D6_9_1000bp',
#     '/Data1/ext_data/motifs/rohslab_D10_1000bp'
# ]

# Define the file extensions
extensions = ['HelT', 'MGW', 'ProT', 'Roll']

def get_type_from_folder(folder):
    if 'D1to5' in folder:
        return 'MEX_D1to5'
    elif 'none' in folder:
        return 'MEX_none'
    elif 'D10' in folder:
        return 'MEX_D10'
    return 'Unknown'

# Function to process a single file
def process_file(file_path, data_type):
    print(f"Processing file: {file_path}")
    data = []
    current_chr = current_start = current_end = None
    position = 0
    
    with open(file_path, 'r') as file:
        for line in file:
            # if (+) in line, replace with nothing
            line = line.replace("(+)","")
            if line.startswith('>'):
                # Reset position for new sequence
                position = 0
                # Parse the header
                chr_info = line.strip()[1:].split(':')
                current_chr = chr_info[0]
                current_start, current_end = map(int, chr_info[1].split('-'))
            else:
                # Parse the data line
                values = line.strip().split(',')
                for v in values:
                    data.append({
                        'chr': current_chr,
                        'start': current_start,
                        'end': current_end,
                        'position': position,
                        'value': float(v) if v != 'NA' else np.nan,
                        'type': data_type
                    })
                    position += 1
    
    df = pd.DataFrame(data)
    print(f"Created DataFrame with shape: {df.shape}")
    return df

# Function to downsample dataframe
def downsample_df(df):
    # Create ID column
    df['id'] = df['chr'] + ':' + df['start'].astype(str) + '-' + df['end'].astype(str)
    
    # Count unique IDs for each type
    id_counts = df.groupby('type')['id'].nunique()
    min_count = id_counts.min()
    
    # Downsample each type to the minimum count
    downsampled_dfs = []
    for type_name in df['type'].unique():
        type_df = df[df['type'] == type_name]
        unique_ids = type_df['id'].unique()
        if len(unique_ids) > min_count:
            sampled_ids = resample(unique_ids, n_samples=min_count, random_state=42)
            downsampled_dfs.append(type_df[type_df['id'].isin(sampled_ids)])
        else:
            downsampled_dfs.append(type_df)
    
    return pd.concat(downsampled_dfs, ignore_index=True)

# Process all files and create dataframes
dfs = {}
for folder in folders:
    print(f"Processing folder: {folder}")
    data_type = get_type_from_folder(folder)
    print(f"Data type: {data_type}")
    
    for ext in extensions:
        file_name = [f for f in os.listdir(folder) if f.endswith(ext)][0]
        file_path = os.path.join(folder, file_name)
        df_name = f"{ext.lower()}_df"
        
        if df_name not in dfs:
            dfs[df_name] = process_file(file_path, data_type)
        else:
            dfs[df_name] = pd.concat([dfs[df_name], process_file(file_path, data_type)], ignore_index=True)

# Downsample if enabled
if down_var:
    for df_name in dfs:
        print(f"Downsampling {df_name}...")
        dfs[df_name] = downsample_df(dfs[df_name])

# Assign dataframes to variables
helt_df = dfs['helt_df']
mgw_df = dfs['mgw_df']
prot_df = dfs['prot_df']
roll_df = dfs['roll_df']

# Print debug information
for name, df in dfs.items():
    print(f"\nDataFrame: {name}")
    print(f"Shape: {df.shape}")
    print(f"Columns: {df.columns}")
    print(f"Sample data:\n{df.head()}")
    print(f"NaN values: {df['value'].isna().sum()}")
    print(f"Unique types: {df['type'].unique()}")
    print(f"Position range: {df['position'].min()} to {df['position'].max()}")
    print(f"Unique IDs per type:")
    print(df.groupby('type')['id'].nunique())

print("\nData processing complete.")

In [None]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

# Configuration
WINDOW_SIZE = 50  # Rolling average window size

# Function to process data
def process_data(df):
    # Create ID column
    df['id'] = df['chr'] + ':' + df['start'].astype(str) + '-' + df['end'].astype(str)
    
    # Sort by ID and position
    df = df.sort_values(['id', 'position'])
    
    # Apply rolling average
    df['smoothed_value'] = df.groupby('id')['value'].transform(lambda x: x.rolling(window=WINDOW_SIZE, center=True, min_periods=1).mean())
    
    # Calculate average by position and type
    avg_df = df.groupby(['position', 'type'])['smoothed_value'].mean().reset_index()
    
    # Center the x-axis
    avg_df['centered_position'] = avg_df['position'] - avg_df['position'].mean()
    
    return avg_df

# Function to create 2x2 subplots
def create_subplots(dfs, titles, y_labels):
    fig = make_subplots(rows=2, cols=2, subplot_titles=titles)
    
    # Get the default Plotly color sequence
    colors = px.colors.qualitative.Plotly

    # Create a mapping of types to colors
    unique_types = sorted(set(type for df in dfs.values() for type in df['type'].unique()))
    color_map = {type: colors[i % len(colors)] for i, type in enumerate(unique_types)}
    
    for i, (df_name, df) in enumerate(dfs.items()):
        row = i // 2 + 1
        col = i % 2 + 1
        
        avg_df = process_data(df)
        
        for data_type in avg_df['type'].unique():
            type_data = avg_df[avg_df['type'] == data_type]
            fig.add_trace(
                go.Scatter(
                    x=type_data['centered_position'],
                    y=type_data['smoothed_value'],
                    mode='lines',
                    name=data_type,
                    legendgroup=data_type,
                    showlegend=(i == 0),  # Show legend only for the first subplot
                    line=dict(color=color_map[data_type])
                ),
                row=row, col=col
            )
        
        fig.update_xaxes(title_text='Distance from MEX pos 1', row=row, col=col)
        fig.update_yaxes(title_text=y_labels[i], row=row, col=col)
    
    fig.update_layout(
        height=800,
        width=1000,
        title_text="DNA Shape Analysis",
        template='plotly_white',
        legend_title_text='Type'
    )
    
    return fig

# Prepare data for plotting
dfs = {
    'HelT': helt_df,
    'MGW': mgw_df,
    'ProT': prot_df,
    'Roll': roll_df
}

titles = [
    'HelT - DNA Shape Analysis',
    'MGW - DNA Shape Analysis',
    'ProT - DNA Shape Analysis',
    'Roll - DNA Shape Analysis'
]

y_labels = [
    'Helical twist (°)',
    'Major groove width (Å)',
    'Propeller twist (°)',
    'Roll (°)'
]

# Create and display the subplot figure
fig = create_subplots(dfs, titles, y_labels)
fig.show()

# Print some statistics
for name, df in dfs.items():
    print(f"\n{name} DataFrame:")
    print(f"Total rows: {len(df)}")
    print(f"Unique IDs: {df['id'].nunique()}")
    print(f"Position range: {df['position'].min()} to {df['position'].max()}")
    print(f"Unique types: {df['type'].unique()}")
    print(f"Average smoothed values by type:")
    print(df.groupby('type')['value'].mean())

In [None]:
### Aligning on motif:

# Convert this pseudo code into python code to be executed in a jupyter notebook cell
# import tsv from: "/Data1/ext_data/motifs/fimo_D10_100bp_mex.tsv" into dataframe called fimo_df
# format looks like:
# motif_id	motif_alt_id	sequence_name	start	stop	strand	score	p-value	q-value	matched_sequence
# 1	TNYCCCTKCSCHWWT-MEME-1	chrV	20651347	20651361	+	21.7701	3.08E-09	0.381	TCCCCCTTCGCCATT
# 1	TNYCCCTKCSCHWWT-MEME-1	chrV	1743236	1743250	+	21.3448	7.40E-09	0.381	TCCCCCTGCCCATTT

# keep only sequence_name, start, stop, strand, p-value columns, and rename them to chr, start, stop, strand, p_value

# in chr column, replace all "chr" with "CHROMOSOME_"

# use IntervalTree to add column corresponding to the 'chip_category' in the 'result_df_cat' dataframe, if the start, end in fimo_df overlaps the start, end in result_df_cat and the 'chr's match.
# otherwise set to "none"
# result_df_cat is the dataframe with following columns:
# Index(['type', 'chr', 'start', 'end', 'length', 'abs_summit', 'pileup',
#        'LOG10(pvalue)', 'fold_enrichment', ... , 'chip_category'],
#       dtype='object')

# Add a 'cluster_count' column corresponding to the number of rows in fimo_df that are within a configurable window (e.g. 100bp) of the current row's start or end positions and on the same 'chr'

# plot a boxplot of 'cluster_count' for each 'chip_category' in result_df_cat using seaborn, showing scatter points

# add a configurable bed_window e.g. 500bp to either side of each region?


In [None]:

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from intervaltree import Interval, IntervalTree
from collections import defaultdict

# Configurable parameters
window_size = 200  # For cluster count
bed_window = 1000  # For expanding regions in result_df_cat

# Step 1: Import the TSV file into a dataframe called fimo_df
fimo_df = pd.read_csv("/Data1/ext_data/motifs/fimo_JANS_rex.csv", sep=',')
# print head of fimo_df
print(fimo_df.head())
# /Data1/ext_data/motifs/fimo_D10_100bp_mex.tsv # \t

# Step 2: Keep only the specified columns and rename them
fimo_df = fimo_df[['sequence_name', 'start', 'stop', 'strand', 'score','p-value']]
fimo_df = fimo_df.rename(columns={'sequence_name': 'chr', 'p-value': 'p_value'})
# drop rows where score < 15
fimo_df = fimo_df[fimo_df['score'] >= 12]
# drop score column
fimo_df.drop(columns=['score'], inplace=True)

# Step 3: Replace "chr" with "CHROMOSOME_" in the 'chr' column
fimo_df['chr'] = fimo_df['chr'].str.replace('chr', 'CHROMOSOME_')

# Assume result_df_cat is already loaded into the environment
# Ensure 'chr' column matches the format in fimo_df
result_df_cat['chr'] = result_df_cat['chr'].str.replace('chr', 'CHROMOSOME_')

# Step 5: Build IntervalTrees for each chromosome in result_df_cat
interval_trees = defaultdict(IntervalTree)
for idx, row in result_df_cat.iterrows():
    chrom = row['chr']
    start = row['start']
    end = row['end']
    category = row['chip_category']
    interval_trees[chrom][start:end] = category

# Add an 'id' column to fimo_df to keep track of original rows
fimo_df['id'] = fimo_df.index

# Step 6: Expand fimo_df to account for multiple overlapping chip_categories
expanded_rows = []

for idx, row in fimo_df.iterrows():
    chrom = row['chr']
    start = row['start']
    end = row['stop']
    overlaps = interval_trees[chrom][start:end] if chrom in interval_trees else []
    if overlaps:
        for interval in overlaps:
            new_row = row.copy()
            new_row['chip_category'] = interval.data
            expanded_rows.append(new_row)
    else:
        # No overlaps, chip_category is 'none'
        new_row = row.copy()
        new_row['chip_category'] = 'none'
        expanded_rows.append(new_row)

# Create the expanded DataFrame
fimo_expanded_df = pd.DataFrame(expanded_rows)

# Step 7: Build IntervalTrees for fimo_df for cluster counting (using original fimo_df)
fimo_trees = defaultdict(IntervalTree)
for idx, row in fimo_df.iterrows():
    chrom = row['chr']
    start = row['start']
    end = row['stop']
    fimo_trees[chrom][start:end] = row['id']  # Use 'id' to uniquely identify intervals


# Step 8: Calculate 'cluster_count' for each row in fimo_df
def get_cluster_count(row):
    chrom = row['chr']
    start = row['start']
    end = row['stop']
    intervals = fimo_trees[chrom][start - window_size:end + window_size] if chrom in fimo_trees else []
    count = len(intervals)
    return max(count, 0)


fimo_df['cluster_count'] = fimo_df.apply(get_cluster_count, axis=1)

# Merge 'cluster_count' back into the expanded DataFrame
fimo_expanded_df = fimo_expanded_df.merge(fimo_df[['id', 'cluster_count']], on='id', how='left')

# Step 9: Compute mean cluster_count and sample sizes per chip_category
category_stats = fimo_expanded_df.groupby('chip_category')['cluster_count'].agg(['mean', 'count']).reset_index()

# Sort by mean cluster_count
category_stats = category_stats.sort_values(by='mean', ascending=False)

# Create x-axis labels with 'n='
category_stats['label'] = category_stats.apply(lambda x: f"{x['chip_category']}\n(n={int(x['count'])})", axis=1)

# Plot the bar plot
plt.figure(figsize=(12, 6))
sns.barplot(x='label', y='mean', data=category_stats, palette='viridis')
plt.title('Average Cluster Count by Chip Category')
plt.xlabel('Chip Category')
plt.ylabel('Average Cluster Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

nanotools.display_sample_rows(fimo_df, 5)

# Add bed_window to either side of each region in result_df_cat
fimo_expanded_df['start'] = fimo_expanded_df['start'] - bed_window
fimo_expanded_df['start'] = fimo_expanded_df['start'].apply(lambda x: max(0, x))
fimo_expanded_df['stop'] = fimo_expanded_df['stop'] + bed_window

# drop nan rows in chrom column
fimo_expanded_df.dropna(subset=['chr'], inplace=True)

# add "MEX_" to the start of each 'chip_category' in fimo_expanded_df
fimo_expanded_df['chip_category'] = 'MEX_' + fimo_expanded_df['chip_category']
# convert start and stop to int
fimo_expanded_df['start'] = fimo_expanded_df['start'].astype(int)
fimo_expanded_df['stop'] = fimo_expanded_df['stop'].astype(int)

import pandas as pd

# Assuming 'fimo_expanded_df' is available from previous code

# Create 'combined_bed_df_mex_cat' with specified columns
combined_bed_df_mex_cat = fimo_expanded_df[['chr', 'start', 'stop', 'strand', 'chip_category']].copy()
combined_bed_df_mex_cat = combined_bed_df_mex_cat.rename(columns={
    'chr': 'chrom',
    'start': 'bed_start',
    'stop': 'bed_end',
    'strand': 'bed_strand',
    'chip_category': 'type'
})

# Add 'chr_type' column
combined_bed_df_mex_cat['chr_type'] = combined_bed_df_mex_cat['chrom'].apply(
    lambda x: 'X' if x == 'CHROMOSOME_X' else 'Autosome')

# Display the first few rows
print("combined_bed_df_mex_cat:")
nanotools.display_sample_rows(combined_bed_df_mex_cat, 5)

# Create 'combined_bed_df_mex_clust' with 'type' column as "clust_" + cluster_count
combined_bed_df_mex_clust = fimo_expanded_df[['chr', 'start', 'stop', 'strand', 'cluster_count']].copy()
combined_bed_df_mex_clust = combined_bed_df_mex_clust.rename(columns={
    'chr': 'chrom',
    'start': 'bed_start',
    'stop': 'bed_end',
    'strand': 'bed_strand'
})

# 'type' column is 'clust_' + cluster_count
combined_bed_df_mex_clust['type'] = 'clust_' + combined_bed_df_mex_clust['cluster_count'].astype(str)

# Add 'chr_type' column
combined_bed_df_mex_clust['chr_type'] = combined_bed_df_mex_clust['chrom'].apply(
    lambda x: 'X' if x == 'CHROMOSOME_X' else 'Autosome')

# drop cluster_count column
combined_bed_df_mex_clust.drop(columns='cluster_count', inplace=True)

# Display the first few rows
print("\ncombined_bed_df_mex_clust:")
nanotools.display_sample_rows(combined_bed_df_mex_clust, 5)

# print columns
print("\ncombined_bed_df_mex_cat columns:")
print(combined_bed_df_mex_cat.columns)

# convert MEX_D1, MEX_D2, MEX_D3, MEX_D4, MEX_D5 to MEX_D1to5
combined_bed_df_mex_cat['type'] = combined_bed_df_mex_cat['type'].replace({
    'MEX_D1': 'MEX_D1to5',
    'MEX_D2': 'MEX_D1to5',
    'MEX_D3': 'MEX_D1to5',
    'MEX_D4': 'MEX_D1to5',
    'MEX_D5': 'MEX_D1to5'
})

# calculate end - start and return unique values
combined_bed_df_mex_cat['length'] = combined_bed_df_mex_cat['bed_end'] - combined_bed_df_mex_cat['bed_start']
print("\nUnique lengths:")
print(combined_bed_df_mex_cat['length'].unique())

# convert MEX_D6, MEX_D7, MEX_D8, MEX_D9 to MEX_D6to9
combined_bed_df_mex_cat['type'] = combined_bed_df_mex_cat['type'].replace({
    'MEX_D6': 'MEX_D6to9',
    'MEX_D7': 'MEX_D6to9',
    'MEX_D8': 'MEX_D6to9',
    'MEX_D9': 'MEX_D6to9'
})

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from intervaltree import Interval, IntervalTree
from collections import defaultdict

# Configurable parameters
window_size = 200  # For cluster count
bed_window = 100  # For expanding regions in result_df_cat

# Step 1: Import the TSV files into dataframes
fimo_files = [
    "/Data1/ext_data/motifs/fimo_MEX_0.01.tsv"#,
    #"/Data1/ext_data/motifs/fimo_MEXII_0.01.tsv",
    #"/Data1/ext_data/motifs/fimo_motifc_0.01.tsv"
]

fimo_dfs = []
for file in fimo_files:
    # Read the TSV file
    df = pd.read_csv(file, sep='\t')
    fimo_dfs.append(df)

# Combine the dataframes into one
fimo_df = pd.concat(fimo_dfs, ignore_index=True)

# Step 2: Pre-filter by p-value < 0.0001
fimo_df = fimo_df[fimo_df['p-value'] < 0.00001] # 0.000001 #0.00001

# Step 3: Deduplicate overlapping rows on the same sequence_name (chromosome)
# Define the priority order
motif_priority = {'MEXII': 1, 'MEX': 2, 'motifC': 3}

# Map motif priorities
fimo_df['motif_priority'] = fimo_df['motif_id'].map(motif_priority)

# Drop any rows where motif_id is not in the priority list
fimo_df = fimo_df.dropna(subset=['motif_priority'])
fimo_df['motif_priority'] = fimo_df['motif_priority'].astype(int)

# Sort by sequence_name, start, and motif_priority
fimo_df = fimo_df.sort_values(by=['sequence_name', 'start', 'motif_priority'])

# Initialize a dictionary to store IntervalTrees for each chromosome
chrom_trees = defaultdict(IntervalTree)
deduped_indices = set()

# Iterate over the dataframe to deduplicate
for idx, row in fimo_df.iterrows():
    chrom = row['sequence_name']
    start = row['start']
    end = row['stop']
    priority = row['motif_priority']

    # Check for overlaps
    overlaps = chrom_trees[chrom][start:end]
    if not overlaps:
        # No overlap, add interval
        chrom_trees[chrom][start:end] = (priority, idx)
        deduped_indices.add(idx)
    else:
        # There is an overlap; check priority
        replace = False
        for interval in overlaps:
            existing_priority, existing_idx = interval.data
            if start < interval.end and end > interval.begin:
                if priority < existing_priority:
                    # Current row has higher priority
                    chrom_trees[chrom].remove(interval)
                    chrom_trees[chrom][start:end] = (priority, idx)
                    deduped_indices.discard(existing_idx)
                    deduped_indices.add(idx)
                    replace = True
                else:
                    replace = False
        if not replace:
            continue

# Keep only deduplicated rows
fimo_df = fimo_df.loc[deduped_indices]

# Step 4: Adjust columns
fimo_df = fimo_df[['sequence_name', 'start', 'stop', 'strand', 'score', 'p-value', 'motif_id']]
fimo_df = fimo_df.rename(columns={'sequence_name': 'chr', 'p-value': 'p_value'})

# Step 5: Replace "chr" with "CHROMOSOME_" in the 'chr' column
fimo_df['chr'] = fimo_df['chr'].str.replace('chr', 'CHROMOSOME_')

# Ensure 'chr' column in result_df_cat matches the format in fimo_df
result_df_cat['chr'] = result_df_cat['chr'].str.replace('chr', 'CHROMOSOME_')

# Step 6: Build IntervalTrees for each chromosome in result_df_cat
interval_trees = defaultdict(IntervalTree)
for idx, row in result_df_cat.iterrows():
    chrom = row['chr']
    start = row['start']
    end = row['end']
    category = row['chip_category']
    interval_trees[chrom][start:end] = category

# Add an 'id' column to fimo_df to keep track of original rows
fimo_df.reset_index(drop=True, inplace=True)
fimo_df['id'] = fimo_df.index

# Step 7: Expand fimo_df to account for multiple overlapping chip_categories
expanded_rows = []

for idx, row in fimo_df.iterrows():
    chrom = row['chr']
    start = row['start']
    end = row['stop']
    overlaps = interval_trees[chrom][start:end] if chrom in interval_trees else []
    if overlaps:
        for interval in overlaps:
            new_row = row.copy()
            new_row['chip_category'] = interval.data
            expanded_rows.append(new_row)
    else:
        # No overlaps, chip_category is 'none'
        new_row = row.copy()
        new_row['chip_category'] = 'none'
        expanded_rows.append(new_row)

# Create the expanded DataFrame
fimo_expanded_df = pd.DataFrame(expanded_rows)

# Step 8: Build IntervalTrees for fimo_df for cluster counting
fimo_trees = defaultdict(IntervalTree)
for idx, row in fimo_df.iterrows():
    chrom = row['chr']
    start = row['start']
    end = row['stop']
    fimo_trees[chrom][start:end] = row['id']  # Use 'id' to uniquely identify intervals

# Step 9: Calculate 'cluster_count' for each row in fimo_df
def get_cluster_count(row):
    chrom = row['chr']
    start = row['start']
    end = row['stop']
    intervals = fimo_trees[chrom][start - window_size:end + window_size] if chrom in fimo_trees else []
    count = len(intervals)
    return max(count, 0)

fimo_df['cluster_count'] = fimo_df.apply(get_cluster_count, axis=1)

# Merge 'cluster_count' back into the expanded DataFrame
fimo_expanded_df = fimo_expanded_df.merge(fimo_df[['id', 'cluster_count']], on='id', how='left')

# Step 10: Compute mean cluster_count and sample sizes per chip_category
category_stats = fimo_expanded_df.groupby('chip_category')['cluster_count'].agg(['mean', 'count']).reset_index()

# Sort by mean cluster_count
category_stats = category_stats.sort_values(by='mean', ascending=False)

# Create x-axis labels with 'n='
category_stats['label'] = category_stats.apply(lambda x: f"{x['chip_category']}\n(n={int(x['count'])})", axis=1)

# Plot the bar plot
plt.figure(figsize=(12, 6))
sns.barplot(x='label', y='mean', data=category_stats, palette='viridis')
plt.title('Average Cluster Count by Chip Category')
plt.xlabel('Chip Category')
plt.ylabel('Average Cluster Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Add bed_window to either side of each region in fimo_expanded_df
fimo_expanded_df['start'] = fimo_expanded_df['start'] - bed_window
fimo_expanded_df['start'] = fimo_expanded_df['start'].apply(lambda x: max(0, x))
fimo_expanded_df['stop'] = fimo_expanded_df['stop'] + bed_window

# Drop NaN rows in 'chr' column
fimo_expanded_df.dropna(subset=['chr'], inplace=True)

# Add "MEX_" to the start of each 'chip_category' in fimo_expanded_df
fimo_expanded_df['chip_category'] = 'MEX_' + fimo_expanded_df['chip_category']

# Convert 'start' and 'stop' to integers
fimo_expanded_df['start'] = fimo_expanded_df['start'].astype(int)
fimo_expanded_df['stop'] = fimo_expanded_df['stop'].astype(int)

# Create 'combined_bed_df_mex_cat' with specified columns
combined_bed_df_mex_cat = fimo_expanded_df[['chr', 'start', 'stop', 'strand', 'chip_category']].copy()
combined_bed_df_mex_cat = combined_bed_df_mex_cat.rename(columns={
    'chr': 'chrom',
    'start': 'bed_start',
    'stop': 'bed_end',
    'strand': 'bed_strand',
    'chip_category': 'type'
})

# Add 'chr_type' column
combined_bed_df_mex_cat['chr_type'] = combined_bed_df_mex_cat['chrom'].apply(
    lambda x: 'X' if x == 'CHROMOSOME_X' else 'Autosome')

# Create 'combined_bed_df_mex_clust' with 'type' column as "clust_" + cluster_count
combined_bed_df_mex_clust = fimo_expanded_df[['chr', 'start', 'stop', 'strand', 'cluster_count']].copy()
combined_bed_df_mex_clust = combined_bed_df_mex_clust.rename(columns={
    'chr': 'chrom',
    'start': 'bed_start',
    'stop': 'bed_end',
    'strand': 'bed_strand'
})

# 'type' column is 'clust_' + cluster_count
combined_bed_df_mex_clust['type'] = 'clust_' + combined_bed_df_mex_clust['cluster_count'].astype(str)

# Add 'chr_type' column
combined_bed_df_mex_clust['chr_type'] = combined_bed_df_mex_clust['chrom'].apply(
    lambda x: 'X' if x == 'CHROMOSOME_X' else 'Autosome')

# Drop 'cluster_count' column
combined_bed_df_mex_clust.drop(columns='cluster_count', inplace=True)

# Update 'type' in 'combined_bed_df_mex_cat'
combined_bed_df_mex_cat['type'] = combined_bed_df_mex_cat['type'].replace({
    'MEX_D1': 'MEX_D1to5',
    'MEX_D2': 'MEX_D1to5',
    'MEX_D3': 'MEX_D1to5',
    'MEX_D4': 'MEX_D1to5',
    'MEX_D5': 'MEX_D1to5',
    'MEX_D6': 'MEX_D6to9',
    'MEX_D7': 'MEX_D6to9',
    'MEX_D8': 'MEX_D6to9',
    'MEX_D9': 'MEX_D6to9'
})

# Display the first few rows (optional)
print("combined_bed_df_mex_cat:")
nanotools.display_sample_rows(combined_bed_df_mex_cat, 5)

print("\ncombined_bed_df_mex_clust:")
nanotools.display_sample_rows(combined_bed_df_mex_clust, 5)




In [None]:
### Plot boxplot centered on MEX motifs:
### Generate dataframe for plotting correlation between chip and accessibility
import pandas as pd
import pyBigWig
import numpy as np

def calculate_bigwig_scores(df, qnormalized_bigwig_paths):
    
    # Initialize new columns for each bigWig file
    stats = ['average', 'median', 'sum', 'max']
    for i, bw_path in enumerate(qnormalized_bigwig_paths):
        for stat in stats:
            df[f'{stat}_bw_{i + 1}'] = np.nan

    # Process each bigWig file
    for i, bw_path in enumerate(qnormalized_bigwig_paths):
        print(f"Processing bigWig file {i + 1}/{len(qnormalized_bigwig_paths)}: {bw_path}")
        with pyBigWig.open(bw_path) as bw:
            for index, row in df.iterrows():
                chrom = row['chr']
                start = row['start']
                end = row['end']

                try:
                    values = bw.values(chrom, start, end)

                    values = [v for v in values if v is not None]  # Remove any None values
                    values = [v for v in values if v is not None and not np.isnan(v)]

                    if values:
                        df.at[index, f'average_bw_{i + 1}'] = np.mean(values)
                        df.at[index, f'median_bw_{i + 1}'] = np.median(values)
                        df.at[index, f'sum_bw_{i + 1}'] = np.sum(values)
                        df.at[index, f'max_bw_{i + 1}'] = np.max(values)
                except RuntimeError:
                    # This can happen if the region is not in the bigWig file
                    pass

        print(f"Finished processing bigWig file {i + 1}")

    # Calculate normalized average columns
    for i in range(1, len(qnormalized_bigwig_paths) + 1):
        df[f'norm_avg_bw_{i}'] = df[f'average_bw_{i}'] / df['average_bw_1'] - 1

    return df


bed_file_path = "/Data1/ext_data/qiming_2024/SDC2_SDC3_20_peaks_500_2000_RPKM.csv" # SDC2_SDC3_20_peaks_500_2000_RPKM.csv or SDC2_SDC3_gt20_rpkm.csv for whole regions

center_length = bed_window * 2

# Read the bed file
chip_bed = pd.read_csv(bed_file_path)

# drop rows where length not equal to center_length
chip_bed = chip_bed[chip_bed['length'] == center_length]

# recalculate start and end based on abs_summit and 1/2 of length
chip_bed['start'] = chip_bed['abs_summit'] - chip_bed['length'] // 2
chip_bed['end'] = chip_bed['abs_summit'] + chip_bed['length'] // 2

nanotools.display_sample_rows(chip_bed, 3)
#print columms
print(chip_bed.columns)

### Add bed regions of interest (e.g. control regions
# Create a new dataframe with the same columns as chip_bed
new_rows = pd.DataFrame(columns=chip_bed.columns)

# Map the columns from combined_bed_df_ext to chip_bed
new_rows['type'] = combined_bed_df_mex_cat['type']
new_rows['chr'] = combined_bed_df_mex_cat['chrom']
new_rows['start'] = combined_bed_df_mex_cat['bed_start']
new_rows['end'] = combined_bed_df_mex_cat['bed_end']

# Calculate length and abs_summit
new_rows['length'] = new_rows['end'] - new_rows['start']
new_rows['abs_summit'] = ((new_rows['start'] + new_rows['end']) / 2).astype(int)

# Append the new rows to chip_bed
chip_bed = pd.concat([chip_bed, new_rows], ignore_index=True)

# Sort the resulting dataframe by chr and start position
#chip_bed = chip_bed.sort_values(['chr', 'start']).reset_index(drop=True)

result_df = calculate_bigwig_scores(chip_bed, raw_bw_files)

# Display the first few rows of the resulting dataframe
nanotools.display_sample_rows(result_df, 10)
# print column names
print(result_df.columns)

# display number of rows by type
result_df['type'].value_counts()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import os


def filter_and_plot(result_df, qvalue_cutoff=None, percentile_cutoff=None, num_categories=4):
    # Filter rows based on qvalue cutoff if provided
    if qvalue_cutoff is not None:
        result_df = result_df[result_df['LOG10(qvalue)'] >= qvalue_cutoff]

    # drop rows with SDC in type
    result_df = result_df[~result_df['type'].str.contains('SDC')]
    
    # rename type == "MEX_internenic" to "MEX_none"
    result_df['type'] = result_df['type'].replace('MEX_intergenic', 'MEX_none')

    # create a column called chip_category that is equal to type
    result_df['chip_category'] = result_df['type']
    # set type to the first element of chip_cetegory when split by _
    result_df['type'] = result_df['chip_category'].apply(lambda x: x.split('_')[0])

    # Extract experiment names
    experiment_names = analysis_cond  # Replace with your actual variable

    # Create a custom RdBu colormap for SDC2 and SDC3 (blue to red)
    colors_sdc = plt.cm.RdBu_r(np.linspace(0, 1, num_categories))
    cmap_sdc = LinearSegmentedColormap.from_list("custom_RdBu", colors_sdc)

    # Colors for 'rex' and 'intergenic'
    color_rex = 'green'
    color_intergenic = 'orange'

    def plot_boxplots(data, bw_column, ax1, ax2, title, palette_sdc, buffer=1):
        # Create a custom order for the chip_category
        def get_order_key(cat):
            if '_D' in cat:
                return cat.split('_D')[1] #int(cat.split('_D')[1])
            else:
                return float('inf')  # Place categories without 'D' at the end

        category_order = sorted(
            data['chip_category'].unique()
        )

        print("Adjusted category_order:", category_order)

        # Boxplot properties
        boxprops = dict(facecolor='none')
        medianprops = dict(color='red', linewidth=2)

        print("data[data['type'] == 'MEX']")
        print(data[data['type'] == 'MEX'])

        # Plot SDC2 on the first subplot
        sns.boxplot(
            x='chip_category', y=bw_column, data=data[data['type'] == 'MEX'],
            ax=ax1, palette=palette_sdc, showfliers=False, order=category_order,
            boxprops=boxprops, medianprops=medianprops, width=1
        )

        # Plot 'rex' and 'intergenic' on the second subplot
        sns.boxplot(
            x='chip_category', y=bw_column, data=data[data['type'].isin(['rex', 'intergenic'])],
            ax=ax2, palette=[color_rex, color_intergenic], showfliers=False, order=['rex', 'intergenic'],
            boxprops=boxprops, medianprops=medianprops, width=1
        )

        # Set titles and labels
        ax1.set_title(f'Boxplot for {title}')
        ax1.set_xlabel('ChIP Category')
        ax1.set_ylabel('Average Methylation')
        ax2.set_xlabel('ChIP Category')
        ax2.set_ylabel('')

        # Rotate x-axis labels and decrease font size
        ax1.tick_params(axis='x', rotation=45, labelsize=8)
        ax2.tick_params(axis='x', rotation=45, labelsize=8)

        # Adjust y-axis label font size
        ax1.yaxis.label.set_fontsize(12)

        # Set title font size
        ax1.title.set_fontsize(16)

        # Remove background
        ax1.set_facecolor('none')
        ax2.set_facecolor('none')
        ax1.grid(False)
        ax2.grid(False)

        # Remove legends from individual subplots
        ax1.legend().remove()
        ax2.legend().remove()

        data = data[data['type'] != 'SDC3']
        # Add (n=) for number of datapoints on the x-axis
        for ax in [ax1, ax2]:
            for i, label in enumerate(ax.get_xticklabels()):
                category = label.get_text()
                count = data[(data['chip_category'] == category) & (data[bw_column].notna())].shape[0]
                ax.text(i, ax.get_ylim()[0], f'(n={count})', ha='center', va='top', fontsize=8)

        # Set y-axis limits to 0 to 0.5
        ax1.set_ylim(0, 0.5)
        ax2.set_ylim(0, 0.5)

        # Add buffer space by setting xlim for each subplot
        ax1.set_xlim(-buffer, len(category_order) - 1 + buffer)  # Adjust limits to add buffer space
        ax2.set_xlim(-buffer, 1 + buffer)  # Add buffer space for two categories: rex and intergenic

    # Generate palette for SDC2 categories
    palette_sdc = [cmap_sdc(i / (num_categories - 1)) for i in range(num_categories)]

    # Set up the subplots
    fig, axes = plt.subplots(
        6, 4, figsize=(18, 28),
        gridspec_kw={'width_ratios': [5, 1, 5, 1]}
    )

    # Plot for each average_bw column
    for i, (col, exp_name) in enumerate(
            zip([col for col in result_df.columns if col.startswith('average_')], experiment_names)
    ):
        plot_boxplots(
            result_df, col,
            axes[i // 2, 2 * (i % 2)],
            axes[i // 2, 2 * (i % 2) + 1],
            exp_name, palette_sdc
        )

    # Remove overall background
    fig.patch.set_facecolor('none')

    # Adjust layout
    plt.tight_layout()

    # Adjust subplots to add buffer space around the boxplots
    plt.subplots_adjust(left=0.05, right=0.95)

    # Save as PNG and SVG
    save_path = '/Data1/git/meyer-nanopore/scripts/analysis/temp_files/'
    os.makedirs(save_path, exist_ok=True)

    png_path = os.path.join(save_path, 'SDC3_boxplot_figure.png')
    svg_path = os.path.join(save_path, 'SDC3_boxplot_figure.svg')

    plt.savefig(png_path, format='png', dpi=300, bbox_inches='tight', transparent=True)
    plt.savefig(svg_path, format='svg', bbox_inches='tight', transparent=True)

    print(f"Figures saved as:\n{png_path}\n{svg_path}")

    # Show plot
    plt.show()

    return result_df


# Example usage:
# Assuming result_df is your DataFrame and analysis_cond is defined
result_df_cat = filter_and_plot(result_df, percentile_cutoff=None, num_categories=10)

# Display sample rows and print columns
nanotools.display_sample_rows(result_df_cat, 5)
print(result_df_cat.columns)

# Print count by type
print(result_df_cat['type'].value_counts())

In [None]:
from Bio import SeqIO
import pandas as pd
import numpy as np
import plotly.express as px

# Read the fasta file into a dictionary
chrom_dict = SeqIO.to_dict(SeqIO.parse("/Data1/reference/c_elegans.WS235.genomic.fa", "fasta"))

# Drop all rows where chr_type == Autosome
combined_bed_df_mex_cat = combined_bed_df_mex_cat[combined_bed_df_mex_cat['chr_type'] != 'Autosome']

# drop rows where type equal to "MEX_intergenic" or "MEX_none"
combined_bed_df_mex_cat = combined_bed_df_mex_cat[~combined_bed_df_mex_cat['type'].isin(['MEX_rex'])]

# rename 'MEX_intergenic', to 'MEX_none'
combined_bed_df_mex_cat['type'] = combined_bed_df_mex_cat['type'].replace({'MEX_intergenic': 'MEX_none'})

# Define a function to get the sequence for each row
def get_sequence(row):
    chrom = row['chrom']
    bed_start = int(row['bed_start'])
    bed_end = int(row['bed_end'])
    bed_strand = row['bed_strand']
    # Get the sequence from the chromosome
    seq = chrom_dict[chrom].seq[bed_start:bed_end]
    # If the strand is negative, get the reverse complement
    if bed_strand == '-':
        seq = seq.reverse_complement()
    return str(seq)

# Apply the function to the dataframe to create the 'sequence' column
combined_bed_df_mex_cat['sequence'] = combined_bed_df_mex_cat.apply(get_sequence, axis=1)

# if length of sequence is an odd number, drop the last character from each sequence
combined_bed_df_mex_cat['sequence'] = combined_bed_df_mex_cat['sequence'].apply(lambda x: x[:-1] if len(x) % 2 != 0 else x)

# Keep only rows where bed_strand == '+'
# combined_bed_df_mex_cat = combined_bed_df_mex_cat[combined_bed_df_mex_cat['bed_strand'] == '+']

# Display the first few rows
print("\ncombined_bed_df_mex_cat with sequence:")
print(combined_bed_df_mex_cat.head())

# **Updated Code to Output the Number of Rows Included for Each Combination of 'type' and 'chr_type'**
# Count the number of sequences per 'type' and 'chr_type'
type_chr_counts = combined_bed_df_mex_cat.groupby(['type', 'chr_type']).size().reset_index(name='count')

# Display the table
print("\nNumber of sequences included for each combination of 'type' and 'chr_type':")
print(type_chr_counts)

# **New Code to Create the Vertical Bar Plot with Data Labels**
# Plotting the counts with 'type' on x-axis, colored by 'chr_type'
# Set the desired order for the 'type' column
type_order = ['MEX_none', 'MEX_D1to5', 'MEX_D6to9', 'MEX_D10']

# Aggregate counts by both 'type' and 'chr_type'
type_chr_counts_agg = type_chr_counts.groupby(['type', 'chr_type'], as_index=False).agg({'count': 'sum'})

# Ensure 'type' column is a categorical with the specified order
type_chr_counts_agg['type'] = pd.Categorical(type_chr_counts_agg['type'], categories=type_order, ordered=True)

# Create a MultiIndex for all combinations of 'type' and 'chr_type'
chr_types = type_chr_counts_agg['chr_type'].unique()
all_combinations = pd.MultiIndex.from_product([type_order, chr_types], names=['type', 'chr_type'])

# Convert the DataFrame to use this MultiIndex
type_chr_counts_agg = type_chr_counts_agg.set_index(['type', 'chr_type'])

# Reindex to include all combinations, filling missing values with 0
type_chr_counts_agg = type_chr_counts_agg.reindex(all_combinations, fill_value=0).reset_index()

# Plotting the counts with 'type' on x-axis, colored by 'chr_type'
fig_bar = px.bar(
    type_chr_counts_agg,
    x='type',
    y='count',
    color='chr_type',  # Color by 'chr_type'
    barmode='group',
    text='count',  # Add data labels
    labels={
        'type': 'Type',
        'count': 'Number of Sequences',
        'chr_type': 'Chromosome Type'
    },
    title='Number of Sequences per Type and Chromosome Type'
)

# Update the position of the text labels to appear on top of the bars
fig_bar.update_traces(textposition='outside')

# Adjust the layout and size
fig_bar.update_layout(template='plotly_white', width=900, height=500)

# Update y-axis range
fig_bar.update_yaxes(range=[0, fig_bar.data[0].y.max() + 300])

# Display the bar plot
fig_bar.show()




### Bar plot by chr_type:
# Count the number of sequences per 'chr_type' only
chr_counts = combined_bed_df_mex_cat['chr_type'].value_counts().reset_index()
chr_counts.columns = ['chr_type', 'count']

# Adjust counts based on chr_type
chr_counts['adjusted_count'] = chr_counts.apply(
    lambda row: row['count'] / 170 if row['chr_type'] == 'X' else (row['count'] / 830 if row['chr_type'] == 'Autosome' else row['count']),
    axis=1
)

# Round the adjusted count to the nearest decimal
chr_counts['adjusted_count'] = chr_counts['adjusted_count'].round(1)

# Create the bar plot by 'chr_type', coloring by 'chr_type'
fig_bar_chr = px.bar(
    chr_counts,
    x='chr_type',
    y='adjusted_count',  # Use the adjusted counts
    color='chr_type',  # Color bars by 'chr_type'
    text='adjusted_count',  # Add data labels
    labels={
        'chr_type': 'Chromosome Type',
        'adjusted_count': 'Motifs / 100kb'  # Update y-axis label
    },
    title='Adjusted Number of Sequences per Chromosome Type'
)

# Update the position of the text labels to appear on top of the bars
fig_bar_chr.update_traces(textposition='outside')

# Adjust the layout to match the existing style
fig_bar_chr.update_layout(template='plotly_white', width=450, height=500)

# Display the bar plot
fig_bar_chr.show()


# Configurable smoothing parameter
window_size = 25 # Change this value to adjust the smoothing window size

# Initialize a list to store data for each sequence
sequence_data = []

# Iterate over each row in the DataFrame
for idx, row in combined_bed_df_mex_cat.iterrows():
    seq = row['sequence'].upper()  # Convert sequence to uppercase
    seq_type = row['type']
    chr_type = row['chr_type']
    seq_length = len(seq)

    # Create positions centered at 0
    #positions = np.linspace(- (seq_length - 1) / 2, (seq_length - 1) / 2, seq_length)
    # invert positions
    positions = np.arange(-(seq_length // 2), seq_length // 2)
    #positions = positions[::-1]
    # Add 8 to the flipped positions
    positions = positions + 7

    # Create a DataFrame for each sequence
    df_seq = pd.DataFrame({
        'type': seq_type,
        'chr_type': chr_type,
        'position': positions,
        'base': list(seq)
    })

    # Mark bases that are 'C' or 'G'
    df_seq['is_cg'] = df_seq['base'].apply(lambda x: 1 if x in ['C', 'G'] else 0)

    # Round positions to integer bins for grouping
    df_seq['bin_position'] = df_seq['position'].round().astype(int)

    # Append to the list
    sequence_data.append(df_seq)

# Concatenate all sequence DataFrames
df_all = pd.concat(sequence_data, ignore_index=True)

# Group by 'type', 'chr_type', and 'bin_position' to calculate average CG%
df_avg_cg = df_all.groupby(['type', 'chr_type', 'bin_position'])['is_cg'].mean().reset_index()
df_avg_cg.rename(columns={'bin_position': 'position', 'is_cg': 'average_cg'}, inplace=True)

# Apply smoothing using a rolling mean
def smooth_group(group):
    group = group.sort_values('position')  # Ensure positions are in order
    if(window_size == 0):
        group['average_cg_smooth'] = group['average_cg']
        return group
    else:
        group['average_cg_smooth'] = group['average_cg'].rolling(
            window=window_size, center=True, min_periods=1
        ).mean()
        return group

# Apply the smoothing function to each combination of 'type' and 'chr_type'
df_avg_cg_smooth = df_avg_cg.groupby(['type', 'chr_type'], group_keys=False).apply(smooth_group)

# Set the desired order for the 'type' column
type_order = ['MEX_none', 'MEX_D1to5', 'MEX_D10']
df_avg_cg_smooth['type'] = pd.Categorical(df_avg_cg_smooth['type'], categories=type_order, ordered=True)

# Plot using Plotly
fig = px.line(
    df_avg_cg_smooth,
    x='position',
    y='average_cg_smooth',
    color='type',
    line_dash='chr_type',
    labels={
        'position': 'Position (centered at 0)',
        'average_cg_smooth': 'Smoothed Average CG%',
        'type': 'Type',
        'chr_type': 'Chromosome Type'
    },
    title='Smoothed Average CG% at Each Position (Centered) by Type and Chromosome Type',
    category_orders={'type': type_order}  # Explicitly set the order
)

# Set x-axis range
fig.update_xaxes(range=[-400, 400])
# set y range to 0.25 to 0.63
fig.update_yaxes(range=[0.27, 0.63])

fig.update_layout(template='plotly_white')
fig.show()


In [None]:
### Save Mex file for structure determination:
# copy result_df
# save bed file to "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/082624/result_df_D10.bed"
# where only regions in D10 are included, keeping only bed columns such as chr, start, end
# but first replace "CHROMOSOME_" with "chr" in chr column
# and do not save header
result_df_copy = combined_bed_df_mex_cat.copy()
# print unique types
print(result_df_copy['type'].unique())
# drop rows where chrom != "CHROMOSOME_X"
result_df_copy = result_df_copy[result_df_copy['chrom'] == "CHROMOSOME_X"]

#result_df_copy = result_df.copy()
result_df_copy['chrom'] = result_df_copy['chrom'].str.replace("CHROMOSOME_", "chr")
print(result_df_copy['chrom'].unique())
# result_df_copy is equal to result_df_copy[result_df_copy['type'] where type is in  ['SDC2_D1','SDC2_D2','SDC2_D3','SDC2_D4']
#result_df_copy = result_df_copy[result_df_copy['type'].isin(['SDC2_D10'])]
result_df_copy = result_df_copy[result_df_copy['type'].isin(['MEX_D1to5'])]
#['SDC2_D1','SDC2_D2','SDC2_D3','SDC2_D4','SDC2_D5']
#['SDC2_D6','SDC2_D7','SDC2_D8','SDC2_D9']
#['SDC2_D10']
result_df_copy[['chrom', 'bed_start', 'bed_end']].to_csv("/Data1/git/meyer-nanopore/scripts/analysis/temp_files/082624/result_df_MEX_D1to5_250bp.bed", sep="\t", header=False, index=False)

# print
print(result_df_copy.head())