In [1]:
# Purpose of this script is to perform an unbiased analysis of Fiber-seq and Chip-seq data

__author__ = "Yuri Malina"
__contact__ = "ymalina@berkeley.edu"
__copyright__ = "The Meyer Lab, UC Berkeley"
__credits__ = [""]
__date__ = "6/23/2024"
__deprecated__ = False
__status__ = "In development"
__version__ = "0.0.1"

### TO DOS:

In [2]:
import importlib
import nanotools
importlib.reload(nanotools) # reload nanotools module
import numpy as np
import pandas as pd
import plotly.io as pio
import plotly
import plotly.express as px # Used for plotting
import plotly.graph_objects as go # Used for plotting
import multiprocessing
from multiprocessing import Pool, cpu_count # used for parallel processing
import subprocess
import os
import pywt # for wavelet transform
import matplotlib.pyplot as plt # Use for plotting m6A frac and coverage plot
from matplotlib import cm # Use for plotting m6A frac and coverage plot

#import tqdm
#import pysam
#import pyBigWig

# set renderer to vscode
plotly.offline.init_notebook_mode(connected=True)
pio.renderers.default = 'plotly_mimetype+notebook' #'plotly_mimetype+notebook'
pd.options.display.max_rows = None
pd.set_option('display.max_columns', None)
# display count_df with no limits on rows
pd.set_option('display.max_rows', None)
# left align print
pd.set_option('display.max_colwidth', None)

In [None]:
### Bed file configurations:
sample_source = "type" # "chr_type" or "type" or "chromosome"
chr_type_selected = ["X","Autosome"] # 'X' or "Autosome"
type_selected = ["gene_q4","gene_q1"]
max_regions = 1000 # max regions to consider; 0 = full set;
chromosome_selected = ["CHROMOSOME_V","CHROMOSOME_X"] #"CHROMOSOME_I", "CHROMOSOME_II", "CHROMOSOME_III", "CHROMOSOME_IV","CHROMOSOME_V",
strand_selected = ["+","-"] #+ and/or -
select_opp_strand = True #If you want to select both + and - strands for all regions set to True
down_sample_autosome = False # If you want to downsample autosome genes to match number of X genes set to True
if chr_type_selected == ["X"]:
    down_sample_autosome = False
bed_file = "/Data1/reference/tss_tes_rex_combined_v20_WS235.bed"
bed_window = 1000   # +/- around bed elements.
intergenic_window = 2000 # +/- around intergenic regions
num_bins = 1000 #bins for metagene plot
mods = "a" # {A,CG,A+CG}
if sample_source == "chr_type":
    selection = chr_type_selected
if sample_source == "type":
    selection = type_selected
if sample_source == "chromosome":
    selection = chromosome_selected

# Filter input bed_file based on input parameters (e.g. chromosome, type, strand, etc.)
# Function saves a new filtered bed file to the same folder as the original bed file
# called temp_do_not_use_"type".bed
importlib.reload(nanotools)
new_bed_files=nanotools.filter_bed_file(
    bed_file,
    sample_source,
    selection,
    chromosome_selected,
    chr_type_selected,
    type_selected,
    strand_selected,
    max_regions,
    bed_window,
    intergenic_window
)

modkit_bed_name = "modkit_temp.bed"
modkit_bed_df = nanotools.generate_modkit_bed(new_bed_files, down_sample_autosome, select_opp_strand,modkit_bed_name)
nanotools.display_sample_rows(modkit_bed_df, 5)

In [None]:
### BAM Configurations
R9_m6A_thresh_percent = 0.8
R10_m6A_thresh_percent = 0.8
R10_5mC_thresh_percent = 0.8 # Note: 0.7 in R9 ~ 0.9 in R10
R9_m6A_thresh = int(round(R9_m6A_thresh_percent*258,0)) #default is 129 = 50%; 181=70%; 194=75%; 207 = 80%; 232 = 90%
m6A_thresh = int(round(R10_m6A_thresh_percent*258,0))
mC_thresh = int(round(R10_5mC_thresh_percent*258,0))
print("R9_m6A_thresh: ", R9_m6A_thresh)
print("m6A_thresh: ", m6A_thresh)
print("mC_thresh: ", mC_thresh)

# modkit is used for aggregating methylation data from .bam files
# https://nanoporetech.github.io/modkit/quick_start.html
modkit_path = "/Data1/software/modkit_v0.3/modkit"
bedgraphtobigwig_path = "/Data1/software/ucsc_genome_browser/bedGraphToBigWig"
danpos_path = "/Data1/software/DANPOS3/danpos.py"
chrom_sizes = "/Data1/reference/chrom.sizes.ce11.txt"

analysis_cond = ["N2_old_SMACseq_R10","N2_old_fiber_R10","N2_young_SMACseq_R10","96_old_DPY27degron_SMACseq_R10","SDC2_degron_mixed_fiber_R10"]
    #,"N2_mixed_fiber_R10","N2_mixed_fiber_R9","SDC2_degron_mixed_fiber_R9","N2_young_fiber_R9","N2_mixed_endogenous_R9","N2_old_fiber_R10","51_old_dpy21null_fiber_R10","52_old_dpy21jmjc_fiber_R10","SDC2_degron_mixed_fiber_R10"]

### IMPORT BAM FILES AND METADATA FROM CSV FILE

input_metadata = pd.read_csv("/Data1/git/meyer-nanopore/scripts/bam_input_metadata_3_25_2024.txt", sep="\t", header=0)
# Set bam_files equal to list of items in column bam_files where conditions == N2_fiber
bam_files = input_metadata[input_metadata["conditions"].isin(analysis_cond)]["bam_files"].tolist()
conditions = input_metadata[input_metadata["conditions"].isin(analysis_cond)]["conditions"].tolist()
exp_ids = input_metadata[input_metadata["conditions"].isin(analysis_cond)]["exp_id_date"].tolist()
flowcells = input_metadata[input_metadata["conditions"].isin(analysis_cond)]["flowcell"].tolist()
bam_fracs = len(bam_files)*[1] # For full .bam set to = 1
sample_indices = list(range(len(bam_files)))

### Import CHIP SEQ / EXTERNAL DATA
ext_target = [] # h3_chip, sdc2_chip, sdc3_chip, dhs, gro, mnase, h4k20me1_chip, ama1_chip, "dpy27_chip","ama1_chip","mnase","gro"
#"sdc2_chip_albritton","sdc3_chip_albritton","sdc3_chip_anderson","dpy27_chip_anderson"
ext_metadata = pd.read_csv("/Data1/git/meyer-nanopore/scripts/bw_input_metadata_2_21_2024.txt", sep="\t", header=0)
ext_exp_ids = ext_metadata[ext_metadata["target"].isin(ext_target)]["exp_id_date"].tolist()
ext_files = ext_metadata[ext_metadata["target"].isin(ext_target)]["bw_files"].tolist()
ext_targets = ext_metadata[ext_metadata["target"].isin(ext_target)]["target"].tolist()

output_stem = "/Data1/git/meyer-nanopore/scripts/analysis/result/"
# for dimelo: [181/258,194/258]
thresh_list=len(bam_files)*[m6A_thresh/258] # For R10 flow cells use 0.5; for R9 flow cells use 0.9
# for position in flowcells == R9 set item with same index in thresh_list to R9_m6A_thresh/258
for i in range(len(flowcells)):
    if "R9" in flowcells[i]:
        thresh_list[i] = R9_m6A_thresh/258

file_prefix = "jul_2024"

# Subsample bam based on bam_frac, used to accelerate testing
# if bam_frac = 1 will use original bam files, otherwise will save new subsampled bam files to output_stem.
args_list = [(bam_file, condition, bam_frac, sample_index, output_stem, exp_id) for bam_file, condition, bam_frac, sample_index, exp_id in zip(bam_files,conditions,bam_fracs,sample_indices, ext_exp_ids)]
new_bam_files=[]
new_bam_files = nanotools.parallel_subsample_bam(bam_files, conditions, bam_fracs, sample_indices, output_stem)

print("Program finished!")
print("new_bam_files: ", new_bam_files)
print("exp_ids: ", exp_ids)

In [None]:
### Calculate modkit sample-probs

# for each bam file, calculate the sample-probability distribution using modkit sample-probs command
# and save the output to a file in the output_stem directory
# if the file already exists, it will not be recalculated
# if force_replace is set to True, the file will be recalculated
force_replace = False
# Define a function to encapsulate the task you want parallelized
def sample_probs_bam(arg, force_replace=False):
    each_bam, each_condition, each_bamfrac, each_index, each_output_stem, each_exp_id = arg
    # Define the output file name as output_stem + each_condition + each_exp_id + "_sample_probs.txt"
    each_output = each_output_stem + "/sample_probs/" + each_condition + "-" + each_exp_id + "_thresholds.tsv"
    # Check if the output file exists
    if not force_replace and os.path.exists(each_output):
        print(f"File already exists: {each_output}")
        return

    print(f"Starting on: {each_output}")
    # Define the command to run as modkit_path "sample-probs" each_bam -t 10 -o output_stem --prefix each_condition + each_exp_id
    command = [
        modkit_path,
        "sample-probs",
        each_bam,
        "-t",
        "10",
        "-o",
        each_output_stem,
        "--prefix",
        each_condition + "-" + each_exp_id,
        "--percentiles",
        "0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9",
        "--force",
        "--hist"
    ]
    subprocess.run(command, text=True)

# Now you need to adjust the task_args to include the index
# Instead of directly zipping, enumerate one of the lists to get the index
task_args_with_index = [(args, index) for index, args in enumerate(zip(
    new_bam_files,
    conditions,
    bam_fracs,
    sample_indices,
    [output_stem]*len(new_bam_files),
    exp_ids
))]
# Execute commands in parallel, unpacking the arguments and index within the map call
with Pool(
    processes=15
) as pool:
    pool.starmap(sample_probs_bam, task_args_with_index, force_replace)

# Read in all files into a single dataframe for downstream analysis
sample_probs_df = pd.DataFrame()
for each_condition, each_exp_id in zip(conditions, exp_ids):
    each_output = output_stem + "/sample_probs/"+ each_condition + "-" + each_exp_id + "_thresholds.tsv"

    # Read the file with whitespace as the delimiter
    each_df = pd.read_csv(each_output, delim_whitespace=True, header=0)

    # Debugging: Print columns of each dataframe
    print(f"Columns in {each_output}: {each_df.columns.tolist()}")

    each_df['condition'] = each_condition
    each_df['exp_id'] = each_exp_id
    sample_probs_df = pd.concat([sample_probs_df, each_df], ignore_index=True)
# Reset the index
sample_probs_df = sample_probs_df.reset_index(drop=True)
nanotools.display_sample_rows(sample_probs_df, 10)

# Print column names
print(sample_probs_df.columns)

# convert percentile column to integer using round
sample_probs_df['percentile'] = sample_probs_df['percentile'].apply(lambda x: round(x,0))


#Plot threshold by percentile for all conditions + exp_ids using plotly where base == "C"

fig = px.line(sample_probs_df[
    sample_probs_df['base'] == 'C'
              ], x="percentile", y="threshold", color="condition", line_group="exp_id", hover_name="exp_id", title="Threshold by Percentile for all conditions and exp_ids")
fig.show(renderer='plotly_mimetype+notebook')

# Plot threshold by percentile for all conditions + exp_ids using plotly where base == "A"
fig = px.line(sample_probs_df[
    sample_probs_df['base'] == 'A'
                ], x="percentile", y="threshold", color="condition", line_group="exp_id", hover_name="exp_id", title="Threshold by Percentile for all conditions and exp_ids")
fig.show(renderer='plotly_mimetype+notebook')



In [None]:
# Specify the folder path
FOLDER_PATH = "/Data1/git/meyer-nanopore/scripts/analysis/result"
import glob
import re

# Set the default theme to white
pio.templates.default = "plotly_white"

# Specify the folder path
FOLDER_PATH = "/Data1/git/meyer-nanopore/scripts/analysis/result/sample_probs"

def parse_probability_file(file_path):
    print(f"Parsing file: {file_path}")
    with open(file_path, 'r') as f:
        content = f.read()

    data = []
    current_code = None
    number_of_samples = None

    for section in re.split(r'# code', content)[1:]:  # Split by code sections
        lines = section.strip().split('\n')
        current_code = lines[0].strip()
        print(f"Processing code: {current_code}")

        for line in lines:
            if "Number of samples =" in line:
                number_of_samples = int(line.split('=')[1].strip())
                print(f"Number of samples: {number_of_samples}")
            elif '..' in line and '[' in line and ']' in line:
                try:
                    parts = line.split()
                    threshold = float(parts[0])
                    # Extract count directly from brackets
                    count_match = re.search(r'\[(.*?)\]', line)
                    if count_match:
                        count = int(count_match.group(1).strip())
                    else:
                        print(f"Warning: Could not find count in line: {line}")
                        continue

                    # Normalize the count
                    if number_of_samples:
                        normalized_count = count / number_of_samples
                    else:
                        print(f"Warning: No number of samples found for code {current_code}")
                        normalized_count = count

                    data.append({'code': current_code, 'threshold': threshold, 'count': normalized_count})
                    print(f"Parsed data point: threshold={threshold}, normalized_count={normalized_count}")
                except (IndexError, ValueError) as e:
                    print(f"Warning: Error parsing line: {line}")
                    print(f"Error details: {str(e)}")

    # Extract condition and exp_ID from filename
    filename = os.path.basename(file_path)
    parts = filename.split('_probabilities.txt')[0].split('-')
    condition = '-'.join(parts[:-1])  # Everything before the last part
    exp_id = parts[-1]  # The last part
    condition_exp_id = f"{condition}-{exp_id}"

    df = pd.DataFrame(data)
    if not df.empty:
        df['condition_exp_id'] = condition_exp_id
    else:
        print(f"Warning: No valid data extracted from file {file_path}")
    return df

def main():
    # Find all files ending with "_probabilities.txt" in the specified folder
    file_pattern = os.path.join(FOLDER_PATH, "*_probabilities.txt")
    files = glob.glob(file_pattern)

    if not files:
        print(f"No '*_probabilities.txt' files found in {FOLDER_PATH}")
        return

    # Combine all data into a single DataFrame
    all_data = pd.concat([parse_probability_file(f) for f in files], ignore_index=True)

    if all_data.empty:
        print("No valid data found in the files. Please check the file contents and the console output for detailed warnings.")
        return

    print(f"Total data points collected: {len(all_data)}")

    # Create individual plots
    codes = ['a', 'C', 'm', 'A']
    for code in codes:
        code_data = all_data[all_data['code'] == code]
        if code_data.empty:
            print(f"No data found for code {code}")
            continue

        print(f"Plotting data for code {code}")
        fig = go.Figure()

        for condition in code_data['condition_exp_id'].unique():
            condition_data = code_data[code_data['condition_exp_id'] == condition]
            fig.add_trace(
                go.Scatter(
                    x=condition_data['threshold'],
                    y=condition_data['count'],
                    mode='lines',
                    name=f'{condition}',
                )
            )

        fig.update_layout(
            title=f"Probability Analysis - Code {code}",
            xaxis_title="Threshold",
            yaxis_title="Normalized Count",
            legend_title="Condition + Exp ID",
            height=600,
            width=1000,
        )

        # Display the plot
        fig.show()

    print("Analysis complete. Plots displayed in the notebook.")

if __name__ == "__main__":
    main()

In [None]:
### Calculate bam summary statistics
importlib.reload(nanotools)
force_replace = False
sampling_frac = 0.1 # fraction of bam to sample for summary statistics

summary_bam_df = pd.DataFrame()

### Define filename for summary table based on selected conditions
# We'll start by defining a function to encapsulate the task you want parallelized
def process_bam(args):
    each_bam, each_condition, each_thresh, each_exp_id = args
    print("starting on bam:", each_bam," | condition:", each_condition,"| each_exp_id:",each_exp_id, "with thresh:", each_thresh)
    return nanotools.get_summary_from_bam(sampling_frac, each_thresh, modkit_path, each_bam, each_condition,each_exp_id, thread_ct = 50)

# Define filename for summary table based on selected conditions
summary_table_name = "temp_files/" + "_" + conditions[0] + conditions[-1] + "_" + str(sampling_frac) + "_thresh" + str(thresh_list[0]) + "_summary_table.csv"

# Check if summary table exists
if not force_replace and os.path.exists(summary_table_name):
    print("Summary table exists, importing...")
    summary_bam_df = pd.read_csv(summary_table_name, sep="\t", header=0)
else:
    print("Summary table does not exist, creating...")
    #
    # Create a pool of worker processes
    pool = multiprocessing.Pool(3)

    # Map the function to the arguments
    results = pool.map(process_bam, zip(new_bam_files, conditions, thresh_list, exp_ids))

    # Close the pool
    pool.close()
    pool.join()

    # Append the results to the summary dataframe
    summary_bam_df = pd.concat(results, ignore_index=True)
    # Reset the index
    summary_bam_df = summary_bam_df.reset_index(drop=True)

    # Use display as per the instructions
    display(summary_bam_df.head(3))

    # Save the dataframe to a CSV file
    summary_bam_df.to_csv(summary_table_name, sep="\t", header=True, index=False)

### Plot m6A frac and coverage by condition

### Create coverage_df file name (similar to summary_table_name)
coverage_df_name =  "temp_files/"+"_"+conditions[0]+conditions[-1] + "_"+str(sampling_frac)+"_thresh"+str(thresh_list[0])+"_coverage_df.csv"
# if coverage_df exists, import it otherwise create it
if not force_replace and os.path.exists(coverage_df_name):
    print("Coverage table exists, importing...")
    coverage_df = pd.read_csv(coverage_df_name, sep="\t", header=0)
else:
    # Calculate total_m6a and total_A
    nanotools.display_sample_rows(summary_bam_df,2)
    # Call the function to create and export the coverage DataFrame
    total_m6a = summary_bam_df.loc[summary_bam_df['code'] == 'a'].groupby(['exp_id', 'condition'])['pass_count'].sum().reset_index()
    total_m6a.rename(columns={'pass_count': 'total_m6a'}, inplace=True)
    total_5mc = summary_bam_df.loc[summary_bam_df['code'] == 'm'].groupby(['exp_id', 'condition'])['pass_count'].sum().reset_index()
    total_5mc.rename(columns={'pass_count': 'total_5mc'}, inplace=True)

    total_A = summary_bam_df.loc[(summary_bam_df['base'] == 'A') & (summary_bam_df['code'] == '-')].groupby(['exp_id', 'condition'])['pass_count'].sum().reset_index()
    total_A.rename(columns={'pass_count': 'total_A'}, inplace=True)

    total_C = summary_bam_df.loc[(summary_bam_df['base'] == 'C') & (summary_bam_df['code'] == '-')].groupby(['exp_id', 'condition'])['pass_count'].sum().reset_index()
    total_C.rename(columns={'pass_count': 'total_C'}, inplace=True)

    # Merge total_m6a and total_A DataFrames
    coverage_df_A = pd.merge(total_m6a, total_A,on=['exp_id', 'condition'], how='outer').fillna(0)
    coverage_df_C = pd.merge(total_5mc, total_C,on=['exp_id', 'condition'], how='outer').fillna(0)
    coverage_df = pd.merge(coverage_df_A, coverage_df_C,on=['exp_id', 'condition'], how='outer').fillna(0)

    # Calculate coverage (ce genome size = 100,272,763)
    coverage_df['coverage'] = ((coverage_df['total_A'] + coverage_df['total_m6a']) * (1/sampling_frac)) / 100000000 * 4 # * 4 since As are 1/4 of genome

    coverage_df['total_A_m6a'] = coverage_df['total_A'] + coverage_df['total_m6a']
    coverage_df['total_C_5mc'] = coverage_df['total_C'] + coverage_df['total_5mc']

    # Calculate m6A_frac
    coverage_df['m6A_frac'] = coverage_df['total_m6a'] / (coverage_df['total_A_m6a'])
    coverage_df['5mC_frac'] = coverage_df['total_5mc'] / (coverage_df['total_C_5mc'])

    # Drop rows where exp_id == AD1-nb_06_13_23
    #coverage_df = coverage_df[coverage_df.exp_id != 'AD1-nb_06_13_23']

    display(coverage_df)
    #Save coverage df
    coverage_df.to_csv(coverage_df_name, sep="\t", header=True, index=False)

# Create the bar plot
fig = go.Figure()

# Function to generate color map for n unique items
def generate_color_map(n, cmap):
    norm = plt.Normalize(-1, n)
    return [cm.colors.to_hex(cmap(norm(i))) for i in range(n)]

# Find unique conditions
unique_conditions = coverage_df['condition'].unique()

# Different color maps for each condition
condition_colormaps = {
    'N2_fiber': cm.Blues,
    'SDC2_degron_fiber': cm.Greens,
    'N2_bg': cm.Reds,
    'N2_fiber_R10': cm.viridis,
    'N2-DPY27_dimelo_pAHia5': cm.plasma,
    'N2-DPY27_dimelo_RbNbHia5': cm.viridis,
    '50_dpy27dimelo_mcvipi': cm.inferno
    # Add more conditions and their corresponding colormaps here
}

# Generate color families for each condition
color_families = {}
for condition in unique_conditions:
    n_exp_ids = len(coverage_df[coverage_df['condition'] == condition]['exp_id'].unique())
    color_families[condition] = generate_color_map(n_exp_ids, condition_colormaps.get(condition, cm.viridis))

# Map each exp_id to its color
coverage_df['custom_color'] = coverage_df.groupby('condition')['exp_id'].transform(lambda x: x.astype('category').cat.codes)
coverage_df['custom_color'] = coverage_df.apply(lambda row: color_families[row['condition']][row['custom_color']], axis=1)
coverage_df['exp_id'] = coverage_df['exp_id'].str.strip()

display(coverage_df)
# Create the bar plot using Plotly Graph Objects for more customization
fig = go.Figure()

# Add bars for each condition and exp_id
for condition, condition_df in coverage_df.groupby('condition'):
    for exp_id, exp_df in condition_df.groupby('exp_id'):
        fig.add_trace(
            go.Bar(
                x=[condition],
                y=[exp_df['coverage'].iloc[0]],  # Assuming only one row per exp_id per condition
                name=exp_id,
                marker=dict(color=exp_df['custom_color'].iloc[0]),
                legendgroup=exp_id
            )
        )

# Update layout
fig.update_layout(
    title='Stacked Bar Plot of Coverage by Condition',
    xaxis_title='Condition',
    yaxis_title='Coverage',
    barmode='stack',
    legend_title="exp_id"
)

# auto set height
fig.update_layout(
    autosize=False,
    #width=600,
    #height=700,
    template="simple_white"
)

# Group data by 'condition'
grouped = coverage_df.groupby('condition')


# Initialize the figure
fig2 = go.Figure()

# Add weighted box plot and points for each condition
for name, group in grouped:
    weighted_points = np.repeat(group['m6A_frac'], np.ceil(group['coverage'].astype(int)))
    weighted_avg = np.average(group['m6A_frac'], weights=(group['coverage'].astype(int)*100)+1)# * 10))
    # calculate median
    weighted_median= np.median(weighted_points)
    fig2.add_trace(go.Box(
        y=weighted_points,
        name=name,
        boxmean=True,
        boxpoints="all",  # No points on the box plot itself
        jitter=0.3,       # Add some jitter for visibility
        pointpos=0,     # Position of points relative to box
        marker_size=2,
        marker_opacity=0.5,
        fillcolor='rgba(0,0,0,0)'  # Transparent fill
    ))
        # Add annotation for weighted average
    fig2.add_annotation(
        x=name,
        y=weighted_median,#weighted_avg,
        text=f"Median: {weighted_median:.2%}",
        arrowhead=1,
        ax=0,
        ay=-10
    )

# Customize the layout
fig2.update_layout(
    title='Distribution of m6A by Condition',
    xaxis_title='Condition',
    yaxis_title='m6A frac',
    template="simple_white",
    #width=600,
    # set y axis range
    yaxis=dict(
        range=[0, 0.4]
    )
)

# Set y axis to %
fig2.update_yaxes(tickformat='.0%')

# Add weighted box plot and points for each condition
# Initialize the figure
fig3 = go.Figure()

for name, group in grouped:
    weighted_points = np.repeat(group['5mC_frac'], np.ceil(group['coverage'].astype(int)))
    weighted_avg = np.average(group['5mC_frac'], weights=(group['coverage'].astype(int)*100)+1)# * 10))
    # calculate median
    weighted_median= np.median(weighted_points)
    fig3.add_trace(go.Box(
        y=weighted_points,
        name=name,
        boxmean=True,
        boxpoints="all",  # No points on the box plot itself
        jitter=0.3,       # Add some jitter for visibility
        pointpos=0,     # Position of points relative to box
        marker_size=2,
        marker_opacity=0.5,
        fillcolor='rgba(0,0,0,0)'  # Transparent fill
    ))
        # Add annotation for weighted average
    fig3.add_annotation(
        x=name,
        y=weighted_median,#weighted_avg,
        text=f"Median: {weighted_median:.2%}",
        arrowhead=1,
        ax=0,
        ay=-10
    )

# Customize the layout
fig3.update_layout(
    title='Distribution of 5mC frac by Condition',
    xaxis_title='Condition',
    yaxis_title='5mC frac',
    template="simple_white",
    #width=600,
    # set y axis range
    yaxis=dict(
        range=[0, 0.4]
    )
)

# Set y axis to %
fig3.update_yaxes(tickformat='.0%')

fig3.show(renderer='plotly_mimetype+notebook')
fig2.show(renderer='plotly_mimetype+notebook')
fig.show(renderer='plotly_mimetype+notebook')
"""
fig.write_image("images_11_14_23/bulk_m6Afrac_n2_sdc2degron_0p1sample.svg")
fig.write_image("images_11_14_23/bulk_m6Afrac_n2_sdc2degron_0p1sample.png")
fig2.write_image("images_11_14_23/coverage_n2_sdc2degron_0p1sample.svg")
fig2.write_image("images_11_14_23/coverage_n2_sdc2degron_0p1sample.png")

# Function call example
### Calculate N50s SKIP, NOT NECESSARY FOR ANY FOLLOWING STEPS
n50_fig = nanotools.calculate_and_plot_n50(new_bam_files, conditions, exp_ids)
n50_fig.show(renderer='plotly_mimetype+notebook')
n50_fig.write_image("images_11_14_23/n50_fig_n2_sdc2degron_0p1sample.svg")
n50_fig.write_image("images_11_14_23/n50_fig_n2_sdc2degron_0p1sample.png")"""

In [None]:
### Generate bedgraph from bam files

regenerate_bit = False # SEt to true to force regenerate, otherwise load if available.
num_processors = 15

# Generating the list of input bam folder paths from new_bam_files
input_bam_paths = [os.path.dirname(bam) for bam in new_bam_files]

# Function to run a single command
def modkit_pileup_extract(args):
    (each_bam, each_thresh, each_condition, each_index, each_bamfrac, each_expid, 
     each_type, modkit_path, output_stem, modkit_bed_name,num_processors) = args
    
    # if regenerate_bit is True delete all files ending in .bedgraph in output_stem
    if regenerate_bit:
        for file in os.listdir(output_stem):
            if file.endswith(".bedgraph"):
                print("Deleting file: ", os.path.join(output_stem, file))
                os.remove(os.path.join(output_stem, file))
                
    # Check if the output file exists
    if not regenerate_bit:
        print("Checking if file exists: ", output_stem + "/"+each_expid + "-" + each_condition + "_a_A0_m_GC1.bedgraph")
        if os.path.exists(output_stem + "/"+each_expid + "-" + each_condition + "_a_A0_m_GC1.bedgraph"):
            print(f"File already exists: {output_stem}/{each_expid}-{each_condition}_a_A0_m_GC1.bedgraph")
            # Read in output file and check if empty
            return
        else:
            for file in os.listdir(output_stem):
                # if file contains {each_expid}-{each_condition} and ends with .bedgraph, delete it
                if each_expid in file and each_condition in file and file.endswith(".bedgraph"):
                    print("Deleting file: ", os.path.join(output_stem, file))

    
    print(f"Starting on bam file: ", each_bam,"and bedfile:", modkit_bed_name)
    command = [
        modkit_path,
        "pileup",
        #"--only-tabs",
        #"--ignore",
        #"m",
        "--threads",
        f"{num_processors}",
        "--bedgraph",
        #"--combine-strands",
        #"--filter-threshold",
        #f"A:{1-each_thresh}",
        #f"A:{1-each_thresh}",
        "--mod-thresholds",
        f"a:{each_thresh}",
        "--mod-thresholds",
        f"m:{each_thresh}",
        "--ref",
        "/Data1/reference/c_elegans.WS235.genomic.fa",
        #"--filter-threshold",
        #f"A:{1-each_thresh}",
        #"--filter-threshold",
        #f"C:{1-each_thresh}",
        "--motif",
        "GC",
        "1",
        #"--motif",
        #"CC",
        #"0",
        "--motif",
        "A",
        "0",
        "--prefix",
        f"{each_expid}-{each_condition}",

        #"--include-bed",
        #modkit_bed_name,
        each_bam,
        output_stem
    ]
    subprocess.run(command, text=True)
    
    # delete any files in output_stem that contain any of the following strings: "a_CG0" or "m_A0"
    for file in os.listdir(output_stem):
        if "a_GC1" in file or "m_A0" in file or "a_CC0" in file:
            print("Deleting file: ", os.path.join(output_stem, file))
            os.remove(os.path.join(output_stem, file))

    # if m_GC1_positive and m_GC1_negative files not exist, due to missing mods in bam file, create file with empty row
    if not os.path.exists(f"{output_stem}/{each_expid}-{each_condition}_m_GC1_positive.bedgraph"):
        with open(f"{output_stem}/{each_expid}-{each_condition}_m_GC1_positive.bedgraph", "w") as f:
            f.write("\n")
    if not os.path.exists(f"{output_stem}/{each_expid}-{each_condition}_m_GC1_negative.bedgraph"):
        with open(f"{output_stem}/{each_expid}-{each_condition}_m_GC1_negative.bedgraph", "w") as f:
            f.write("\n")
    
    # Merge A0_negative and A0_positive files by concatenating them, and then sorting by chromosome and start position in bash
    # and saving the output to a new file, then deleting the old files
    def merge_and_sort_bedgraph_files(output_stem, each_expid, each_condition, file_suffixes, num_processors=8):
        for suffix_pair in file_suffixes:
            negative_suffix, positive_suffix, output_suffix = suffix_pair
    
            negative_file = f"{output_stem}/{each_expid}-{each_condition}_{negative_suffix}.bedgraph"
            positive_file = f"{output_stem}/{each_expid}-{each_condition}_{positive_suffix}.bedgraph"
            merged_file = f"{output_stem}/{each_expid}-{each_condition}_{output_suffix}.bedgraph"
    
            command = f"cat {negative_file} {positive_file} | sort -k1,1 -k2,2n --parallel={num_processors} > {merged_file}"
            subprocess.run(command, shell=True)
            
            
            # if either suffix contains "positive" or "negative, Delete the old files
            if "positive" in negative_suffix or "negative" in negative_suffix or "positive" in positive_suffix or "negative" in positive_suffix:
                os.remove(negative_file)
                os.remove(positive_file)
    
    file_suffixes = [
        ("a_A0_negative", "a_A0_positive", "a_A0"),
        ("m_GC1_negative", "m_GC1_positive", "m_GC1"),
        ("a_A0", "m_GC1", "a_A0_m_GC1"),
    ]
    
    merge_and_sort_bedgraph_files(output_stem, each_expid, each_condition, file_suffixes, num_processors)

    
    
# Now you need to adjust the task_args to include the index
# Prepare the arguments for each task
task_args = list(zip(
    new_bam_files,
    thresh_list,
    conditions,
    sample_indices,
    bam_fracs,
    exp_ids,
    [type_selected]*len(new_bam_files),
    [modkit_path]*len(new_bam_files),
    input_bam_paths,
    [modkit_bed_name]*len(new_bam_files),
    [num_processors] * len(new_bam_files)
))

# Select task_args where new_bam_files contains "AG1"
#task_args = [task for task in task_args if "AG1" in task[0]]

# Print bam paths for debugging
print("new_bam_files: ", input_bam_paths)

# Execute commands in parallel w
with Pool(processes=4) as pool:
    pool.map(modkit_pileup_extract, task_args)


In [None]:
### Convert bam to bedgraph and bw files.
importlib.reload(nanotools)
import tempfile

#initiate list of bw file names
bw_files = []

def process_bedgraph(args):
    each_bam, each_condition, each_expid, smoothing_window, imputation_window, bedgraphtobigwig_path = args
    
    # Define raw output file names
    bedgraph_fn = os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0.bedgraph")
    raw_bw_fn = os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0.bw")
    
    ### Convert raw file directly to bigwig, without filling or smoothing or normalizing.
    if not os.path.exists(raw_bw_fn) and False:
        print("Converting raw bedgraph directly to raw bigwig...")

        with tempfile.NamedTemporaryFile(mode='w+t', delete=False, suffix='.bedgraph') as temp_file:
            temp_filename = temp_file.name
            cut_command = f"cut -f 1-4 {bedgraph_fn}"

            try:
                subprocess.run(cut_command, shell=True, check=True, stdout=temp_file)
            except subprocess.CalledProcessError as e:
                print(f"An error occurred while cutting the bedgraph file: {e}")
                os.unlink(temp_filename)
                raise

        try:
            bigwig_command = [
                bedgraphtobigwig_path,
                temp_filename,
                chrom_sizes,
                raw_bw_fn
            ]
            subprocess.run(bigwig_command, check=True)
        except subprocess.CalledProcessError as e:
            print(f"An error occurred during bedgraph to bigwig conversion: {e}")
            raise
        finally:
            print("Saved filled bigwig file: ", raw_bw_fn)
            # add filled bigwig file to list
            os.unlink(temp_filename)
    else:
        print("Bigwig file already exists, skipping conversion.")
    
    # Imputation and smoothing (currently skipped as per original code)
    
    if True:
        ## Define output file name path
        # Files where missing positions are filled with 0
        filled_bedgraph_fn = os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0_raw_filled.bedgraph")
        filled_bw_fn = filled_bedgraph_fn.replace(".bedgraph", ".bw")

        nafilled_bedgraph_fn = os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0_nafilled.bedgraph")
        nafilled_bw_fn = nafilled_bedgraph_fn.replace(".bedgraph", ".bw")
        
        # File where all positions are imputed and smoothe.
        smoothed_bedgraph_fn = os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0_smoothed-{smoothing_window}-{imputation_window}.bedgraph")
        # Convert smoothed bedgraph to bigwig
        smoothed_bigwig_fn = smoothed_bedgraph_fn.replace(".bedgraph", ".bw")
        
        bedgraph_df = pd.DataFrame()
        
        if not os.path.exists(filled_bedgraph_fn):
            print("Starting to fill raw bedgraph...")
            print("Loading bedgraph file: ", bedgraph_fn)
            if bedgraph_df.empty:
                bedgraph_df = nanotools.load_bedgraph_file(bedgraph_fn)
                bedgraph_df['score'] = bedgraph_df['score'].fillna(0)

            # write filled bedgraph file to disk
            bedgraph_df[['chromosome', 'start', 'end', 'score']].to_csv(
                filled_bedgraph_fn,
                sep="\t", header=False, index=False)
        else:
            print(f"Raw filled bedgraph file already exists, skipping: {filled_bedgraph_fn}")
        
        if not os.path.exists(filled_bw_fn):
            print("Converting filled bedgraph to bigwig...")
            try:
                bigwig_command = [
                    bedgraphtobigwig_path,
                    filled_bedgraph_fn,
                    chrom_sizes,
                    filled_bw_fn
                ]
                subprocess.run(bigwig_command, check=True)
            except subprocess.CalledProcessError as e:
                print(f"An error occurred during bedgraph to bigwig conversion: {e}")
                raise
        else:
            print(f"Raw filled bigwig file already exists, skipping: {filled_bw_fn}")

        if not os.path.exists(nafilled_bedgraph_fn):
            print("Starting to fill raw bedgraph with NAs...")
            print("Loading bedgraph file: ", bedgraph_fn)
            if bedgraph_df.empty:
                bedgraph_df = nanotools.load_bedgraph_file(bedgraph_fn)
                #bedgraph_df['score'] = bedgraph_df['score'].fillna(0)

            # write filled bedgraph file to disk
            bedgraph_df[['chromosome', 'start', 'end', 'score']].to_csv(
                nafilled_bedgraph_fn,
                sep="\t", header=False, index=False)
        else:
            print(f"Raw NA filled bedgraph file already exists, skipping: {filled_bedgraph_fn}")


        if not os.path.exists(smoothed_bedgraph_fn): 
            if bedgraph_df.empty:
                print("Loading bedgraph file: ", bedgraph_fn)
                bedgraph_df = nanotools.load_bedgraph_file(bedgraph_fn)
                bedgraph_df['score'] = bedgraph_df['score'].fillna(0)
                
            print(f"Imputing and smoothing bedgraph file: {smoothed_bedgraph_fn}")
            bedgraph_df['imputed_score'], bedgraph_df['imputed_coverage'], bedgraph_df['smoothed_score'], bedgraph_df['smoothed_coverage'] = nanotools.parallel_impute_and_smooth(
                bedgraph_df,
                impute_window=imputation_window,
                smooth_window=smoothing_window,
                fill_value=0
            )
            
            print(f"Saving smoothed bedgraph file: {smoothed_bedgraph_fn}")
            bedgraph_df[['chromosome', 'start', 'end', 'smoothed_score']].to_csv(
                smoothed_bedgraph_fn,
                sep="\t", header=False, index=False)
        else:
            print(f"Imputed and smoothed bedgraph file already exists, skipping: {smoothed_bedgraph_fn}")
        
        if not os.path.exists(smoothed_bigwig_fn):
            command = [
                bedgraphtobigwig_path,
                smoothed_bedgraph_fn,
                chrom_sizes,
                smoothed_bigwig_fn
            ]
            #print("Command:")
            #print(' '.join(command))
            print(f"Converting smoothed bedgraph to bigwig: {smoothed_bigwig_fn}")
            subprocess.run(command, text=True, capture_output=True)    
        else:
            print(f"Imputed and smoothed bigwig file already exists, skipping: {smoothed_bigwig_fn}")   
    else:
        print("Skipping everything for now.")
        

if __name__ == "__main__":
    importlib.reload(nanotools)
    
    # Configurable parameters
    smoothing_window = 20
    imputation_window = 0
    
    # Prepare arguments for multiprocessing
    args_list = [
        (each_bam, each_condition, each_expid, smoothing_window, imputation_window, bedgraphtobigwig_path)
        for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)
    ]
    
    # Use multiprocessing to process bedgraph files in parallel
    with multiprocessing.Pool(processes=10) as pool:
        pool.map(process_bedgraph, args_list)

    print("All processing completed.")
    
    for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids):
        filled_bw_fn = os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0_filled.bw")
        #append to bw_files list
        bw_files.append(filled_bw_fn)

In [None]:
### Plot size of all files
def get_file_info(bam_files, conditions, exp_ids):
    file_info = []
    for bam_file, condition, exp_id in zip(bam_files, conditions, exp_ids):
        output_dir = os.path.dirname(bam_file)

        # Define the specific file patterns
        patterns = [
            f"{exp_id}-{condition}_a_A0.bedgraph",
            f"{exp_id}-{condition}_a_A0_raw.bw",
            f"{exp_id}-{condition}_a_A0_raw_filled.bedgraph",
            f"{exp_id}-{condition}_a_A0_raw_filled.bw",
            f"{exp_id}-{condition}_a_A0_smoothed-{smoothing_window}-{imputation_window}.bedgraph",
            f"{exp_id}-{condition}_a_A0_smoothed-{smoothing_window}-{imputation_window}.bw"
        ]

        for pattern in patterns:
            file_path = os.path.join(output_dir, pattern)
            if os.path.exists(file_path):
                file_size = os.path.getsize(file_path) # in bytes
                # conert to Gb
                file_size = file_size / (1024**3)

                if file_path.endswith('.bedgraph'):
                    file_type = 'Bedgraph'
                elif file_path.endswith('.bw'):
                    file_type = 'Bigwig'

                if 'raw' in file_path and not 'filled' in file_path:
                    processing = 'Raw'
                elif 'filled' in file_path:
                    processing = 'Filled'
                elif 'smoothed' in file_path:
                    processing = 'Smoothed'
                else:
                    processing = 'Original'

                file_info.append({
                    'File Name': os.path.basename(file_path),
                    'File Path': file_path,
                    'File Size (bytes)': file_size,
                    'File Type': file_type,
                    'Processing': processing,
                    'Experiment ID': exp_id,
                    'Condition': condition
                })

    return file_info

# Use the variables from your original script
new_bam_files = [each_bam for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)]
conditions = [each_condition for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)]
exp_ids = [each_expid for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)]

# Define smoothing_window and imputation_window as in your original script
smoothing_window = 20
imputation_window = 0

# Get the file information
file_info = get_file_info(new_bam_files, conditions, exp_ids)

if file_info:
    df_file_info = pd.DataFrame(file_info)

    # Sort the DataFrame by file size (largest to smallest)
    df_file_info = df_file_info.sort_values('File Size (bytes)', ascending=False)

    # Display the DataFrame
    #display(df_file_info)

    # Create plots using Plotly
    file_types = df_file_info['File Type'].unique()
    fig = make_subplots(rows=len(file_types), cols=1,
                        subplot_titles=[f"{ftype} File Sizes" for ftype in file_types],
                        vertical_spacing=0.1)

    for i, file_type in enumerate(file_types, start=1):
        df_subset = df_file_info[df_file_info['File Type'] == file_type]

        trace = go.Bar(
            x=df_subset['Experiment ID'],
            y=df_subset['File Size (bytes)'],
            name=file_type,
            text=df_subset['Processing'],
            hoverinfo='text+y',
            hovertext=[f"Exp ID: {exp}<br>Size: {size:.2f} GB<br>Processing: {proc}"
                       for exp, size, proc in zip(df_subset['Experiment ID'],
                                                  df_subset['File Size (bytes)'],
                                                  df_subset['Processing'])]
        )

        fig.add_trace(trace, row=i, col=1)

        fig.update_xaxes(title_text="Experiment ID", row=i, col=1)
        fig.update_yaxes(title_text="File Size (GB)", row=i, col=1)

    fig.update_layout(
        height=300 * len(file_types),
        title_text="File Sizes by Experiment ID and File Type",
        showlegend=False,
        template="plotly_white"
    )

    fig.show()

    # Optionally, you can save this DataFrame to a CSV file
    # df_file_info.to_csv('file_analysis_results.csv', index=False)
else:
    print("No matching files found in the specified directories.")

# Print total number of files found
print(f"\nTotal number of files found: {len(file_info)}")

# Print total size of all files
if file_info:
    total_size = sum(file['File Size (bytes)'] for file in file_info)
    print(f"Total size of all files: {total_size:.2f} GB")

In [None]:
### Create merged bedgraph file for qnormalization
import os
import pandas as pd
from multiprocessing import Pool
from tqdm import tqdm

def get_optimal_dtypes(chunk):
    dtypes = {
        'chromosome': 'category',
        'start': 'uint32',
    }
    # Determine the best numeric type for the score column
    score_col = chunk.columns[-1]
    if pd.api.types.is_float_dtype(chunk[score_col]):
        if chunk[score_col].apply(lambda x: x.is_integer()).all():
            dtypes[score_col] = 'float32'
        else:
            dtypes[score_col] = 'float32'
    elif pd.api.types.is_integer_dtype(chunk[score_col]):
        dtypes[score_col] = 'float32'
    else:
        dtypes[score_col] = 'float32' #chunk[score_col].dtype
    return dtypes

def process_chunk(args):
    chunk, file_name = args
    chunk.columns = ['chromosome', 'start', file_name]
    dtypes = get_optimal_dtypes(chunk)
    return chunk.astype(dtypes)

def process_bedgraph_in_chunks(file_path, chunk_size=1000000,rows_to_process=None, suffix = None):
    file_name = os.path.basename(file_path).replace(suffix, '')

    ### TEMP LIMITTING INPUT ROWS
    chunks = pd.read_csv(file_path, sep='\t', header=None, chunksize=chunk_size, usecols = [0, 1, 3],nrows = rows_to_process)

    with Pool(
        processes=50
    ) as pool:
        processed_chunks = list(pool.imap(process_chunk, ((chunk, file_name) for chunk in chunks)))

    return pd.concat(processed_chunks)

def check_and_concat_dataframes(base_df, new_df):
    if base_df is None:
        return new_df

    if not np.array_equal(base_df[['chromosome', 'start']].values, new_df[['chromosome', 'start']].values):
        raise ValueError("chromosome, start, and end columns are not identical across all files.")

    base_df[new_df.columns[-1]] = new_df.iloc[:, -1]
    return base_df

def find_bedgraph_files(directory, suffix=None):
    ## DEFAULTS TO RAW UNFILLED FILE
    if suffix == None:
        print("Requires Suffix to find files")
        return
    return [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(suffix)]

def main(new_bam_files):
    unique_directories = list(set(os.path.dirname(path) for path in new_bam_files))
    print(f"Found {len(unique_directories)} unique directories.")
    suffix = "_nafilled.bedgraph"
    all_bedgraph_files = []
    for directory in unique_directories:
        all_bedgraph_files.extend(find_bedgraph_files(directory, suffix = suffix))


    # keep only bedgraph files that have the required substrings (BM_, BK_, BN_, AG2_merged)
    all_bedgraph_files = [file for file in all_bedgraph_files if "BM_" in file or "BK_" in file or "BN_" in file or "AG-22" in file or "AH-" in file]

    print(f"Found {len(all_bedgraph_files)} bedgraph files.")
    for file in all_bedgraph_files:
        print(os.path.basename(file))

    final_df = None
    temp_folder = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"+ file_prefix + "/"
    input_file = temp_folder + "/merged_bedgraph_test.csv"
    # if input_file does not exist, create it
    if not os.path.exists(input_file):
        for file in tqdm(all_bedgraph_files, desc="Processing files"):
            try:
                df = process_bedgraph_in_chunks(file,rows_to_process=None, suffix = suffix)
                final_df = check_and_concat_dataframes(final_df, df)
            except Exception as e:
                print(f"Error processing {file}: {str(e)}")
                continue  # Skip to the next file if there's an error

        if final_df is not None:
            print("Saving merged dataframe...")
            final_df.to_csv('/Data1/git/meyer-nanopore/scripts/analysis/temp_files/merged_bedgraph_test.csv', index=False)

            print(f"Processed {len(all_bedgraph_files)} files from {len(unique_directories)} directories. Merged dataframe saved as 'merged_bedgraph_test.csv'")
        else:
            print("No data was processed successfully. Please check the errors above.")
    else:
        print("Merged dataframe already exists, skipping.")

print(new_bam_files)
if __name__ == '__main__':
    main(
        # select all elements in new_bam_files that contain "BM_" "BK_" or "BN_"
        [bam for bam in new_bam_files if "BM_" in bam or "BK_" in bam or "BN_" in bam or "AG1_" in bam or "AH_" in bam]
    )

In [None]:
### Qnromalization
import stat

# https://pypi.org/project/qnorm/
from qnorm import quantile_normalize
# reimport nanotools
importlib.reload(nanotools)
import pyarrow.csv as pv

def import_normalize_smooth_and_convert(bedgraphtobigwig_path, chrom_sizes, imputation_window, smoothing_window):
    temp_folder = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"+ file_prefix + "/"
    input_file = temp_folder + "/merged_bedgraph_test.csv"

    # Check if the input file exists
    if not os.path.exists(input_file):
        raise FileNotFoundError(f"Input file not found: {input_file}")

    # Read the CSV file
    print("Importing data...")
    # Read the CSV file using pyarrow with parallel processing
    read_options = pv.ReadOptions(use_threads=True)
    table = pv.read_csv(input_file, read_options=read_options)

    # Convert the pyarrow Table to a pandas DataFrame
    df = table.to_pandas()

    #df = pd.read_csv(input_file,sep=',', header=0)
    print(f"Imported data shape: {df.shape}")

    # Check the structure of the dataframe
    if 'chromosome' not in df.columns or 'start' not in df.columns:
        raise ValueError("Input file must have 'chromosome' and 'start' columns")

    if df.shape[1] < 3:
        raise ValueError("Input file must have at least one data column besides 'chromosome' and 'start'")

    # Identify columns to normalize (all except 'chrom' and 'start')
    columns_to_normalize = df.columns[2:]

    # Convert 'na' to NaN and columns to float
    df[columns_to_normalize] = df[columns_to_normalize].replace('na', np.nan).astype(float)

    # Prepare data for normalization
    data_to_normalize = df[columns_to_normalize].values

    # Apply qnorm with parallel processing
    print("Applying quantile normalization using 10 CPU cores...")
    normalized_data = quantile_normalize(data_to_normalize, ncpus=10)

    # Replace original data with normalized data
    df[columns_to_normalize] = normalized_data

    print("Normalization complete.")
    print(f"Final data shape: {df.shape}")

    # Save the normalized dataframe
    if (False):
        qnorm_output_file = os.path.join(temp_folder, "qnormalized_bedgraph.tsv")
        df.to_csv(qnorm_output_file, sep='\t', index=False)
        print(f"Normalized data saved to: {qnorm_output_file}")

    # Process each column
    for column in columns_to_normalize:
        # Create a new dataframe for this column
        bedgraph_df = df[['chromosome', 'start']].copy()
        bedgraph_df['end'] = bedgraph_df['start'] + 1  # Assuming 1-base positions
        bedgraph_df['score'] = df[column]
        bedgraph_df['coverage'] = 1  # Placeholder for 'coverage' column

        # Generate output filenames
        original_filename = f"{column}_filled.bedgraph"
        qnorm_bedgraph_filename = original_filename.replace("filled.bedgraph", "qnorm.bedgraph")
        qnorm_bigwig_filename = qnorm_bedgraph_filename.replace(".bedgraph", ".bw")
        smoothed_bedgraph_filename = original_filename.replace("filled.bedgraph", "qnorm_smoothed.bedgraph")
        smoothed_bigwig_filename = smoothed_bedgraph_filename.replace(".bedgraph", ".bw")
        
        qnorm_bedgraph_path = os.path.join(os.path.dirname(input_file), qnorm_bedgraph_filename)
        qnorm_bigwig_path = os.path.join(os.path.dirname(input_file), qnorm_bigwig_filename)
        smoothed_bedgraph_path = os.path.join(os.path.dirname(input_file), smoothed_bedgraph_filename)
        smoothed_bigwig_path = os.path.join(os.path.dirname(input_file), smoothed_bigwig_filename)

        # Save the qnorm bedgraph file from the first 4 columns of bedgraph_df
        if(False):
            bedgraph_df[['chromosome', 'start', 'end', 'score']].to_csv(qnorm_bedgraph_path, sep='\t', index=False, header=False, na_rep='na')
            print(f"Saved qnorm bedgraph: {qnorm_bedgraph_path}")

            # Convert qnorm bedgraph to bigwig
            if not os.path.exists(qnorm_bigwig_path):
                print(f"Converting {qnorm_bedgraph_filename} to bigwig...")
                try:
                    bigwig_command = [bedgraphtobigwig_path, qnorm_bedgraph_path, chrom_sizes, qnorm_bigwig_path]
                    subprocess.run(bigwig_command, check=True)
                    print(f"Qnorm bigwig file created: {qnorm_bigwig_path}")
                except subprocess.CalledProcessError as e:
                    print(f"An error occurred during bedgraph to bigwig conversion: {e}")
                    raise
            else:
                print(f"Qnorm bigwig file already exists, skipping: {qnorm_bigwig_path}")

        # Apply smoothing
        if not os.path.exists(smoothed_bedgraph_path):
            print(f"Imputing and smoothing bedgraph file: {smoothed_bedgraph_path}")
            bedgraph_df['score'] = bedgraph_df['score'].fillna(0)
            bedgraph_df['imputed_score'], bedgraph_df['imputed_coverage'], bedgraph_df['smoothed_score'], bedgraph_df['smoothed_coverage'] = nanotools.parallel_impute_and_smooth(
                bedgraph_df,
                impute_window=imputation_window,
                smooth_window=smoothing_window,
                fill_value=0
            )
            
            print(f"Saving smoothed bedgraph file: {smoothed_bedgraph_path}")
            bedgraph_df[['chromosome', 'start', 'end', 'smoothed_score']].to_csv(
                smoothed_bedgraph_path,
                sep="\t", header=False, index=False)
        else:
            print(f"Imputed and smoothed bedgraph file already exists, skipping: {smoothed_bedgraph_path}")

        # Convert smoothed bedgraph to bigwig
        if not os.path.exists(smoothed_bigwig_path):
            print(f"Converting smoothed bedgraph to bigwig: {smoothed_bigwig_path}")
            command = [bedgraphtobigwig_path, smoothed_bedgraph_path, chrom_sizes, smoothed_bigwig_path]
            subprocess.run(command, text=True, capture_output=True)
        else:
            print(f"Imputed and smoothed bigwig file already exists, skipping: {smoothed_bigwig_path}")

    print("All files have been processed, normalized, smoothed, and converted.")
    
imputation_window = 0  # Set this to your desired value
smoothing_window = 50  # Set this to your desired value
import_normalize_smooth_and_convert(bedgraphtobigwig_path, chrom_sizes, imputation_window, smoothing_window)

In [None]:
### BINNING AND PREP FOR WAVELET ANALYSIS

import pandas as pd
import pyarrow as pa
import pyarrow.csv as csv
import pyarrow.compute as pc
import numpy as np
import os

def import_bedgraph_files(file_paths, num_lines, binsize=0, chromosomes=None,temp_folder = None):
    dataframes = []
    if chromosomes is not None:
        chromosomes = [f"CHROMOSOME_{chrom}" for chrom in chromosomes]
        print(chromosomes)

    for file_path in file_paths:
        base_name = os.path.basename(file_path)
        # shortened name as "_".join(base_name.split("_")[:-1]) + chromosome and bin size
        shortened_name ="_".join(base_name.split("_")[:-1]) + f"_chr_{chromosomes[0]}_{binsize}"
        # Output fn as temp path, shortened name, chromosome and bin size
        output_fn = os.path.join(temp_folder, shortened_name + f"_chr_{chromosomes[0]}_{binsize}" + ".csv")


        # if file is already binned, skip
        if os.path.exists(output_fn):
            # load file
            print("file already exists, loading file...")
            df = pd.read_csv(output_fn)
            df.shortened_name = shortened_name
        else:
            # Define PyArrow read options
            read_options = csv.ReadOptions(use_threads=True)
            parse_options = csv.ParseOptions(delimiter='\t')

            print(f"Reading file: {file_path}")
            # Read the CSV file using PyArrow
            table = csv.read_csv(file_path, read_options=read_options, parse_options=parse_options)

            # Convert to pandas DataFrame
            df = table.to_pandas()
            df.columns = ['chromosome', 'start', 'end', 'score']

            print("Filtering chromosomes...")
            if chromosomes is not None:
                df = df[df['chromosome'].isin(chromosomes)]

            if num_lines > 0:
                df = df.head(num_lines)

            print("binning...")
            if binsize > 0:
                grouped = df.groupby('chromosome')
                binned_dfs = []

                for chrom, chrom_df in grouped:
                    bin_edges = np.arange(int(chrom_df['start'].min()), int(chrom_df['end'].max()) + int(binsize), int(binsize))
                    chrom_df['bin'] = pd.cut(chrom_df['start'], bins=bin_edges, labels=bin_edges[:-1], include_lowest=True)

                    binned = chrom_df.groupby('bin').agg({
                        'chromosome': 'first',
                        'start': 'min',
                        'end': 'max',
                        'score': 'mean'
                    }).reset_index(drop=True)

                    binned_dfs.append(binned)

                df = pd.concat(binned_dfs, ignore_index=True)

            df.to_csv(output_fn, index=False)
            df.shortened_name = shortened_name

        dataframes.append(df)
    return dataframes

# Usage remains the same
# file_paths = [
#     "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/jun_2024/BK_05_30_24-N2_young_SMACseq_R10_a_A0_qnorm_smoothed.bedgraph",
#     "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/jun_2024/BM_05_30_24-N2_old_SMACseq_R10_a_A0_qnorm_smoothed.bedgraph",
#     "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/jun_2024/BN_05_24_24-96_old_DPY27degron_SMACseq_R10_a_A0_qnorm_smoothed.bedgraph"
# ]

temp_folder = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"+ file_prefix + "/"
# Find all files ending in .bedgraph in temp_folder
extracted_elements = [file for file in os.listdir(temp_folder) if file.endswith("smoothed.bedgraph")]
# Create the file paths
file_paths = [os.path.join(temp_folder, f"{element}") for element in extracted_elements]

num_lines = 0  # Change this to the number of lines you want to read from each file
binsize = 100
bg_dfs = import_bedgraph_files(file_paths, num_lines, binsize=binsize, chromosomes=["X"], temp_folder = temp_folder)



In [None]:
def plot_wavelet(time, signal,
                 waveletname='cmor',
                 min_period=None,
                 max_period=None,
                 num_scales=100,
                 cmap=plt.cm.magma,
                 title='Wavelet Transform (Power Spectrum) of signal',
                 ylabel='Period (bp)',
                 xlabel='Position (bp)',
                 normalize=True):

    dt = time[1] - time[0]

    if normalize:
        signal = (signal - np.mean(signal)) / np.std(signal)

    # Convert periods to scales
    # Convert periods to scales
    def period_to_scale(period, waveletname, dt):
        if waveletname == 'cgau6':
            fc = 0.6678700143418761  # Center frequency for cgau6
            return (period * fc) / dt
        else:
            return period / (4 * dt)  # General approximation for other wavelets

    if min_period is None:
        min_period = 2 * dt  # Nyquist limit
    if max_period is None:
        max_period = len(time) * dt / 2  # Half the signal length

    min_scale = period_to_scale(min_period, waveletname, dt)
    max_scale = period_to_scale(max_period, waveletname, dt)

    # Generate logarithmically spaced scales
    scales = np.geomspace(min_scale, max_scale, num=num_scales)

    [coefficients, frequencies] = pywt.cwt(signal, scales, waveletname, dt)
    power = (abs(coefficients)) ** 2
    period = 1. / frequencies
    
    # Compute the range of log2(power) values
    log2_power = np.log2(power)
    vmin, vmax = np.nanpercentile(log2_power, [10, 95])
    vmin = -4
    vmax = 10
    # Create levels that span the entire range of the data
    #num_levels = 9  # You can adjust this for more or fewer contour levels
    #contourlevels = np.linspace(vmin, vmax, num_levels)

    fig, ax = plt.subplots(figsize=(15, 4))

    #im = ax.contourf(time, np.log2(period), np.log2(power), contourlevels, extend='both', cmap=cmap)
    im = ax.pcolormesh(time, np.log2(period), np.log2(power), cmap=cmap, vmin=vmin, vmax=vmax)
    
    ax.set_title(title, fontsize=20)
    ax.set_ylabel(ylabel, fontsize=18)
    ax.set_xlabel(xlabel, fontsize=18)
    # set min and max x-axis limits
    ax.set_xlim(9000000, 11000000)

    # Set y-axis ticks to show actual period values
    yticks = 2**np.arange(np.ceil(np.log2(period.min())), np.ceil(np.log2(period.max())))
    ax.set_yticks(np.log2(yticks))
    ax.set_yticklabels(yticks)

    ax.invert_yaxis()
    #ylim = ax.get_ylim()
    #ax.set_ylim(ylim[0], -1)  # Adjust upper limit if needed

    # Add thin semi-transparent horizontal and vertical gridlines
    ax.grid(which='both', color='white', linestyle='-', linewidth=0.5, alpha=0.3)

    # Add colorbar
    cbar_ax = fig.add_axes([0.95, 0.5, 0.03, 0.25])
    cbar = fig.colorbar(im, cax=cbar_ax, orientation="vertical")
    cbar.ax.set_ylabel('Power', rotation=270, labelpad=20)

    # Set colorbar ticks to show log2 values
    tick_locations = np.linspace(vmin, vmax, 3)
    # increase size of ticks
    cbar.ax.yaxis.set_tick_params(width=2)
    cbar.set_ticks(tick_locations)
    cbar.set_ticklabels([f'{x:.0f}' for x in tick_locations])

    #plt.tight_layout()
    return fig, power, period, time

# import fft from scipy
import numpy as np
from scipy.fft import fft
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

def get_fft_values(y_values, T, N, f_s):
    f_values = np.linspace(0.0, 1.0/(2.0*T), N//2)
    fft_values_ = fft(y_values)
    fft_values = 2.0/N * np.abs(fft_values_[0:N//2])
    return f_values, fft_values

def plot_fft_plus_power(time, signal, min_period=10000, max_period=200000):
    dt = time[1] - time[0]
    N = len(signal)
    fs = 1/dt
    
    fig, ax = plt.subplots(figsize=(15, 5))
    variance = np.std(signal)**2
    f_values, fft_values = get_fft_values(signal, dt, N, fs)
    fft_power = variance * abs(fft_values) ** 2     # FFT power spectrum
    
    # Convert frequency to period (in basepairs)
    periods = np.where(f_values != 0, 1 / f_values, np.inf)
    
    # Filter out periods outside the specified range
    mask = (periods >= min_period) & (periods <= max_period)
    periods = periods[mask]
    fft_power = fft_power[mask]
    
    # Sort the arrays by period for proper plotting
    sort_idx = np.argsort(periods)
    periods = periods[sort_idx]
    fft_power = fft_power[sort_idx]
    
    # Plot power spectrum
    ax.plot(periods, fft_power, 'k-', linewidth=1, label='FFT Power Spectrum')
    
    # Calculate and plot rolling average (trendline)
    window = 50
    trendline = gaussian_filter1d(fft_power, sigma=window/5)
    ax.plot(periods, trendline, 'r-', linewidth=2, label='Trendline (Rolling Average)')
    
    #ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('Period [basepairs]', fontsize=18)
    ax.set_ylabel('Power', fontsize=18)
    ax.set_xlim(min_period, max_period)
    ax.legend()
    plt.title('FFT Power Spectrum', fontsize=20)
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.show()

# Usage
wavelet_results = []

for i, df in enumerate(bg_dfs):
    print(f"Dataframe {i}: {df.shortened_name}")

    N = df.shape[0]
    t0 = 0
    dt = binsize
    time = np.arange(0, N) * dt + t0
    signal = df["score"].values.squeeze()
    
    plot_fft_plus_power(time, signal,min_period=10000, max_period=200000)
    
    plt_obj, log_power, period, time = plot_wavelet(time, signal,cmap=plt.cm.coolwarm, waveletname="cgau6", min_period=100, max_period=10000, num_scales=100, title = df.shortened_name)#cgau6 mexh
    wavelet_results.append((plt_obj, log_power, period, time, df.shortened_name, signal))
    
# Combine all plt_obj into a single figure

In [None]:
# Combine all plt_obj into a single figure
n_plots = len(wavelet_results)
fig, axs = plt.subplots(n_plots, 1, figsize=(15, 4 * n_plots))

if n_plots == 1:
    axs = [axs]

for i, (plt_obj, log_power, period, time, title, signal) in enumerate(wavelet_results):
    fig_tmp, ax_tmp = plt_obj.gca(), axs[i]
    im = ax_tmp.pcolormesh(time, np.log2(period), np.log2(log_power), cmap=plt.cm.coolwarm, vmin=-2, vmax=8)
    # Split subtitle based on "-" and "R10"
    subtitle = title.split('-')[-1]
    subtitle = subtitle.split('R10')[0]

    ax_tmp.set_title(subtitle, fontsize=20)
    ax_tmp.set_ylabel('Period (bp)', fontsize=18)
    ax_tmp.set_xlabel('Position (bp)', fontsize=18)
    
    # Set y-axis ticks to show actual period values
    yticks = 2**np.arange(np.ceil(np.log2(period.min())), np.ceil(np.log2(period.max())))
    ax_tmp.set_yticks(np.log2(yticks))
    ax_tmp.set_yticklabels(yticks)
    
    ax_tmp.invert_yaxis()
    ax_tmp.grid(which='both', color='white', linestyle='-', linewidth=0.5, alpha=0.3)

# Add a single colorbar for all subplots
fig.subplots_adjust(right=0.9, top=0.85)
cbar_ax = fig.add_axes([0.92, 0.7, 0.02, 0.2])  # Adjust size and position
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.ax.set_ylabel('Power', rotation=270, labelpad=25, fontsize=21)  # Increase fontsize by 50% from 14


plt.tight_layout(rect=[0, 0, 0.9, 1])
plt.show()

In [None]:
def plot_combined_power_plotly_smooth(wavelet_results, period_of_interest=100000, window_size=100):
    fig = go.Figure()

    for plt_obj, log_power, period, time, condition_name, signal in wavelet_results:

        # Find the index of the period closest to 100000
        idx = np.argmin(np.abs(period - period_of_interest))

        # Extract the power for this period
        power_at_period = 2**log_power[idx, :]  # Convert back from log2
        #print("log_power",log_power)
        #print("power_at_period",power_at_period)

        # Create a pandas Series and apply rolling mean
        power_series = pd.Series(power_at_period)
        power_smooth = power_series.rolling(window=window_size, center=True, min_periods=1).mean()

        # Add trace for this condition
        fig.add_trace(go.Scatter(
            x=time,
            y=power_smooth,
            mode='lines',
            name=condition_name,
            line=dict(shape='spline', smoothing=1.3)  # This makes the line even smoother
        ))

    # Update layout
    fig.update_layout(
        title=f'Smoothed Power at period {period_of_interest} bp across conditions (Rolling mean, window={window_size})',
        xaxis_title='Position (bp)',
        yaxis_title='Smoothed Power',
        legend_title='Conditions',
        font=dict(size=14),
        hovermode='x unified',
        template='plotly_white',  # Use Plotly white template
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-1,
            xanchor="center",
            x=0.5
        )
    )
    
    # Update y axis to span from 2^-4 to 2^8 on logarithmic of 2 scale
    fig.update_yaxes(type="log", range=[-4, 30])

    # Adjust margins to accommodate legend below the chart
    fig.update_layout(margin=dict(l=50, r=50, t=80, b=120))

    return fig

# After running your existing code to get wavelet_results, use this function:
fig = plot_combined_power_plotly_smooth(wavelet_results, period_of_interest=10000, window_size=1000)
fig.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pycwt import wavelet
import pywt
from matplotlib.gridspec import GridSpec

def get_pycwt_wavelet(waveletname):
    """
    Map PyWavelets names to pyCWT wavelet objects where possible,
    or return a pyCWT wavelet object directly.
    """
    pycwt_wavelets = {
        'morlet': wavelet.Morlet(),
        'paul': wavelet.Paul(),
        'dog': wavelet.DOG(),
        'mexican_hat': wavelet.DOG(m=2),  # Mexican hat is DOG of order 2
        'cgau6': wavelet.DOG(m=6)  # 6th order DOG to approximate cgau6
    }

    pywt_to_pycwt = {
        'morl': 'morlet',
        'paul': 'paul',
        'mexh': 'mexican_hat',
        'cgau6': 'cgau6'
    }

    if waveletname in pycwt_wavelets:
        return pycwt_wavelets[waveletname]
    elif waveletname in pywt_to_pycwt:
        return pycwt_wavelets[pywt_to_pycwt[waveletname]]
    else:
        raise ValueError(f"Wavelet '{waveletname}' is not supported. "
                         f"Supported wavelets are: {list(pycwt_wavelets.keys()) + list(pywt_to_pycwt.keys())}")

def plot_mother_wavelet(waveletname):
    """
    Plot the mother wavelet.
    """
    mother = get_pycwt_wavelet(waveletname)

    # Generate time array
    t = np.linspace(-20, 20, 1000)

    # Compute wavelet
    psi = mother.psi(t)

    # Plot
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(t, psi.real, label='Real part')
    if np.iscomplexobj(psi):
        ax.plot(t, psi.imag, label='Imaginary part')
    ax.set_title(f'Mother Wavelet: {waveletname}')
    ax.set_xlabel('Time')
    ax.set_ylabel('Amplitude')
    ax.legend()
    ax.grid(True)

    return fig

def wavelet_coherence(time, signal1, signal2,
                      waveletname='morlet',
                      min_scale=None,
                      max_scale=None,
                      num_scales=100,
                      cmap=plt.cm.viridis,
                      title='Wavelet Coherence',
                      ylabel='Period (bp)',
                      xlabel='Position (bp)'):

    dt = time[1] - time[0]

    # Normalize the signals
    signal1 = (signal1 - np.mean(signal1)) / np.std(signal1)
    signal2 = (signal2 - np.mean(signal2)) / np.std(signal2)

    # Automatically determine min_scale and max_scale if not provided
    if min_scale is None:
        min_scale = 2 * dt  # Nyquist frequency
    if max_scale is None:
        max_scale = (time[-1] - time[0]) / 2  # Half the time series length

    # Generate logarithmically spaced scales
    scales = np.logspace(np.log2(min_scale), np.log2(max_scale), num_scales)

    # Get the appropriate wavelet
    mother = get_pycwt_wavelet(waveletname)

    # Perform wavelet coherence analysis
    wct, aWCT, coi, freq, sig = wavelet.wct(signal1, signal2, dt,
                                            dj=1/12, s0=-1, J=-1,
                                            sig=False, sig_test=0,
                                            significance_level=0.95,
                                            wavelet=mother,
                                            normalize=True)

    period = 1 / freq

    # Create figure
    fig = plt.figure(figsize=(15, 10))
    gs = GridSpec(2, 2, width_ratios=[1, 0.05], height_ratios=[0.1, 1], wspace=0.02, hspace=0.2)

    # Coherence plot
    ax = fig.add_subplot(gs[1, 0])
    im = ax.contourf(time, np.log2(period), wct, cmap=cmap, levels=np.linspace(0, 1, 11))
    ax.set_title(title, fontsize=16)
    ax.set_ylabel(ylabel, fontsize=14)
    ax.set_xlabel(xlabel, fontsize=14)

    # Dynamically set y-axis ticks based on actual period range
    min_period = period.min()
    max_period = period.max()

    # Determine appropriate tick locations
    log_min_period = np.log2(min_period)
    log_max_period = np.log2(max_period)

    # Generate tick locations at powers of 2
    tick_locations = 2**np.arange(np.floor(log_min_period), np.ceil(log_max_period) + 1)

    # Filter tick locations to be within the actual period range
    tick_locations = tick_locations[(tick_locations >= min_period) & (tick_locations <= max_period)]

    ax.set_yticks(np.log2(tick_locations))
    ax.set_yticklabels([f'{x:.0f}' for x in tick_locations])
    ax.invert_yaxis()

    # Add colorbar
    cax = fig.add_subplot(gs[1, 1])
    cbar = fig.colorbar(im, cax=cax)
    cbar.set_label('Coherence', rotation=270, labelpad=15)

    # Add thin semi-transparent horizontal and vertical gridlines
    ax.grid(which='both', color='white', linestyle='-', linewidth=0.5, alpha=0.3)

    # Plot time series
    ax_ts = fig.add_subplot(gs[0, 0])
    ax_ts.plot(time, signal1, 'b-', label='Signal 1')
    ax_ts.plot(time, signal2, 'r-', label='Signal 2')
    ax_ts.set_ylabel('Amplitude')
    ax_ts.legend(loc='upper right')
    ax_ts.set_xlim(time.min(), time.max())

    plt.tight_layout()
    return fig, wct, period, time

# Example usage:
# Plot the mother wavelet (this can be any supported wavelet, including 'cgau6')
wavelet_name_plot = "cgau6"
mother_wavelet_fig = plot_mother_wavelet(wavelet_name_plot)
mother_wavelet_fig.savefig(f"mother_wavelet_{wavelet_name_plot}.png")
plt.close(mother_wavelet_fig)

# Perform pairwise coherence analysis (use a fully supported pyCWT wavelet)
wavelet_name_coherence = "morlet"
coherence_results = []

for i in range(len(bg_dfs)):
    for j in range(i+1, len(bg_dfs)):
        _, _, period, time, name1, signal1 = wavelet_results[i]
        _, _, _, _, name2, signal2 = wavelet_results[j]

        print(f"Calculating coherence between {name1} and {name2}")
        # Keep only first 5000000 items of signal1 and signal2
        nrows = 20000
        signal1_sub = signal1[:nrows]
        signal2_sub = signal2[:nrows]
        period_sub = period
        time_sub = time[:nrows]

        fig, wct, period_sub, time_sub = wavelet_coherence(
            time_sub, signal1_sub, signal2_sub,
            waveletname=wavelet_name_coherence,
            min_scale=100,
            max_scale=5000,
            num_scales=100,
            title=f'Wavelet Coherence: {name1} vs {name2}'
        )

        coherence_results.append((fig, wct, period_sub, time_sub, name1, name2))
        fig.savefig(f"coherence_{name1}_vs_{name2}.png")
        plt.close(fig)

# You can now analyze the coherence results
for fig, wct, period, time, name1, name2 in coherence_results:
    # Perform any additional analysis on the coherence results
    # For example, you could identify regions of high coherence:
    high_coherence = wct > 0.8  # Adjust threshold as needed
    print(f"High coherence regions between {name1} and {name2}:")
    print(f"Total area of high coherence: {high_coherence.sum() / high_coherence.size:.2%}")
    # show plot
    fig.show()

In [None]:
### RUN MACS3 Peak Caller, in terminal
'''e.g.
# To determine cutoffs:
macs3 bdgpeakcall --ifile /Data1/git/meyer-nanopore/scripts/analysis/temp_files/BK_05_30_24-N2_young_SMACseq_R10_a_A0_qnorm_smoothed.bedgraph --cutoff-analysis --verbose 2 --o-prefix N2_young
# Then to call peaks:
macs3 bdgpeakcall --ifile /Data1/git/meyer-nanopore/scripts/analysis/temp_files/BK_05_30_24-N2_young_SMACseq_R10_a_A0_qnorm_smoothed.bedgraph --cutoff 0.085 --verbose 2 --o-prefix N2_young
'''

### Now MACS3 peaks analysis:
from venn import venn


def import_narrowpeak_files(directory):
    dfs = {}
    for filename in os.listdir(directory):
        if filename.endswith('.narrowPeak'):
            if "7p5" in filename:
                condition = filename.split('.')[0]
                condition = condition.replace("_7p5_c0.", "")
                file_path = os.path.join(directory, filename)
                df = pd.read_csv(file_path, sep='\t', header=None,
                                 names=['chromosome', 'start', 'end', 'name', 'score',
                                        'strand', 'signalValue', 'pValue', 'qValue', 'peak'])
                print(f"Imported {condition}: {len(df)} peaks")
                dfs[condition] = df
    return dfs

def create_venn_diagram(dfs):
    peak_sets = {condition: set(zip(df['chromosome'], df['start'], df['end'])) for condition, df in dfs.items()}

    # Debug: Print peak set sizes
    for condition, peaks in peak_sets.items():
        print(f"{condition} peak set size: {len(peaks)}")

    plt.figure(figsize=(8, 8))
    venn(peak_sets)
    plt.title("Shared and Unique Peaks Across Conditions")
    plt.show()

def create_boxplot(dfs):
    fig = go.Figure()
    for condition, df in dfs.items():
        widths = df['end'] - df['start']
        fig.add_trace(go.Box(y=widths, name=condition))

    fig.update_layout(
        title="Peak Width Distribution",
        yaxis_title="Width (bp)",
        width=600,
        height=400,
        template="plotly_white"
    )
    fig.show()

def create_histogram(dfs):
    peak_counts = {condition: len(df) for condition, df in dfs.items()}

    fig = go.Figure(data=[go.Bar(
        x=list(peak_counts.keys()),
        y=list(peak_counts.values()),
    )])
    for i, count in enumerate(peak_counts.values()):
        fig.add_annotation(
            x=list(peak_counts.keys())[i],
            y=count,
            text=str(count),
            showarrow=False,
            yshift=10
        )
    fig.update_traces(marker_color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'])

    fig.update_layout(
        title="Number of Peaks per Condition",
        xaxis_title="Condition",
        yaxis_title="Peak Count",
        width=600,
        height=400,
        template="plotly_white",
    )
    fig.show()

if __name__ == "__main__":
    directory = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"
    dfs = import_narrowpeak_files(directory)

    create_venn_diagram(dfs)
    create_boxplot(dfs)
    create_histogram(dfs)

    # Export to CSV
    comprehensive_df = pd.concat(dfs.values(), keys=dfs.keys())
    output_file = "/Data1/git/meyer-nanopore/scripts/analysis/comprehensive_peaks_7p5.csv"
    #comprehensive_df.to_csv(output_file, index=False)
    print(f"Comprehensive peak data exported to {output_file}")

    # Display the first few rows of the DataFrame
    print(comprehensive_df.head())


In [None]:
def import_narrowpeak_files(directory, debug=False):
    print("Importing narrowPeak files...")
    dfs = {}
    for filename in os.listdir(directory):
        if filename.endswith('.narrowPeak'):
            if "8p5" in filename:
                condition = filename.split('.')[0]
                condition = condition.replace("_8p5_c0", "")
                file_path = os.path.join(directory, filename)
                if debug:
                    print(f"Reading file: {file_path}")
                df = pd.read_csv(file_path, sep='\t', header=None,skiprows=1,
                                 names=['Chromosome', 'Start', 'End', 'Name', 'Score',
                                        'Strand', 'SignalValue', 'PValue', 'QValue', 'Peak'])
                df['Condition'] = condition
                dfs[condition] = df
                if debug:
                    print(f"Imported {condition}: {df.shape[0]} rows")
    print(f"Imported {len(dfs)} narrowPeak files.")
    return dfs



def concatenate_conditions_on_overlap(dfs, progress=True, debug=False):
    pyranges_dict = {}
    total_files = len(dfs)
    for i, (condition, df) in enumerate(dfs.items(), 1):
        df = df[['Chromosome', 'Start', 'End', 'Condition']]

        # Convert Start and End to integers
        df['Start'] = pd.to_numeric(df['Start'], errors='coerce').fillna(0).astype(int)
        df['End'] = pd.to_numeric(df['End'], errors='coerce').fillna(0).astype(int)

        pyr = pr.PyRanges(df)
        pyranges_dict[condition] = pyr
        if progress:
            print(f"Processed {i}/{total_files} files: {condition}")
        if debug and i <= 5:  # Limit the debug output to the first 5 files
            print(f"{condition} PyRanges:\n{pyr.df.head()}")

    combined_df_list = []
    conditions = list(pyranges_dict.keys())

    # Include all original peaks
    for condition in conditions:
        original_df = pyranges_dict[condition].df.copy()
        original_df['CombinedCondition'] = condition
        combined_df_list.append(original_df)

    # Find pairwise intersections
    for i in range(len(conditions)):
        for j in range(i+1, len(conditions)):
            combined = pyranges_dict[conditions[i]].intersect(pyranges_dict[conditions[j]])
            combined_df = combined.df.copy()
            combined_df['CombinedCondition'] = '-'.join(sorted([conditions[i], conditions[j]]))
            combined_df_list.append(combined_df)
            if debug:
                print(f"Intersection of {conditions[i]}, {conditions[j]}:\n{combined_df.head(5)}")  # Limit to first 5 rows

    # Find three-way intersections
    for i in range(len(conditions)):
        for j in range(i+1, len(conditions)):
            for k in range(j+1, len(conditions)):
                combined = pyranges_dict[conditions[i]].intersect(pyranges_dict[conditions[j]])
                combined = combined.intersect(pyranges_dict[conditions[k]])
                combined_df = combined.df.copy()
                combined_df['CombinedCondition'] = '-'.join(sorted([conditions[i], conditions[j], conditions[k]]))
                combined_df_list.append(combined_df)
                if debug:
                    print(f"Intersection of {conditions[i]}, {conditions[j]}, {conditions[k]}:\n{combined_df.head(5)}")  # Limit to first 5 rows

    # Concatenate all combined dataframes
    if combined_df_list:
        final_df = pd.concat(combined_df_list, ignore_index=True)
    else:
        final_df = pd.DataFrame(columns=['Chromosome', 'Start', 'End', 'Condition', 'CombinedCondition'])

    # Count the number of conditions each peak overlaps with
    final_df['ConditionCount'] = final_df['CombinedCondition'].apply(lambda x: len(x.split('-')))

    # Filter to retain peaks with the highest overlap count
    final_df.sort_values(by=['Chromosome', 'Start', 'End', 'ConditionCount'], ascending=[True, True, True, False], inplace=True)
    final_df.drop_duplicates(subset=['Chromosome', 'Start', 'End'], keep='first', inplace=True)

    # Reformat the DataFrame
    final_result = final_df[['Chromosome', 'Start', 'End', 'CombinedCondition']].copy()
    final_result['Strand'] = '.'

    # Reorder columns to place 'Strand' after 'End'
    final_result = final_result[['Chromosome', 'Start', 'End', 'Strand', 'CombinedCondition']]

    # Reindex the DataFrame
    final_result.reset_index(drop=True, inplace=True)

    if debug:
        print("Final result sample:\n", final_result.head(5))  # Limit to first 5 rows

    return final_result



# Usage
directory = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files"  # Replace with your directory
dfs = import_narrowpeak_files(directory, debug=False)
result_df = concatenate_conditions_on_overlap(dfs, progress=False, debug=True)

# write to file
result_df.to_csv("/Data1/git/meyer-nanopore/scripts/analysis/temp_files/combined_peaks_8p5.bed", sep='\t', header=False, index=False)

# Print only the head of the final result to avoid excessive output
print(result_df.head())

In [None]:
from matplotlib_venn import venn3
import matplotlib.pyplot as plt

def plot_venn_diagram(df, condition1, condition2, condition3):
    # filter df to only contain chromosome_X
    df = df[df['Chromosome'] == 'CHROMOSOME_X']
    # Filter the dataframe for each condition
    condition1_regions = df[df['CombinedCondition'].str.contains(condition1)]
    condition2_regions = df[df['CombinedCondition'].str.contains(condition2)]
    condition3_regions = df[df['CombinedCondition'].str.contains(condition3)]

    # Get the unique identifiers for each region
    condition1_set = set(condition1_regions.apply(lambda row: (row['Chromosome'], row['Start'], row['End']), axis=1))
    condition2_set = set(condition2_regions.apply(lambda row: (row['Chromosome'], row['Start'], row['End']), axis=1))
    condition3_set = set(condition3_regions.apply(lambda row: (row['Chromosome'], row['Start'], row['End']), axis=1))

    # Create the Venn diagram
    plt.figure(figsize=(8, 8))
    venn = venn3([condition1_set, condition2_set, condition3_set], (condition1, condition2, condition3))

    # Display the Venn diagram
    plt.title(f"Venn Diagram of {condition1}, {condition2}, and {condition3}")
    plt.show()

# Example usage
# Assuming result_df is your DataFrame containing the regions and CombinedCondition
plot_venn_diagram(result_df, 'N2_young', 'N2_old', 'DPY27_degron')

In [None]:
import pandas as pd
import math

# Function to merge overlapping regions with debugging
# Function to merge overlapping regions with debugging
def merge_overlapping_regions(df, debug=False):
    if debug:
        print("Initial DataFrame:")
        print(df.head())

    merged_regions = []

    # Ensure correct data types
    df['start'] = df['start'].astype(int)
    df['end'] = df['end'].astype(int)

    # Sort dataframe by chromosome and start
    df = df.sort_values(by=['chromosome', 'start'])

    if debug:
        print("\nSorted DataFrame:")
        print(df.head())

    current_region = None

    for _, row in df.iterrows():
        if current_region is None:
            current_region = row
        else:
            if debug:
                print("\nCurrent Region:", current_region)
                print("Next Row:", row)

            if row['start'] <= current_region['end']:  # Overlapping regions
                if debug:
                    print("Overlapping detected.")
                current_types = set(current_region['type'].split('-'))
                new_types = set(row['type'].split('-'))
                combined_types = sorted(current_types.union(new_types))
                current_region['type'] = '-'.join(combined_types)

                # Check for "qy" in types
                contains_qy_current = any("qy" in t.lower() for t in current_types)
                contains_qy_new = any("qy" in t.lower() for t in new_types)

                if contains_qy_current and contains_qy_new:
                    current_region['start'] = min(current_region['start'], row['start'])
                    current_region['end'] = max(current_region['end'], row['end'])
                elif not contains_qy_current and not contains_qy_new:
                    current_region['start'] = min(current_region['start'], row['start'])
                    current_region['end'] = max(current_region['end'], row['end'])
                else:
                    # If merging one with "qy" and one without, preserve the non-"qy" start/end
                    if contains_qy_current and not contains_qy_new:
                        current_region['start'] = row['start']
                        current_region['end'] = row['end']
                    elif not contains_qy_current and contains_qy_new:
                        current_region['start'] = current_region['start']
                        current_region['end'] = current_region['end']
            else:  # No overlap
                merged_regions.append(current_region)
                current_region = row

    # Append the last region
    if current_region is not None:
        merged_regions.append(current_region)

    merged_df = pd.DataFrame(merged_regions)

    # Keep the first region start and end values if neither contains "qy"
    def adjust_start_end(row):
        current_types = set(row['type'].split('-'))
        contains_qy = any("qy" in t.lower() for t in current_types)
        if not contains_qy:
            row['start'] = row['start']
            row['end'] = row['end']
        return row

    merged_df = merged_df.apply(adjust_start_end, axis=1)
    
    if debug:
        print("\nMerged DataFrame with updated start and end:")
        print(merged_df.head())
    
    return merged_df

# Import csv to collapse
all_peaks_df = pd.read_csv("/Data1/git/meyer-nanopore/scripts/analysis/temp_files/QY_chip_fiber.csv", sep=',', header=None,
                           index_col=False,
                           names=['chromosome', 'start', 'end', 'strand', 'type'])
display(all_peaks_df.head())

# Merge overlapping regions
merged_df = merge_overlapping_regions(all_peaks_df, debug=False)
display(merged_df.head())

# Write to csv
merged_df.to_csv("/Data1/git/meyer-nanopore/scripts/analysis/temp_files/QY_fiber_merged.csv", sep=',', header=False, index=False)


In [None]:
# construct deeptools multibigwigsummary command on all filled.bw files

# Zip up conditions and exp_ids together
conditions_exp_ids = [f"{condition}-{exp_id}" for condition, exp_id in zip(conditions, exp_ids)]

#bw_files is equal to every file ending in "filled.bw" in the base folder of each_bam in new_bam_files
#bw_files = [os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0_filled.bigwig") for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)]

#bw_files is equal to every file containing "smoothed" and ending in .bw in the base folder of each_bam in new_bam_files
#bw_files = [os.path.join(os.path.dirname(each_bam), f"{each_expid}-{each_condition}_a_A0_smoothed-20-0.bw") for each_bam, each_condition, each_expid in zip(new_bam_files, conditions, exp_ids)]

# Select only conditions_exp_ids and bw_files that contain "R10"
#conditions_exp_ids = [condition_exp_id for condition_exp_id in conditions_exp_ids if "SMAC" in condition_exp_id]
#bw_files = [bw_file for bw_file in bw_files if "SMAC" in bw_file]

# select files inding in qnorm_smoothed.bw in temp folder
bw_files = [os.path.join("/Data1/git/meyer-nanopore/scripts/analysis/temp_files/", f) for f in os.listdir("/Data1/git/meyer-nanopore/scripts/analysis/temp_files/") if f.endswith("A0_qnorm_smoothed.bw")]
# set conditions_exp_ids
conditions_exp_ids = [f.split("_R10")[0] for f in bw_files]
# replace "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/" with "" in conditions_exp_ids
conditions_exp_ids = [f.replace("/Data1/git/meyer-nanopore/scripts/analysis/temp_files/", "") for f in conditions_exp_ids]

print("bw_files:",  len(bw_files))
print(bw_files)

print("Conditions and exp_ids:", len(conditions_exp_ids))
print(conditions_exp_ids)

# Select only conditions_exp_ids and bw_files that do not contain "AH-0"
#conditions_exp_ids = [condition_exp_id for condition_exp_id in conditions_exp_ids if "AH-0" not in condition_exp_id]
#bw_files = [bw_file for bw_file in bw_files if "AH-0" not in bw_file]

bin_size = 100
command = [
    "multiBigwigSummary",
    "bins",
    "--verbose",
    # "--chromosomesToSkip",
    # "CHROMOSOME_I",
    # "CHROMOSOME_II",
    # "CHROMOSOME_III",
    # "CHROMOSOME_IV",
    # "CHROMOSOME_V",
    "--numberOfProcessors",
    str(50),
    "--binSize",
    str(bin_size),
    #"--BED",
    #"/Data1/git/meyer-nanopore/scripts/analysis/temp_files/ce11_first_100kb.bed",
    "--outFileName",
    f"/Data1/git/meyer-nanopore/scripts/analysis/temp_files/multiBigwigSummary_results-na-allSMAC-{bin_size}.npz",
    "--bwfiles",
] + bw_files + [
    "--labels",
] + conditions_exp_ids

# print command as it will appear to console
print("Executing command:")
print(' '.join(command))
# Execute command:
subprocess.run(command, text=True, check=True)

# Execute plotting command: plotCorrelation -in multiBigwigSummary_results-1000.npz -c spearman -p heatmap -o multiBigWigSummary_Correlation_R10_1000bp.png
command = [
    "plotCorrelation",
    "-in",
    f"/Data1/git/meyer-nanopore/scripts/analysis/temp_files/multiBigwigSummary_results-na-allSMAC-{bin_size}.npz",
    "-c",
    "pearson",
    "-p",
    "scatterplot",
    "--labels",
] + conditions_exp_ids + [
    #"--plotNumbers",
    "-o",
    f"/Data1/git/meyer-nanopore/scripts/analysis/temp_files/multiBigWigSummary_Correlation-na-allSMAC_{bin_size}.png"
]
subprocess.run(command, text=True, check=True)
incl_PCs = [1,2]
# Execute plotting command: plotCorrelation -in multiBigwigSummary_results-1000.npz -c spearman -p heatmap -o multiBigWigSummary_Correlation_R10_1000bp.png
command = [
    "plotPCA",
    "-in",
    f"/Data1/git/meyer-nanopore/scripts/analysis/temp_files/multiBigwigSummary_results-na-allSMAC-{bin_size}.npz",
    "--labels",
] + conditions_exp_ids + [
    "--PCs",
    "2",
    "3",
    "-o",
    f"/Data1/git/meyer-nanopore/scripts/analysis/temp_files/multiBigWigSummary_PCA-na-allSMAC_{bin_size}.png"
]
subprocess.run(command, text=True, check=True)

# Display the correlation plot


In [None]:
### Convert Qiming's .bw files to bedgraph:
# for each file ending in .bw in /Data1/ext_data/qiming_2024/ convert to bedgraph
for f in os.listdir("/Data1/ext_data/qiming_2024/"):
    if f.endswith(".bw"):
        # define path
        f = os.path.join("/Data1/ext_data/qiming_2024/", f)
        print("Converting to bedgraph: ", f)
        try:
            output_bw_fn = f.replace(".bw", ".bedgraph")
            bigwig_command = [
                bedgraphtobigwig_path,
                f,
                chrom_sizes,
                output_bw_fn
            ]
            subprocess.run(bigwig_command, check=True)
        except subprocess.CalledProcessError as e:
            print(f"An error occurred during bedgraph to bigwig conversion: {e}")


In [None]:
from scipy.ndimage import gaussian_filter1d
# Reimport nanotools
importlib.reload(nanotools)

# Base folder path as temp folder
base_folder = "/Data1/git/meyer-nanopore/scripts/analysis/temp_files/"

# File names as all files ending in "R10_a_A0_raw_qnorm_smoothed.bedgraph" in temp folder, with keys as the file names without the extension
files = {f.split(".")[0]: f for f in os.listdir(base_folder) if f.endswith("R10_a_A0_qnorm_smoothed.bedgraph")}

# ### PREP DF FOR PLOTTING NON NORMALIZED FILES
# # Base folder path
# base_folder = "/Data1/seq_data/BM_N2_old_Fiber_Hia5_MCVIPI_05_30_24/no_sample/20240530_1930_X2_FAX32001_56dbbc37/basecalls/"
# # File names
# files = {
#     'a_A0': 'BM_05_30_24-N2_old_SMACseq_R10_a_A0.bedgraph',
#     'm_GC1': 'BM_05_30_24-N2_old_SMACseq_R10_m_GC1.bedgraph',
#     'a_A0_m_GC1': 'BM_05_30_24-N2_old_SMACseq_R10_a_A0_m_GC1.bedgraph',
# }


# set chr to NULL

# Function to apply centered Gaussian smoothing with weights
# def apply_weighted_gaussian_smoothing(scores, weights, sigma):
#     weighted_scores = scores * weights
#     smoothed_weighted_scores = gaussian_filter1d(weighted_scores, sigma=sigma)
#     smoothed_weights = gaussian_filter1d(weights, sigma=sigma)
#     smoothed_scores = smoothed_weighted_scores / smoothed_weights
#     return smoothed_scores

# Load and process all files
data = {}
for key, file_name in files.items():
    file_path = os.path.join(base_folder, file_name)
    data[key] = nanotools.load_bedgraph_file(file_path,'CHROMOSOME_I',500000,505000, False)

# Combine all data into one DataFrame with type column = key
combined_data = pd.concat([df.assign(type=key) for key, df in data.items()], ignore_index=True)
nanotools.display_sample_rows(combined_data, 10)

In [None]:
# reimport nanotools local library
importlib.reload(nanotools)

# Configurable window for exponential smoothing
smoothing_window = 0
imputation_window = 0

# Create scatter plot
fig = go.Figure()

# Select a color scale
color_scale = px.colors.qualitative.Vivid

# Make sure we have enough colors in the color scale for all types
assert len(files) <= len(color_scale), "Not enough colors in the color scale for all types"

# Assign a color from the color scale to each type
color_dict = {file_key: color_scale[i] for i, file_key in enumerate(files.keys())}

smoothed_data = {}

for type in files.keys():
    df_type = combined_data[combined_data['type'] == type].sort_values('start')

    if imputation_window > 0 or smoothing_window > 0:
        # Apply combined imputation and smoothing
        df_type['imputed_score'], df_type['imputed_coverage'], df_type['smoothed_score'], df_type['smoothed_coverage'] = nanotools.impute_and_smooth(
            df_type['score'],
            df_type['coverage'],
            impute_window=imputation_window,  # Adjust as needed
            smooth_window=smoothing_window,  # Adjust as needed
            fill_value=0  # Use 0 if you prefer to fill with zeros
        )
    else:
        df_type['smoothed_score'] = df_type['score']

    # Store smoothed data for correlation calculation
    smoothed_data[type] = df_type['smoothed_score'].values

    fig.add_trace(go.Scatter(
        x=df_type['start'],
        y=df_type['score'],
        mode='markers',
        name=f"{type} (raw)",
        marker=dict(color=color_dict[type]),
    ))

    fig.add_trace(go.Scatter(
        x=df_type['start'],
        y=df_type['smoothed_score'],
        mode='lines',
        name=f"{type} (smoothed)",
        line=dict(color=color_dict[type], width=2)
    ))

# Update layout
fig.update_layout(
    title='Bedgraph Scatter Plot with Raw and Smoothed Data',
    xaxis_title='Position',
    yaxis_title='Score',
    legend_title='Type',
    hovermode='closest',
    template="plotly_white"
)

# Show the plot
fig.show()

# Output boxplots for each type of data
fig2 = go.Figure()

# Add boxplot for each type
for type in files.keys():
    df_type = combined_data[combined_data['type'] == type]
    fig2.add_trace(go.Box(
        y=df_type['score'],
        name=type,
        boxpoints='all',  # show all points
        jitter=0.3,  # spread them out so they all appear
        pointpos=-1.8,  # offset them to the left of the box
        marker_color=color_dict[type],  # Use color_dict instead of colors
    ))

# Update layout
fig2.update_layout(
    title='Box Plot of Scores by Type',
    yaxis_title='Score',
    legend_title='Type',
    template="plotly_white"
)

# Show the plot
fig2.show()

# Calculate correlation matrices
# Ensure all arrays are of the same length
min_length = min(len(arr) for arr in smoothed_data.values())
selected_keys = list(smoothed_data.keys())[:7]

for key in selected_keys:
    smoothed_data[key] = smoothed_data[key][:min_length]

smoothed_df = pd.DataFrame({key: smoothed_data[key] for key in selected_keys})
pearson_corr_matrix = smoothed_df.corr()
spearman_corr_matrix = smoothed_df.corr(method='spearman')

# Plot Pearson correlation matrix as a heatmap
fig3 = px.imshow(pearson_corr_matrix, text_auto=True, color_continuous_scale='Viridis')

fig3.update_layout(
    title='Pearson Correlation Matrix of Smoothed Scores (First 4 Keys)',
    xaxis_title='Type',
    yaxis_title='Type',
    template="plotly_white",
    #increase plot height
    height=800
)

# Show the plot
fig3.show()

# Plot Spearman correlation matrix as a heatmap
fig4 = px.imshow(spearman_corr_matrix, text_auto=True, color_continuous_scale='Viridis')

fig4.update_layout(
    title='Spearman Correlation Matrix of Smoothed Scores (First 4 Keys)',
    xaxis_title='Type',
    yaxis_title='Type',
    template="plotly_white",
    #increase plot height
    height=800
)

# Show the plot
fig4.show()

In [None]:

### Add bed and condition details to modkit output for plotting
# Using DataFrame merging to achieve the task without explicit loops
## Looks up the bed_start and bed_end values for each row in bedmethyl_df
def add_bed_columns_no_loops(bedmethyl_df_loc, combined_bed_df):
    # Calculate midpoint in combined_bed_df
    combined_bed_df['midpoint'] = (combined_bed_df['bed_start'] + combined_bed_df['bed_end']) / 2
    # Convert midpoint to the same type as start_position (int, in this case)
    combined_bed_df['midpoint'] = combined_bed_df['midpoint'].astype(int)

    combined_bed_df = combined_bed_df.sort_values(by='midpoint')

    # Ensure that start_position is of type int (if it's not already)
    bedmethyl_df_loc['start_position'] = bedmethyl_df_loc['start_position'].astype(int)

    # Merge bedmethyl_df with combined_bed_df based on the nearest midpoint
    merged_df = pd.merge_asof(bedmethyl_df_loc.sort_values('start_position'),
                              combined_bed_df,
                              by='chrom',
                              left_on='start_position',
                              right_on='midpoint',
                              direction='nearest')

    # Filter out rows where the start_position is not within the bed_start and bed_end range
    merged_df = merged_df.loc[(merged_df['start_position'] >= merged_df['bed_start']) &
                                (merged_df['start_position'] <= merged_df['bed_end'])]

    #reset index
    merged_df.reset_index(inplace=True, drop=True)

    # Create the final DataFrame by merging the merged DataFrame back to the original bedmethyl_df
    final_df = pd.merge(bedmethyl_df_loc,
                        merged_df[['chrom', 'start_position', 'bed_start', 'bed_end', 'bed_strand', 'type', 'chr_type']],
                        on=['chrom', 'start_position'],
                        how='left')

    # Drop all final_df rows where type == NaN
    final_df = final_df[final_df['type'].notna()]

    return final_df

combined_bed_df = nanotools.create_lookup_bed(new_bed_files)

# Initialize comb_bedmethyl_plot_df
comb_bedmethyl_df = pd.DataFrame()

# Create combined plotting dataframe
for each_output,each_condition,each_exp_id in zip(out_file_names,conditions,exp_ids):
    #print("Starting on:",each_output)
    # Define bed methyl columns and import bedmethyl file
    bedmethyl_df = pd.DataFrame()
    bedmethyl_cols = ['chrom','start_position','end_position','modified_base_code','score','strand','start_position_compat','end_position_compat','color','Nvalid_cov','fraction_modified','Nmod','Ncanonical','Nother_mod','Ndelete','Nfail','Ndiff','Nnocall']
    bedmethyl_df=pd.read_csv(each_output, sep="\t", header=None, names=bedmethyl_cols)
    # if bedmethyl_df is empty
    # drop all rows where modified_base_code is not equal to "a,A,0" or "m,GC,1"
    bedmethyl_df = bedmethyl_df[bedmethyl_df['modified_base_code'].isin(['a,A,0','m,GC,1'])]
    if bedmethyl_df.empty:
        print("!Read in empty csv!!")
        print("Tried to select:",each_output," ",each_condition," ",each_exp_id, "and failed...")
        continue


    # sort bedmethyl_df by chrom and start_position
    bedmethyl_df = bedmethyl_df.sort_values(['start_position'], ascending=[True])
    # drop any rows with a nan
    bedmethyl_df = bedmethyl_df.dropna()
    bedmethyl_df.drop_duplicates(inplace=True)
    bedmethyl_df.reset_index(inplace=True, drop=True)

    bedmethyl_df = add_bed_columns_no_loops(bedmethyl_df, combined_bed_df)
    # Add rel_start and rel_end columns equal to start-bed_start and end-bed_start

    # if type_selected contains 'gene', map to metagene bins
    if 'gene' in type_selected[0] or 'damID' in type_selected[0]:
        print("Mapping to metagene bins...")
        # Define a function to process each group
        def process_group(group_tuple, num_bins, edge_window_size, sum_columns):
            _, group = group_tuple
            min_pos, max_pos = group['start_position'].min(), group['start_position'].max()
            bed_start, bed_end = group['bed_start'].iloc[0], group['bed_end'].iloc[0]

            group['rel_start'] = np.where(group['start_position'] < bed_start + edge_window_size, # if start position is less than or equal to bed_start + edge_window_size then
                                          group['start_position'] - bed_start - edge_window_size, # shift rel_pos by bed_start and window size. Otherwise
                                          np.where(group['start_position'] > bed_end - edge_window_size, #if start position is greater than bed_end - edge_window_size then
                                                   num_bins + edge_window_size - (bed_end - group['start_position']), #assign to bin otherwise
                                                   100000)) # assign to nan

            # delete any points outside of window
            if (max_pos - min_pos) > (num_bins + 2 * edge_window_size):
                binning_mask = (group['start_position'] >= bed_start + edge_window_size) & (group['start_position'] <= bed_end - edge_window_size)
                bin_edges = np.linspace(bed_start + edge_window_size, bed_end - edge_window_size, num_bins + 1)
                group.loc[binning_mask, 'rel_start'] = np.digitize(group.loc[binning_mask, 'start_position'], bins=bin_edges, right=True)

            return group

        def map_to_metagene_bins_and_sum(df, num_bins=1000, edge_window_size=500):
            # Columns for summing within bins
            sum_columns = ['Nmod', 'Ncanonical', 'Nother_mod', 'Ndelete', 'Nfail', 'Ndiff', 'Nnocall','Nvalid_cov']
            # Columns to retain in the final DataFrame
            retain_columns = ['bed_strand', 'chr_type', 'strand', 'bed_end','type']
            # Adjust group columns based on the updated request
            group_columns = ['bed_start', 'chrom' ,'modified_base_code']

            # Splitting the DataFrame into groups
            groups = list(df.groupby(group_columns))

            # Using multiprocessing to process groups in parallel
            with Pool(500) as pool:
                processed_groups = pool.starmap(process_group, [(group, num_bins, edge_window_size, sum_columns) for group in groups])

            # Combine the processed groups into a single DataFrame
            result_df = pd.concat(processed_groups, ignore_index=True)

            # Summing within bins and merging
            sum_group_columns = group_columns + ['rel_start']
            summed_df = result_df.groupby(sum_group_columns)[sum_columns].sum()
            merged_df = pd.merge(result_df[sum_group_columns + retain_columns].drop_duplicates(), summed_df, on=sum_group_columns, how='left')

            return merged_df

        bedmethyl_df = map_to_metagene_bins_and_sum(bedmethyl_df, num_bins=num_bins, edge_window_size=bed_window)

    else:
        bedmethyl_df['rel_start'] = bedmethyl_df['start_position'] - bedmethyl_df['bed_start'] - bed_window +1

    # set rel_start to int
    bedmethyl_df['rel_start'] = bedmethyl_df['rel_start'].astype(int)

    #print("2. bedmethyl_df")
    #display(bedmethyl_df)
    bedmethyl_df['condition'] = each_condition
    bedmethyl_df['exp_id'] = each_exp_id
    # eliminate levels in dataframe

    # if bedmethyl_df is empty
    if bedmethyl_df.empty:
        print("!Bedmethyl_df is empty!")
        print("Tried to select:",each_output," ",each_condition," ",each_exp_id, "and failed...")
        continue

    # if comb_bedmethyl_plot_df is null, set it equal to bedmethyl_plot
    if comb_bedmethyl_df.empty:
        print("comb_bedmethyl_plot_df is empty, setting it equal to bedmethyl_plot...")
        comb_bedmethyl_df = bedmethyl_df
        #print("comb_bedmethyl_plot_df:",comb_bedmethyl_plot_df)
    # else append bedmethyl_plot to comb_bedmethyl_plot_df
    else:
        print("comb_bedmethyl_plot_df is not empty, appending bedmethyl_plot...")
        comb_bedmethyl_df = comb_bedmethyl_df.append(bedmethyl_df)
        #print("comb_bedmethyl_plot_df:",comb_bedmethyl_plot_df)

comb_bedmethyl_df.reset_index(inplace=True, drop=True)

#print("head")
#display(comb_bedmethyl_df.head(100))
print("sample")
display(nanotools.display_sample_rows(comb_bedmethyl_df,10))
#print("tail")
#display(comb_bedmethyl_df.tail(100))


In [None]:
### SHIFT AND TRANSFORM (OPTIONAL)
align_zero_bool = False
flip_bool = False
def compute_lag_for_maximum_alignment(series1, bed_start1): #,series2, bed_start2):
    """
    Decides flipping based on maximum cross-correlation, and then computes the lag
    required to align the maximum values of two series. Returns both the lag and the decision to flip.
    """
    """ if flip_bool:
         # Calculate the correlations without any shift
         correlation_original = np.correlate(series1, series2, mode='valid')
         correlation_flipped = np.correlate(series1, series2[::-1], mode='valid')

         # Decide the flip based on the correlation values
         original_max_correlation = correlation_original.max()
         flipped_max_correlation = correlation_flipped.max()
         # set flip to 1 if flipped_max_correlation > original_max_correlation otherwise set to 0
         flip = 1 if flipped_max_correlation > original_max_correlation else 0
     else:
         flip = 0

     # Depending on the flip decision, align based on max values in the series
     if flip == 1:
         pos_max_series1 = -np.argmax(series1)
     else:"""
    flip=0
    # print max value in the series
    # set all values > 9 to 0
    #series1[series1 > 9] = 0
    pos_max_series1 = np.argmax(series1)
    #print("len(series1):",len(series1))
    #print("max(series1):",max(series1))
    #print("argmax:",pos_max_series1)
    # print value at argmax
    #print("series1[pos_max_series1]:",series1[pos_max_series1])

    lag = (round(len(series1)/2))-pos_max_series1
    #print("lag:",lag)

    return (lag, flip)

def get_continuous_series(df_subset):
    # Create a Series with rel_start as the index and norm_mod_frac_weighted as the values
    series_filled = df_subset.set_index('rel_start')['weighted_norm_mod_frac']

    #print("series before filling:",series)
    # Fill NaNs using a rolling average
    #series_filled = series.rolling(50, min_periods=1,center=True).mean()
    # Fill any remaining NaNs at the start or end of the series using ffill or bfill
    series_filled = series_filled.fillna(method='ffill').fillna(method='bfill')

    # Ensure it's a continuous series by filling any gaps in rel_start
    try:
        series_filled = series_filled.reindex(range(int(series_filled.index.min()), int(series_filled.index.max()) + 1), fill_value=0)
    except:
        print("Failed series_filled:",series_filled)
        print("Duplicate indexes:",series_filled.index[series_filled.index.duplicated()])
    # For the newly introduced NaNs due to reindexing, we fill them again using a rolling average
    #rolling_avg_reindexed = series_filled.rolling(window=50, center=True, min_periods=1).mean()
    #series_filled = series_filled.fillna(rolling_avg_reindexed)
    #print("series after filling:",series_filled)
    return series_filled.values

def align_profiles(df):
    df = df.sort_values(['bed_start', 'rel_start']).copy()
    bed_starts = df['bed_start'].unique()

    # Determine the reference bed_start
    summed_Nvalid_cov = df.groupby('bed_start')['Nvalid_cov'].sum()
    reference_bed_start = summed_Nvalid_cov.idxmax()
    series_reference = get_continuous_series(df[df['bed_start'] == reference_bed_start])

    # Calculate the number of positions to shift
    shift_positions = int(round(len(series_reference)/2)) - np.argmax(series_reference)

    # Shift the entire series_reference by shift_positions to the left or right depending on the sign
    if shift_positions > 0:  # shift to the left
        series_reference = np.concatenate(([0]*shift_positions, series_reference))
    else:
        series_reference = np.concatenate((series_reference,[0]*shift_positions))

    #df["shift"] = 0
    df["flipped"] = 0

    for other_bed_start in bed_starts:
        #if other_bed_start == reference_bed_start:
        #    continue

        series_to_shift = get_continuous_series(df[df['bed_start'] == other_bed_start])
        # print every item in series_to_shift
        #for item in series_to_shift:
        lag, flip = compute_lag_for_maximum_alignment(series_to_shift, other_bed_start)#,series_reference, reference_bed_start)

        df.loc[df['bed_start'] == other_bed_start, 'shift'] = lag
        df.loc[df['bed_start'] == other_bed_start, 'flipped'] = 1 if flip else 0

        #print(f"Decision for bed_start {other_bed_start}: flipped={flip}, shift={lag}\n")

    # Calculate statistics using the 'flipped' and 'shift' columns
    total_flipped = df[df['flipped'] == 1]['bed_start'].nunique()
    lag_distribution = df['shift'].describe()

    print(f"Total bed_starts flipped: {total_flipped} out of {len(bed_starts) - 1}")
    print("Lag Distribution:")
    print(lag_distribution)

    return df

print("Copying and dropping rows...")
comb_bedmethyl_plot_df = comb_bedmethyl_df.copy()

# If aligning to 0
if align_zero_bool:
    final_df = comb_bedmethyl_plot_df.copy()

    final_df = comb_bedmethyl_plot_df.groupby(
    ['chrom', 'chr_type', 'rel_start', 'exp_id', 'condition', 'type', 'bed_start']).agg({
    'Nvalid_cov': 'sum',
    'Nmod': 'sum',
    'Ncanonical': 'sum',
    'Nother_mod': 'sum'
    }).reset_index()

    # Calculate normalized m6A
    final_df['raw_mod_frac'] = final_df['Nmod'] / (final_df['Nmod'] + final_df['Ncanonical'])

    # Merge operation
    final_df = pd.merge(
        final_df,
        coverage_df[['exp_id', 'm6A_frac']],
        on=['exp_id'],
        how='left'
    )

    # rename m6A_frac column to exp_id_m6A_frac
    final_df.rename(columns={'m6A_frac': 'exp_id_m6A_frac'}, inplace=True)

    # Calculate norm_mod_frac
    final_df['norm_mod_frac_init'] = final_df['raw_mod_frac'] / final_df['exp_id_m6A_frac']

    # 2. Reuse DataFrame
    final_df['norm_mod_frac_weighted'] = final_df['norm_mod_frac_init'] * (final_df['Nmod'] + final_df['Ncanonical'])

    # set bed_start to bed_start as a string + chrom as a string
    final_df['bed_start'] = final_df['bed_start'].astype(str) + "_" + final_df['chrom'].astype(str)

    # Group by and aggregation
    final_df = final_df.groupby(
        ['rel_start', 'condition', 'type', 'chr_type','bed_start']
    )[['Nvalid_cov', 'Ncanonical', 'Nmod', 'norm_mod_frac_weighted']].sum().reset_index()


    # Additional calculations
    final_df['weighted_norm_mod_frac'] = final_df['norm_mod_frac_weighted'] / (final_df['Nmod'] + final_df['Ncanonical'])
    final_df['raw_mod_frac'] = final_df['Nmod'] / (final_df['Nmod'] + final_df['Ncanonical'])

    # drop rows where Nvalid_cov is lower than lower quartile
    final_df = final_df[final_df['Nvalid_cov'] > final_df['Nvalid_cov'].quantile(0.1)]

    # Sorting and re-indexing
    final_df.sort_values(['bed_start', 'rel_start'], inplace=True)
    final_df.reset_index(inplace=True, drop=True)

    # print rows with duplicate rel_pos, bed_start, condition and type
    print("Duplicate rows:")
    display(final_df[final_df.duplicated(['rel_start','bed_start','condition','type'])].head(10))

    # Displaying the first 100 rows
    display(final_df.head(100))
    center_iter = 0
    for each_condition in final_df['condition'].unique():
        for each_type in final_df['type'].unique():
            # Filtering the data
            print("Starting on:",each_condition,each_type)
            final_df_cluster = final_df[(final_df['condition'] == each_condition) & (final_df['type'] == each_type)]
            #print("final_df_cluster:")
            #display(final_df_cluster.head(100))

            aligned_df = align_profiles(final_df_cluster)
            aligned_df.drop_duplicates(subset=['bed_start'], inplace=True)

            final_df_cluster = comb_bedmethyl_df[(comb_bedmethyl_df['condition'] == each_condition) & (comb_bedmethyl_df['type'] == each_type)].groupby(
                ['chrom', 'rel_start', 'exp_id', 'modified_base_code', 'condition', 'type', 'chr_type', 'bed_start']
            ).agg({
                'Nvalid_cov': 'sum',
                'Nmod': 'sum',
                'Ncanonical': 'sum',
                'Nother_mod': 'sum'
            }).reset_index()

            # set bed_start to bed_start as a string + chrom as a string
            final_df_cluster['bed_start'] = final_df_cluster['bed_start'].astype(str) + "_" + final_df_cluster['chrom'].astype(str)

            # Calculate normalized m6A
            final_df_cluster['raw_mod_frac'] = final_df_cluster['Nmod'] / (final_df_cluster['Nmod'] + final_df_cluster['Ncanonical'])

            # Merge operation
            final_df_cluster = pd.merge(
                final_df_cluster,
                coverage_df[['exp_id', 'm6A_frac']],
                on=['exp_id'],
                how='left'
            )

            # rename m6A_frac column to exp_id_m6A_frac
            final_df_cluster.rename(columns={'m6A_frac': 'exp_id_m6A_frac'}, inplace=True)

            # Calculate norm_mod_frac
            final_df_cluster['norm_mod_frac_init'] = final_df_cluster['raw_mod_frac'] / final_df_cluster['exp_id_m6A_frac']

            # 2. Reuse DataFrame
            final_df_cluster['norm_mod_frac_weighted'] = final_df_cluster['norm_mod_frac_init'] * (final_df_cluster['Nmod'] + final_df_cluster['Ncanonical'])

            ### Since multiple samples have same condition:
            #merge final_df with aligned_df on bed_start adding shift and flipped columns
            print("Merging final_df_cluster with aligned_df...")
            #display(final_df_cluster.head(10))
            #display(aligned_df.head(10))

            final_df_cluster = pd.merge(final_df_cluster, aligned_df[['bed_start','shift','flipped']], on=['bed_start'], how='left')
            final_df_cluster.loc[final_df_cluster['flipped'] == 1, 'rel_start'] *= -1
            # add shift to rel_pos
            final_df_cluster['rel_start'] += final_df_cluster['shift']
            final_df_cluster = final_df_cluster.groupby(['rel_start','modified_base_code','condition','type','chr_type'])[['Nvalid_cov','Ncanonical','Nmod','norm_mod_frac_weighted']].sum() #,'strand'
            final_df_cluster.reset_index(inplace=True)

            # set norm_mod_frac to norm_mod_frac_weighted / Nvalid_cov
            final_df_cluster['weighted_norm_mod_frac'] = final_df_cluster['norm_mod_frac_weighted']/(final_df_cluster['Nmod']+final_df_cluster['Ncanonical'])
            final_df_cluster['raw_mod_frac'] = final_df_cluster['Nmod']/(final_df_cluster['Nmod']+final_df_cluster['Ncanonical'])
            #sort by rel_start
            final_df_cluster.sort_values(['rel_start'], inplace=True)
            final_df_cluster.reset_index(inplace=True, drop=True)

            if center_iter == 0:
                plot_df = final_df_cluster.copy()
            else:
                plot_df = plot_df.append(final_df_cluster)
            center_iter += 1

else:
    # FOR GENES If bed_strand is -, multiply rel_start by -1, and sort by rel_start resetting index afterwards
    for each_type in type_selected:
        # if each_type contains substring "TSS", "TES", "MEX", then flip only those genes
        if any(x in each_type for x in ["TSS", "TES", "MEX", "MEXII", "gene"]):
            print(f"Strand orientation sensitive {each_type} type selected, multiplying rel_start by -1 for '-' strand genes...")
            if 'gene' in each_type:
                # subtract 1/2 of num_bins from rel_start for metagene profiles, so they are centered
                comb_bedmethyl_plot_df['rel_start'] -= num_bins/2
            # Mask comb_bedmethyl_plot_df by type == each_type and strand == '-'
            mask = (comb_bedmethyl_plot_df['type'] == each_type) & (comb_bedmethyl_plot_df['bed_strand'] == '-')
            # Multiply rel_start by -1  for all rows where mask is true
            comb_bedmethyl_plot_df.loc[mask, 'rel_start'] *= -1

            # DO the same for bigwig lines:
            #if ext_target is not empty list []:
            if ext_target != []:
                mask = (bw_df['type'] == each_type) & (bw_df['bed_strand'] == '-')
                # Multiply rel_start by -1  for all rows where mask is true
                bw_df.loc[mask, 'rel_start'] *= -1
            if 'gene' in each_type:
                # subtract 1/2 of num_bins from rel_start
                comb_bedmethyl_plot_df['rel_start'] += num_bins/2

    print("Grouping by chrom, rel_start, exp_id, modified_base_code, condition, type, chr_type, bed_start...")
    # Group comb_bedmethyl_plot_df and sum specific columns
    grouped_df = comb_bedmethyl_plot_df.groupby(['chrom', 'rel_start','exp_id','modified_base_code','condition','type','chr_type']).agg({ #,'strand'
        'Nvalid_cov': 'sum',
        'Nmod': 'sum',
        'Ncanonical': 'sum',
        'Nother_mod': 'sum'
    }).reset_index()

    print("Calculating normalized m6A...")
    ### Calculate normalized m6A
    #grouped_df['mod_frac'] = grouped_df['Nmod'] / grouped_df['Nvalid_cov']
    grouped_df['raw_mod_frac'] = grouped_df['Nmod'] / (grouped_df['Nmod'] + grouped_df['Ncanonical'])
    # Merge the two dataframes based on 'exp_id' and 'condition'

    coverage_df['exp_id'] = coverage_df['exp_id'].str.strip()
    grouped_df['exp_id'] = grouped_df['exp_id'].str.strip()
    nanotools.display_sample_rows(coverage_df,10)
    nanotools.display_sample_rows(grouped_df,10)
    merged_df = pd.merge(grouped_df, coverage_df[['exp_id', 'm6A_frac']],
                         on=['exp_id'], how='left')
    nanotools.display_sample_rows(merged_df,10)
    # rename m6A_frac column to exp_id_m6A_frac
    merged_df.rename(columns={'m6A_frac': 'exp_id_m6A_frac'}, inplace=True)
    # Calculate norm_mod_frac
    merged_df['norm_mod_frac_init'] = merged_df['raw_mod_frac'] / merged_df['exp_id_m6A_frac']
    # If you want to keep only the original columns plus the new 'norm_mod_frac'
    plot_df = merged_df[grouped_df.columns.tolist() + ['norm_mod_frac_init']]
    # Calculate norm_mod_frac_weighted
    plot_df['norm_mod_weighted'] = plot_df['norm_mod_frac_init'] * (grouped_df['Nmod'] + grouped_df['Ncanonical'])

    ### Since multiple samples have same condition:
    plot_df = plot_df.groupby(['rel_start','modified_base_code','condition','type','chr_type'])[['Nvalid_cov','Ncanonical','Nmod','norm_mod_weighted']].sum() #,'strand'
    plot_df.reset_index(inplace=True)

    if ext_target != []:
        plot_comb_bigwig_df = bw_df.groupby(['rel_start','chrom','condition','type','chr_type'])['value'].mean().reset_index()

    else:
        plot_comb_bigwig_df = pd.DataFrame()
    # set norm_mod_frac to norm_mod_frac_weighted / Nvalid_cov
    plot_df['weighted_norm_mod_frac'] = plot_df['norm_mod_weighted']/(plot_df['Nmod']+plot_df['Ncanonical'])
    plot_df['raw_mod_frac'] = plot_df['Nmod']/(plot_df['Nmod']+plot_df['Ncanonical'])
    #sort by rel_start
    plot_df.sort_values(['rel_start'], inplace=True)
    plot_df.reset_index(inplace=True, drop=True)

print("plot_df:")
# display random 100 rows
nanotools.display_sample_rows(plot_df,10)
if ext_target != []:
    nanotools.display_sample_rows(plot_comb_bigwig_df,10)

In [None]:
### Plot correlation plot between replicates
merged_df_correlation = merged_df.copy()
merged_df_correlation["exp_condition_id"] = merged_df_correlation["exp_id"] + "_" + merged_df_correlation["condition"]
# Define a binning function to bin every 10 bases
def binning_func(x):
    return np.floor(x / 50) * 50

# Apply the binning function to the start_position to create binned_start_position
merged_df_correlation['binned_start_position'] = merged_df_correlation['start_position'].apply(binning_func)
nanotools.display_sample_rows(merged_df_correlation)
# Group by 'chrom', 'binned_start_position', and 'exp_condition_id', and then sum up 'Nvalid_cov' and 'Nmod'
binned_df = merged_df_correlation.groupby(['chrom', 'binned_start_position', 'exp_condition_id','exp_id_m6A_frac']).agg({
    'Nvalid_cov': 'sum',
    'Nmod': 'sum'
}).reset_index()

# Calculate the new 'm6A_frac' as the ratio of 'Nmod' to 'Nvalid_cov'
binned_df['m6A_frac'] = binned_df['Nmod'] / binned_df['Nvalid_cov']

# add norm_mod_frac column
binned_df['norm_mod_frac'] = binned_df['m6A_frac'] / binned_df['exp_id_m6A_frac']
nanotools.display_sample_rows(binned_df)
binned_df['transformed_mod_frac'] = np.arcsin(np.sqrt(binned_df['m6A_frac']))

# Pivot the DataFrame with the new binned positions and calculated 'm6A_frac'
pivoted_df = binned_df.pivot_table(
    index=['chrom', 'binned_start_position'],
    columns='exp_condition_id',
    values='transformed_mod_frac'
)

nanotools.display_sample_rows(pivoted_df)

# Step 3: Calculate the Pearson correlation coefficient matrix
correlation_matrix = pivoted_df.corr(method='pearson')

nanotools.display_sample_rows(correlation_matrix)

# Step 4: Square the correlation coefficients to obtain r² values
r_squared_matrix = correlation_matrix ** 2

nanotools.display_sample_rows(r_squared_matrix)

# Create a heatmap using plotly.graph_objects
fig = go.Figure(data=go.Heatmap(
    z=r_squared_matrix.values,
    x=r_squared_matrix.columns,
    y=r_squared_matrix.index,
    colorscale='Oranges'))

# Update the layout to use the plotly_white template and adjust the title
fig.update_layout(
    template='plotly_white',
    title='Pearson r² Values Heatmap'
)

# Show the figure in a Jupyter environment or it can be saved to an HTML file using fig.write_html('heatmap.html')
fig.show()


In [None]:
force_replace = True
# save final_df to /temp folder as csv, with all configurations in file name if it does not exist. If it does exist, import it.
final_fn = "temp_files/" + "final_df_" + "_".join([each_type for each_type in type_selected]) + str(round(thresh_list[0],2)) + "_"+str(bam_fracs[0])+str(bed_window)+".csv"
final_fn_chip = "temp_files/" + "final_df_chip" + "_".join([each_type for each_type in type_selected]) + str(round(thresh_list[0],2)) + "_"+str(bam_fracs[0])+str(bed_window)+".csv"

if not force_replace and os.path.exists(final_fn):
    print("final_df already exists, importing it...")
    plot_df = pd.read_csv(final_fn)
    nanotools.display_sample_rows(plot_df,5)
else:
    print("final_df does not exist, saving it...")
    plot_df.to_csv(final_fn, index=False)

# if plot_comb_bigwig_df dataframe does not exist:
try:
    if not force_replace and os.path.exists(final_fn_chip):
        print("final_df_chip already exists, importing it...")
        plot_comb_bigwig_df = pd.read_csv(final_fn_chip)
        nanotools.display_sample_rows(plot_comb_bigwig_df,5)
    else:
        print("final_df_chip does not exist, saving it...")
        plot_comb_bigwig_df.to_csv(final_fn_chip, index=False)
except:
    print("plot_comb_bigwig_df does not exist, skipping...")

In [None]:
# Base folder path
base_folder = "/Data1/seq_data/BM_N2_old_Fiber_Hia5_MCVIPI_05_30_24/no_sample/20240530_1930_X2_FAX32001_56dbbc37/basecalls/"

# File names
files = {
    'a_positive': 'a_A0_positive.bedgraph',
    'a_negative': 'a_A0_negative.bedgraph',
    'm_positive': 'm_GC1_positive.bedgraph',
    'm_negative': 'm_GC1_negative.bedgraph'
}

def load_data(file_path):
    df = pd.read_csv(file_path, sep='\t', header=None, names=['chromosome', 'start', 'end', 'score', 'coverage'])
    df = df[(df['chromosome'] == 'CHROMOSOME_I') & (df['start'] >= 1000000) & (df['start'] < 1010000)]
    # Ensure 'start' is unique by aggregating scores if necessary
    df = df.groupby('start').agg({'score': 'mean', 'coverage': 'sum'}).reset_index()
    return df

def weighted_average(x):
    # Handle numpy array
    x = x[~np.isnan(x)]  # Remove NA values
    if len(x) == 0:
        return np.nan
    weights = np.arange(1, len(x) + 1)
    return np.average(x, weights=weights)

def apply_weighted_rolling_average(series, window=10):
    return series.rolling(window=window, center=True, min_periods=1).apply(weighted_average, raw=True)

# Load and process all files
data = {}
for key, file_name in files.items():
    file_path = os.path.join(base_folder, file_name)
    data[key] = load_data(file_path)
    print(f"Loaded {key}: {len(data[key])} rows")  # Debug print

# Create a complete range of start positions
all_starts = pd.DataFrame({'start': range(1000000, 1010000)})

# Combine all data into one DataFrame, keeping NaN values
combined_data = all_starts.copy()
for key, df in data.items():
    type, strand = key.split('_')
    combined_data = combined_data.merge(
        df[['start', 'score']].rename(columns={'score': f'score_{type}_{strand}'}),
        on='start', how='left'
    )

print("Combined data shape:", combined_data.shape)  # Debug print
print("Combined data non-null counts:\n", combined_data.notnull().sum())  # Debug print

# Melt the DataFrame to long format
combined_data_long = pd.melt(
    combined_data,
    id_vars=['start'],
    value_vars=[col for col in combined_data.columns if col.startswith('score')],
    var_name='type_strand',
    value_name='score'
)
combined_data_long[['type', 'strand']] = combined_data_long['type_strand'].str.split('_', expand=True).iloc[:, 1:]
combined_data_long = combined_data_long.drop('type_strand', axis=1)

print("Long data shape:", combined_data_long.shape)  # Debug print
print("Long data non-null counts:\n", combined_data_long.notnull().sum())  # Debug print

# Create scatter plot
fig = go.Figure()

colors = {'a': 'red', 'm': 'blue'}

for type in ['a', 'm']:
    df_type = combined_data_long[combined_data_long['type'] == type]

    # Sort dataframe by 'start' to ensure correct line connections
    df_type = df_type.sort_values('start')

    # Set 'start' as index for smoothing
    df_type = df_type.set_index('start')

    # Apply weighted rolling average
    df_type['smoothed'] = apply_weighted_rolling_average(df_type['score'])

    print(f"Type {type} data points: {len(df_type)}")  # Debug print
    print(f"Type {type} non-null counts:\n", df_type.notnull().sum())  # Debug print

    # Plot raw data line
    fig.add_trace(go.Scatter(
        x=df_type.index,
        y=df_type['score'],
        mode='markers+lines',
        name=f"{type} (raw)",
        line=dict(color=colors[type], width=1),
        legendgroup=type,
        showlegend=True
    ))

    # Plot smoothed line
    fig.add_trace(go.Scatter(
        x=df_type.index,
        y=df_type['smoothed'],
        mode='lines',
        name=f"{type} (smoothed)",
        line=dict(color=colors[type], width=2, dash='dash'),
        legendgroup=type,
        showlegend=True
    ))

# Update layout
fig.update_layout(
    title='Bedgraph Scatter Plot with Raw and Weighted Rolling Average Smoothed Lines',
    xaxis_title='Position',
    yaxis_title='Score',
    legend_title='Type',
    hovermode='closest',
    template="simple_white"
)

# Show the plot
fig.show()

# Print statistics to verify zero and NA handling
for type in ['a', 'm']:
    df_type = combined_data_long[combined_data_long['type'] == type]
    total_count = len(df_type)
    na_count = df_type['score'].isna().sum()
    zero_count = (df_type['score'] == 0).sum()
    non_zero_count = ((df_type['score'] != 0) & (~df_type['score'].isna())).sum()

    print(f"Type {type}:")
    print(f"  Total data points: {total_count}")
    print(f"  NA values: {na_count} ({na_count/total_count*100:.2f}%)")
    print(f"  Zero values: {zero_count} ({zero_count/total_count*100:.2f}%)")
    print(f"  Non-zero values: {non_zero_count} ({non_zero_count/total_count*100:.2f}%)")
    print(f"  Min non-zero value: {df_type['score'][df_type['score'] != 0].min()}")
    print(f"  Max value: {df_type['score'].max()}")
    print()

In [None]:

def filter_data(df, condition, chr_type, type_,strand):
    filters = []
    if condition:
        filters.append(df['condition'] == condition)
    if chr_type != "all":
        filters.append(df['chr_type'] == chr_type)
    if type_ != "all":
        filters.append(df['type'] == type_)
    if strand != "all":
        filters.append(df['strand'] == strand)

    base_filter = pd.concat(filters).groupby(level=0).all()
    return df.loc[base_filter]

def plot_bedmethyl_diff(bed_window,final_df, conditions, window_size=25, *args):
    fig = go.Figure()
    fig.update_layout(
        title='m6A Fraction Difference vs Genomic Position N2 SDC3 - N2 intergenic control',
        xaxis_title='Genomic Position',
        yaxis_title='% change in norm m6A/A',
        template="plotly_white",
        width=1000,
        height=600,
        #increase font size
        font=dict(
            size=14
        ),
        # set y axis to % with rounded to nearest int
        yaxis_tickformat = '.0%'

    )
    # Update to place legend at the bottom
    fig.update_layout(legend=dict(
        y=-0.4,
        x=0.25
    ))

    #Shift y axis labels left
    # Skip the first tick label
    """x_min = -20
    x_max = 20
    tickvals = list([(x)/100 for x in range(x_min,x_max,5)])
    ticktext = [(str(round(x*100))+"%") for x in tickvals]  # Empty string for the first tick label
    fig.update_yaxes(tickvals=tickvals, ticktext=ticktext)
    # set y axis min and max
    fig.update_yaxes(range=[min(tickvals)-0.025,max(tickvals)+0.025])

    # Add vertical dashed line at x=0
    fig.add_shape(
        type="line", line=dict(dash="dash"),
        x0=0, x1=0, y0=min(tickvals)-0.025, y1=max(tickvals), line_color="Grey"
    )"""


    fig.update_xaxes(range=[-bed_window, bed_window])
    #fig.update_yaxes(range=[-0.1,0.4])

    for (selection_index1, chr_type1, type1,strand, selection_index2, chr_type2, type2,strand) in args:
        condition1 = conditions[selection_index1]
        condition2 = conditions[selection_index2]

        df1 = filter_data(final_df, condition1, chr_type1, type1,strand)
        df2 = filter_data(final_df, condition2, chr_type2, type2,strand)

        df1.reset_index(drop=True, inplace=True)
        df2.reset_index(drop=True, inplace=True)

        """#Drop outlier weighted_norm_mod_frac datapoints from df1 and df2 more than 3 standard deviations away from the mean
        # Calculate mean and standard deviation for the column 'weighted_norm_mod_frac' in df1
        mean1 = df1['weighted_norm_mod_frac'].mean()
        std1 = df1['weighted_norm_mod_frac'].std()

        # Drop outliers in df1
        df1 = df1[(df1['weighted_norm_mod_frac'] >= mean1 - 6 * std1) &
                  (df1['weighted_norm_mod_frac'] <= mean1 + 6 * std1)]

        # Calculate mean and standard deviation for the column 'weighted_norm_mod_frac' in df2
        mean2 = df2['weighted_norm_mod_frac'].mean()
        std2 = df2['weighted_norm_mod_frac'].std()

        # Drop outliers in df2
        df2 = df2[(df2['weighted_norm_mod_frac'] >= mean2 - 6 * std2) &
                  (df2['weighted_norm_mod_frac'] <= mean2 + 6 * std2)]"""

        def weighted_average(sub_df):
            weights = sub_df['Nvalid_cov']
            values = sub_df['weighted_norm_mod_frac']
            if weights.sum() == 0:
                return np.nan
            return np.average(values, weights=weights)
        # smooth df1 weighted_norm_mod_frac using a rolling average centered, weighted on Nvalid_cov column
        df1['weighted_norm_mod_frac_smooth'] = df1.apply(lambda row: weighted_average(df1.loc[row.name - window_size // 2 : row.name + window_size // 2]), axis=1)

        df2['weighted_norm_mod_frac_smooth'] = df2.apply(lambda row: weighted_average(df2.loc[row.name - window_size // 2 : row.name + window_size // 2]), axis=1)
        diff_data = (df1['weighted_norm_mod_frac_smooth'] - df2['weighted_norm_mod_frac_smooth'])/df1['weighted_norm_mod_frac_smooth']
        diff_data_xaxis = df1['rel_start']

        # combine diff_data and diff_data_xaxis into a dataframe and display
        diff_df = pd.concat([diff_data_xaxis, diff_data], axis=1)
        diff_df.columns = ['rel_start', 'diff_data']
        # display diff_df between -100 and 100
        #display(diff_df[(diff_df['rel_start'] >= -50) & (diff_df['rel_start'] <= 50)])

        #smoothed_data = diff_data.rolling(window=window_size, center=True).mean()

        label = f"Diff_{condition1}_{chr_type1}_{type1} - {condition2}_{chr_type2}_{type2}"
        print(label)

        fig.add_trace(go.Scatter(
            name=label,
            x=diff_data_xaxis.values,
            y=diff_data.values,
            mode='lines',
            # make color of lines shades of grey
            #line=dict(color='grey', width=2)
        ))
        # set x axis min and max using bed_window
        fig.update_xaxes(range=[-bed_window, bed_window])

    fig.show(renderer='plotly_mimetype+notebook')
    return fig,label

# capture fig and label

diff_fig = plot_bedmethyl_diff(1000, plot_df, conditions, 200,
                               (8,"X",type_selected[0],"all",1,"X",type_selected[0],"all"))
                               #(8,"X",type_selected[0],"all",8,"X",type_selected[1],"all"),
                               #(8,"X",type_selected[0],"all",1,"X",type_selected[0],"all"),
                               #(8,"X",type_selected[1],"all",1,"X",type_selected[1],"all"))
                               #(8,"X",type_selected[1],1,"X",type_selected[1]))
                    #(8,"X","all",1,"X","all"))
                    #(8,"X","strong_rex",1,"X","strong_rex"),
                    #(8,"X","weak_rex",1,"X","weak_rex"))
                    #(8,"X","center_SDC3_chip_albretton",1,"X","center_SDC3_chip_albretton"),
                    #(8,"Autosome","center_SDC3_chip_albretton",1,"Autosome","center_SDC3_chip_albretton"))
                    #(1,"Autosome","center_SDC3_chip_albretton",1,"X","center_SDC3_chip_albretton"),
                    #(8,"Autosome","center_SDC3_chip_albretton",8,"X","center_SDC3_chip_albretton"))
                    #(1, "X", "TSS_q4", 1, "Autosome", "TSS_q4"),
                    #(8, "X", "TSS_q4", 8, "Autosome", "TSS_q4"))
                    #(8, "X", "TSS_q4", 1, "X", "TSS_q4"),
                    #(8, "Autosome", "TSS_q4", 1, "Autosome", "TSS_q4"))
                    #(8, "X", "TSS_q3", 1, "X", "TSS_q3"),
                    #(8, "X", "TSS_q2", 1, "X", "TSS_q2"),
                    #(8, "X", "TSS_q1", 1, "X", "TSS_q1"))

diff_fig[0].write_image("images_11_14_23/"+region_fig[1]+"sdc2degron_minus_N2_strong_rex_1000_centered.svg")
diff_fig[0].write_image("images_11_14_23/"+region_fig[1]+"sdc2degron_minus_N2_strong_rex_1000_centered.png")


In [None]:
force_replace = False

### Extracting per read modifications
out_file_names = [output_stem + "modkit-extract-" + each_condition +"_"+ str(round(each_thresh,2))+"_"+str(each_index)+ "_"+str(each_bamfrac)+"_"+str(bed_window)+
                  # convert the first 3 characters of each element in "type_selected" into a single string separated by "-"
                  "-".join([str(x)[0:7] for x in type_selected])+"_"+
                  # convert first character and the last 3 characters of each element in "choromosome_selected" into a single string separated by "-"
                  "-".join([str(x)[0]+str(x)[-3:] for x in chromosome_selected])+"_"+
                  ".bed"
                  for each_condition,each_thresh,each_index, each_bamfrac in zip(conditions,thresh_list,sample_indices,bam_fracs)]

modkit_bed_df = pd.read_csv(modkit_bed_name,sep='\t',header=None)
### Define bed file for modkit

# Function to run a single extract command
def modkit_extract(args):
    each_bam, each_thresh, each_condition, each_index, each_bamfrac,modkit_path, output_stem, modkit_bed_name, bed_window = args

    each_output = output_stem + "modkit-extract-" + each_condition +"_"+ str(round(each_thresh,2))+"_"+str(each_index)+ "_"+str(each_bamfrac)+"_"+str(bed_window)+ "-".join([str(x)[0:7] for x in type_selected])+"_"+ "-".join([str(x)[0]+str(x)[-3:] for x in chromosome_selected])+"_"+ ".bed"

    ### NOTE: Name of pileup file is not based on configurations
    ### TODO: Name of output file should be based on configs so that we aren't recomputing pileups withidentical conditions.

    # If each_output exsits, skip
    if not force_replace and os.path.exists(each_output):
        print(f"Skipping: {each_output}")
        return

    print(f"Starting on: {each_bam}")
    command = [
        modkit_path,
        "extract",
        "--threads",
        "128",
        "--force",
        "--mapped",
        #"--ignore",
        #"m",
        "--include-bed",
        modkit_bed_name,
        "--log-filepath",
        each_output + each_condition + "_modkit-extract.log",
        each_bam,
        each_output
    ]
    subprocess.run(command, text=True)

    # Create a list of arguments for each task
task_args = list(zip(
    new_bam_files,
    thresh_list,
    conditions,
    sample_indices,
    bam_fracs,
    [modkit_path]*len(new_bam_files),
    [output_stem]*len(new_bam_files),
    [modkit_bed_name]*len(new_bam_files),
    [bed_window]*len(new_bam_files)
))

# Execute commands in parallel
with Pool() as pool:
    pool.map(modkit_extract, task_args)

print("finished with:")
print(out_file_names)

In [None]:
### Add bed and condition details to modkit output for plotting
# Using DataFrame merging to achieve the task without explicit loops
## Looks up the bed_start and bed_end values for each row in bedmethyl_df
def add_bed_columns_no_loops(bedmethyl_df, bed_df):

    # Initialize an empty DataFrame to store the merged data
    merged_data = pd.DataFrame()
    filtered_df = pd.DataFrame()

    # drop all rows from bedmethly_df where mod_strand != '+' or '-'
    bedmethyl_df = bedmethyl_df[(bedmethyl_df['mod_strand'] == '+') | (bedmethyl_df['mod_strand'] == '-')]
    # Get all unique chromosome-strand combinations without headers in bedmethyl_df
    unique_chrom_strand_combinations = bedmethyl_df[['chrom']].drop_duplicates().values #, 'mod_strand'
    for chrom in unique_chrom_strand_combinations: #, strand
        #print("Starting on:", chrom, strand)
        bed_subset = bed_df[(bed_df['chrom'] == chrom[0])] #& (bed_df['bed_strand'] == strand)]

        if bed_subset.empty:
            #print("No bed entries for {}".format(chrom[0]))
            # then skip this chromosome-strand combination
            continue
        #display(bed_subset.head(10))
        # Subset data for the current chromosome
        bedmethyl_subset = bedmethyl_df[(bedmethyl_df['chrom'] == chrom[0])]
        #print("bedmethyl_subset:",bedmethyl_subset.head(10))
        #display(bedmethyl_subset.head(10))
        # Explicitly cast to numeric data type
        bedmethyl_subset['ref_position'] = bedmethyl_subset['ref_position'].astype('int64')
        bed_subset['bed_start'] = bed_subset['bed_start'].astype('int64')

        # Perform the merge for the current subset
        merged_subset = pd.merge_asof(
            bedmethyl_subset,
            bed_subset[['bed_start', 'bed_end', 'bed_strand','chr_type','type']],  # Exclude 'chrom' from right DF
            left_on='ref_position',
            right_on='bed_start',
            direction='nearest'
        )

        # Append the merged data to the overall result
        merged_data = merged_data.append(merged_subset)

    #print("bedmethyl_df:")
    #display(merged_data.head(10))
    # Filter out rows where ref_position is not within bed_start and bed_end
    filtered_df = merged_data.loc[
        (merged_data['ref_position'] >= merged_data['bed_start']) &
        (merged_data['ref_position'] <= merged_data['bed_end'])
    ]
    # Filter out rows where bed_start is not a number
    filtered_df = filtered_df[filtered_df['bed_start'].notna()]
    filtered_df.reset_index(drop=True, inplace=True)

    # Print number rows in filtered_df
    print("Found {} rows in filtered_df".format(len(filtered_df)))

    #display(filtered_df)
    return filtered_df

# Initialize comb_readmethyl_df as an empty dataframe
comb_readmethyl_df = pd.DataFrame()

# Build combined bed df
combined_bed_df=pd.DataFrame()
for each_bed in new_bed_files:
    bed_path = each_bed[:-3] # remove .gz
    print("Starting on:",bed_path)
    # read bedpath and append to bed_df
    combined_bed_df = combined_bed_df.append(pd.read_csv(bed_path, sep="\t", header=None))#skiprows=1))
combined_bed_df.columns = ['chrom','bed_start','bed_end','bed_strand','type','chr_type']

combined_bed_df.sort_values(['chrom','bed_start'], inplace=True)
combined_bed_df.reset_index(drop=True, inplace=True)
display(combined_bed_df.head(10))

# Initialize comb_bedmethyl_plot_df
comb_bedmethyl_plot_df = pd.DataFrame()

# Create combined plotting dataframe
for each_output,each_condition in zip(out_file_names,conditions):
    print("Starting on:",each_output)
    # Define bed methyl columns and import bedmethyl file
    bedmethyl_df = pd.DataFrame()
    bedmethyl_cols = ['read_id',
    'forward_read_position',
    'ref_position',
    'chrom',
    'mod_strand',
    'ref_strand',
    'ref_mod_strand',
    'fw_soft_clipped_start',
    'fw_soft_clipped_end',
    'read_length',
    'mod_qual',
    'mod_code',
    'base_qual',
    'ref_kmer',
    'query_kmer',
    'canonical_base',
    'modified_primary_base',
    'inferred']
    bedmethyl_df=pd.read_csv(each_output, sep="\t", header=None, names=bedmethyl_cols,skiprows=1)

    # sort bedmethyl_df by chrom and start_position, required for merging in bed matching.
    bedmethyl_df.sort_values(['chrom','ref_position'], inplace=True)
    bedmethyl_df.reset_index(drop=True, inplace=True)
    #display(bedmethyl_df.head(10))

    # Adding new columns to bedmethyl_df using condition matching
    bedmethyl_df = add_bed_columns_no_loops(bedmethyl_df.copy(), combined_bed_df)
    # Add rel_start and rel_end columns equal to start-bed_start and end-bed_start
    bedmethyl_df['rel_pos'] = bedmethyl_df['ref_position'] - bedmethyl_df['bed_start'] -bed_window +1

    # add condition column
    bedmethyl_df['condition'] = each_condition

    # if comb_bedmethyl_plot_df is null, set it equal to bedmethyl_plot
    if comb_bedmethyl_plot_df.empty:
        print("comb_bedmethyl_plot_df is empty, setting it equal to bedmethyl_plot...")
        comb_bedmethyl_plot_df = bedmethyl_df.copy()
    # else append bedmethyl_plot to comb_bedmethyl_plot_df
    else:
        print("comb_bedmethyl_plot_df is not empty, appending bedmethyl_plot...")
        comb_bedmethyl_plot_df = comb_bedmethyl_plot_df.append(bedmethyl_df)

if any(x in type_selected[0] for x in ("TSS", "TES", "MEX")):
    print("Strand orientation sensitive type selected, multiplying rel_start by -1 for '-' strand genes...")
    mask = comb_bedmethyl_plot_df['bed_strand'] == '-'
    comb_bedmethyl_plot_df.loc[mask, 'rel_pos'] *= -1
    #comb_bedmethyl_plot_df.sort_values(['chrom','rel_pos'], inplace=True)

# Set entire mod_qual column equal to 1 base on m6A thresh
threshold = m6A_thresh / 255
# convert comb_bedmethly_plot_df_final['mod_qual'] to a float
comb_bedmethyl_plot_df['mod_qual'] = comb_bedmethyl_plot_df['mod_qual'].astype(float)
comb_bedmethyl_plot_df['mod_qual_bin'] = np.where(comb_bedmethyl_plot_df['mod_qual'] > threshold, 1, 0)

#add read_start and read_end columns to comb_bedmethyl_plot_df_final based on min and max ref_position for each read_id
print("Calculating read start and end...")
# Reduce the DataFrame size by selecting only the columns you need
small_df = comb_bedmethyl_plot_df[['read_id', 'rel_pos']]
# Perform the grouping and aggregation in one step
grouped_df = small_df.groupby('read_id')['rel_pos'].agg(['min', 'max']).reset_index()
# Rename the columns
grouped_df.rename(columns={'min': 'rel_read_start', 'max': 'rel_read_end'}, inplace=True)
# Merge the aggregated results back to the original DataFrame
comb_bedmethyl_plot_df = pd.merge(comb_bedmethyl_plot_df, grouped_df, on='read_id', how='left')
# delete small_df and grouped_df
del small_df
del grouped_df

#comb_bedmethyl_plot_df.sort_values(['chrom','rel_pos'], inplace=True)
comb_bedmethyl_plot_df = comb_bedmethyl_plot_df[comb_bedmethyl_plot_df['bed_start'].notna()]
comb_bedmethyl_plot_df.reset_index(inplace=True, drop=True)

# print count of unique read_ids for each condition
print("All unique read_ids for each condition:")
print(comb_bedmethyl_plot_df.groupby(['condition'])['read_id'].nunique())

#display(comb_bedmethyl_plot_df.head(10))
# drop all rows where read_length is less than 1000
comb_bedmethyl_plot_df = comb_bedmethyl_plot_df[comb_bedmethyl_plot_df['read_length'] >= 500]


# print count of unique read_ids for each condition
print("All unique read_ids for each condition > read_len of 1000:")
print(comb_bedmethyl_plot_df.groupby(['condition'])['read_id'].nunique())

# Define SQLite database file name
#db_fn = "temp_files/" + "plot_db_" + "-".join([str(x)[:6] for x in type_selected]) + "_" + str(round(thresh_list[0],2)) + "_" + str(bam_fracs[0]) + ".db"
# Create a SQLite database connection
#conn = sqlite3.connect(db_fn)
# Save the DataFrame to SQLite database
#comb_bedmethyl_plot_df.to_sql('bedmethyl_plot', conn, if_exists='replace', index=False)
# Close the connection
#conn.close()

In [None]:
# Similar to other dataframes, define fn based on configurations. If it doesn't exist, create it, otherwise import it.
plot_df_fn = "temp_files/" + "plot_df_" + "-".join([str(x)[0:12] for x in type_selected])+"_"+ str(thresh_list[0])+"_"+str(bam_fracs[0])+str(bed_window)+".csv"
#comb_bedmethyl_plot_df.to_csv(plot_df_fn, index=False)
if os.path.exists(plot_df_fn):
    print("plot_df_fn exists, importing...")
    plot_df = pd.read_csv(plot_df_fn)
else:
    print("plot_df_fn does not exist, creating...")
    comb_bedmethyl_plot_df.to_csv(plot_df_fn, index=False)
    plot_df = comb_bedmethyl_plot_df.copy()

metadata_cols = ['chrom', 'chr_type', 'condition', 'bed_start','type', 'read_id', 'rel_read_start','rel_read_end']

nanotools.display_sample_rows(plot_df)

In [None]:
def output_kmer_plots(plot_df):
    kmer_plot_df = plot_df.copy()
    kmer_plot_df['3mer'] = kmer_plot_df['query_kmer'].str[1:4]

    query_kmer_counts = kmer_plot_df.groupby(['3mer', 'canonical_base', 'condition']).size().reset_index(name='3mer_occurence')
    nanotools.display_sample_rows(query_kmer_counts, 2)

    kmer_mod_plot_df = kmer_plot_df[kmer_plot_df['mod_qual'] > 0.9]
    mod_kmer_counts = kmer_mod_plot_df.groupby(['3mer', 'canonical_base', 'condition']).size().reset_index(name='mod_occurence')
    nanotools.display_sample_rows(mod_kmer_counts, 2)

    merged_kmer_counts = query_kmer_counts.merge(mod_kmer_counts, on=['3mer', 'canonical_base', 'condition'], how='left')
    merged_kmer_counts['percent'] = merged_kmer_counts['mod_occurence'] / merged_kmer_counts['3mer_occurence']
    merged_kmer_counts = merged_kmer_counts.sort_values(by=['condition', 'percent'], ascending=[False, False])
    merged_kmer_counts = merged_kmer_counts[merged_kmer_counts['3mer_occurence'] >= 100]
    nanotools.display_sample_rows(merged_kmer_counts, 2)

    # Create vertical subplots for 'A' canonical base
    fig_50 = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.25, subplot_titles=["Percent Methylated", "Trinucleotide Occurrence"])
    filtered_df = merged_kmer_counts[merged_kmer_counts['canonical_base'] == 'A']
    for condition in filtered_df['condition'].unique():
        df_condition = filtered_df[filtered_df['condition'] == condition]
        fig_50.add_trace(go.Bar(x=df_condition['3mer'], y=df_condition['percent'], name=f"Percent - {condition}",text=df_condition['percent'], textposition='outside',texttemplate='%{text:.0%}'), row=1, col=1)
        if condition != 'N2_bg':
            fig_50.add_trace(go.Bar(x=df_condition['3mer'], y=df_condition['3mer_occurence'], name=f"Occurrence - {condition}",), row=2, col=1)

    fig_50.update_yaxes(tickformat='.0%', row=1, col=1)
    fig_50.update_layout(template="plotly_white", title_text="3mer Analysis for Canonical Base 'A'")
    # set width to 800
    fig_50.update_layout(width=1200)

    # Create vertical subplots for 'C' canonical base
    fig_bg = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.25, subplot_titles=["Percent Methylated", "Trinucleotide Occurrence"])
    filtered_df = merged_kmer_counts[merged_kmer_counts['canonical_base'] == 'C']
    for condition in filtered_df['condition'].unique():
        df_condition = filtered_df[filtered_df['condition'] == condition]
        fig_bg.add_trace(go.Bar(x=df_condition['3mer'], y=df_condition['percent'], name=f"Percent - {condition}",text=df_condition['percent'], textposition='outside',texttemplate='%{text:.0%}'), row=1, col=1)
        if condition != 'N2_bg':
            fig_bg.add_trace(go.Bar(x=df_condition['3mer'], y=df_condition['3mer_occurence'], name=f"Occurrence - {condition}"), row=2, col=1)

    fig_bg.update_yaxes(tickformat='.0%', row=1, col=1)
    fig_bg.update_layout(template="plotly_white", title_text="3mer Analysis for Canonical Base 'C'")
    # set width to 800
    fig_bg.update_layout(width=1200)

    max_percent = merged_kmer_counts[merged_kmer_counts['canonical_base']=='A']['percent'].max()
    padding = max_percent * 0.2  # Adjust the padding as needed

    fig_50.update_yaxes(range=[0, max_percent + padding], row=1, col=1)
    max_percent = merged_kmer_counts[merged_kmer_counts['canonical_base']=='C']['percent'].max()
    padding = max_percent * 0.2  # Adjust the padding as needed
    fig_bg.update_yaxes(range=[0, max_percent + padding], row=1, col=1)

    fig_50.show()
    fig_bg.show()

    return fig_50, fig_bg

# Example usage
fig_50, fig_bg = output_kmer_plots(plot_df)

In [None]:

### Create read_id to m6A 5mC lookup table
plot_df_cutoff = plot_df#[plot_df['condition'] == '50_dpy27-3xGNB_GFP-Hia5_mcvipi']
canonical_count = plot_df_cutoff.groupby(['read_id','canonical_base','condition','type']).size().reset_index(name='canonical_count')
plot_df_cutoff = plot_df[plot_df['mod_qual'] > 0.8]
# drop all rows where condition == N2_bg


plot_df_cutoff = plot_df_cutoff[(plot_df_cutoff['rel_pos'] <= 250) & (plot_df_cutoff['rel_pos'] >= -250)]
mod_count = plot_df_cutoff.groupby(['read_id','canonical_base','condition','type']).size().reset_index(name='mod_count')
mod_lookup = canonical_count.merge(mod_count, on=['read_id','canonical_base','condition','type'], how='left')
mod_lookup['mod_frac'] = mod_lookup['mod_count']/mod_lookup['canonical_count']
# fill nan values with 0
mod_lookup['mod_frac'] = mod_lookup['mod_frac'].fillna(0)
nanotools.display_sample_rows(mod_lookup,10)


# Calculate quartiles separately for 'A' and 'C'
quartiles_A = mod_lookup[mod_lookup['canonical_base'] == 'A'].groupby('condition')['mod_frac'].transform(lambda x: pd.qcut(x, 2, labels=False))
quartiles_C = mod_lookup[mod_lookup['canonical_base'] == 'C'].groupby('condition')['mod_frac'].transform(lambda x: pd.qcut(x, 2, labels=False))

# add 1 to the value of quartiles_A and quartiles_C
quartiles_A += 1
quartiles_C += 1

# Assign these quartiles back to mod_lookup in new columns
mod_lookup.loc[mod_lookup['canonical_base'] == 'A', 'm6A_quartile'] = quartiles_A
mod_lookup.loc[mod_lookup['canonical_base'] == 'C', '5mC_quartile'] = quartiles_C
# set m6A_quartile to "q_" concatenated with the first digit ONLY of m6A_quartile
mod_lookup['m6A_quartile'] = "m6A_q_" + mod_lookup['m6A_quartile'].astype(str).str[0]
# set 5mC_quartile to "q_" concatenated with the first digit ONLY of 5mC_quartile
mod_lookup['5mC_quartile'] = "5mC_q_" + mod_lookup['5mC_quartile'].astype(str).str[0]

nanotools.display_sample_rows(mod_lookup,10)

# Create two separate DataFrames for 'A' and 'C'
mod_lookup_A = mod_lookup[mod_lookup['canonical_base'] == 'A'][['read_id', 'm6A_quartile']]
mod_lookup_C = mod_lookup[mod_lookup['canonical_base'] == 'C'][['read_id', '5mC_quartile']]

# Merge the DataFrames on read_id
mod_lookup_final = pd.merge(mod_lookup_A, mod_lookup_C, on='read_id')

# Merge this final DataFrame with plot_df
plot_df = pd.merge(plot_df, mod_lookup_final, on='read_id', how='left')

nanotools.display_sample_rows(plot_df,10)
# plot scatter plot of with mod qual of cannonical_base == A on x axis and mod_qual of cannonical_base == C on y axis
"""
# Separate the dataframe into two based on canonical_base
mod_lookup_A = mod_lookup[mod_lookup['canonical_base'] == 'A']
mod_lookup_C = mod_lookup[mod_lookup['canonical_base'] == 'C']
mod_merged = pd.merge(mod_lookup_A[['read_id', 'mod_frac']], mod_lookup_C[['read_id', 'mod_frac']], on='read_id', suffixes=('_A', '_C'))

# Create scatter plot
fig = go.Figure(data=go.Scatter(
    x=mod_merged['mod_frac_A'],
    y=mod_merged['mod_frac_C'],
    mode='markers',
    #color by condition and type
    marker=dict(color=mod_merged['condition'], size=5)
))

# Update layout
fig.update_layout(
    title='Scatter Plot of Modification Fraction',
    xaxis_title='Modification Fraction for A',
    yaxis_title='Modification Fraction for C',
    template='plotly_white'
)

# Show plot
fig.show()"""


In [None]:

### PROCESSING AUTOCORRELATON
from scipy.ndimage import gaussian_filter
from scipy.interpolate import interp1d
import tqdm
metadata_cols = ['chrom', 'chr_type', 'condition', 'bed_start','type', 'read_id', 'rel_read_start','rel_read_end']#,'m6A_quartile']
#version that requires a nuc and linker region (does not allow nucleosomes to be stuck next to eachother)
def process_raw_read(read_id, group, metadata_cols,gauss_std_dev, crr_length,corr_start,mod_code):
    global m6A_thresh
    # drop all rows where 'mod_code' != mod_code
    group = group[group['mod_code'] == mod_code]
    m6A_thresh_local = m6A_thresh
    # Calculate the total number of modified bases, total bases, and minimum base position in the read
    BASE_NUM = max(group['rel_pos']) - min(group['rel_pos']) + 1
    BASE_MIN = min(group['rel_pos'])

    if np.mean(group['mod_qual']) > m6A_thresh_local/255:
        return read_id, group.iloc[0][metadata_cols], np.zeros(1), None

    # Initialize the calling_vec with -1 and populate it with mod_qual values based on relative position
    calling_vec_raw = np.full(BASE_NUM+1, np.nan)
    for i in range(len(group['rel_pos'])):
        calling_vec_raw[group.iloc[i]['rel_pos'] - BASE_MIN] = group.iloc[i]['mod_qual']

    #print("read_mean: ", read_mean)
    # Impute -1 values with mean
    #calling_vec[calling_vec == -1] = 0#read_mean
    #print("calling_vec RAW: ", calling_vec[0:100])
    # Interpolate to fill NaN values
    not_nan = ~np.isnan(calling_vec_raw)
    indices = np.arange(len(calling_vec_raw))

    # set calling vec as a copy of calling_vec_raw
    calling_vec = calling_vec_raw.copy()

    # set all values in calling_vec > m6A_thresh to 1
    calling_vec[calling_vec > m6A_thresh_local/255] = 1
    # set all values in calling_vec <= m6A_thresh to 0
    calling_vec[calling_vec <= m6A_thresh_local/255] = 0

    interp_func = interp1d(indices[not_nan], calling_vec[not_nan], bounds_error=False, copy=False, fill_value="extrapolate", kind='nearest')
    calling_vec_filled = interp_func(indices)
    #print("calling_vec INTERP: ", calling_vec_filled[0:100])

    # set all values in calling_vec > m6A_thresh to 1
    #calling_vec_filled[calling_vec_filled > m6A_thresh/255] = 1
    # set all values in calling_vec <= m6A_thresh to 0
    #calling_vec_filled[calling_vec_filled <= m6A_thresh/255] = 0
    # Ensure no NaNs remain after smoothing
    if np.isnan(calling_vec_filled).any():
        # Handle remaining NaNs after smoothing if they exist
        calling_vec_filled = np.nan_to_num(calling_vec_filled, nan=0.0)
    #print("calling_vec THRESH: ", calling_vec_filled[0:100])
    # Apply gaussian smoothing
    calling_vec_smoothed = gaussian_filter(calling_vec_filled, sigma=gauss_std_dev)
    #print("calling_vec GAUSS: ", calling_vec_smoothed[0:100])
    read_mean = np.mean(calling_vec_smoothed[calling_vec != -1])
    read_std = np.std(calling_vec_smoothed[calling_vec != -1])
    if read_std == 0 or read_mean == 0:
        return read_id, group.iloc[0][metadata_cols], np.zeros(1), None

    # Calculate 1D autocorrelation
    autocorr = np.correlate(calling_vec_smoothed-read_mean, calling_vec_smoothed-read_mean, mode='same')/(read_std * read_std * len(calling_vec_smoothed))
    autocorr_centered = autocorr[autocorr.size // 2:]  # Taking one side as it's symmetric
    #autocorr_normalized = autocorr_centered / autocorr_centered[0]

    # Limit the autocorrelation calculation to a lag of crr_length

    if len(autocorr_centered) < (crr_length+1): # np.isnan(autocorr_limited).any() or np.isinf(autocorr_limited).any() or np.max(autocorr_limited) > np.finfo(np.float32).max:
        return read_id, group.iloc[0][metadata_cols], np.zeros(1), None

    else:
        autocorr_limited = autocorr_centered[corr_start:crr_length + 1]  # Include lag 0 to 500
        # if autocorr_limited contains NaN, infinity or a value too large for dtype('float32'). return 0
        # output scatter plot of first 500 values of calling_vec,calling_vec_filled, and calling_vec_smoothed
        # Instead of creating a single figure, create a subplot figure
        read_fig = make_subplots(rows=2, cols=1, subplot_titles=('Methylation', 'Autocorrelation'), shared_xaxes=False)# Set subfigure distribution:

        colors_scheme = plotly.colors.qualitative.Prism
        # define dictionary with one color for each of "Raw", "Extrapolated", "Smoothed" and "Autocorrelation"
        colors_dict = dict(zip(["Raw", "Extrapolated", "Smoothed", "Autocorrelation"], colors_scheme))
        # Add the methylation plot to the first subplot
        read_fig.add_trace(go.Scatter(
            x=np.arange(1000),
            y=calling_vec_raw[100:1100],
            mode='lines',
            name="Raw",
            marker=dict(color=colors_dict["Raw"])
        ), row=1, col=1)
        read_fig.add_trace(go.Scatter(
            x=np.arange(1000),
            y=calling_vec_filled[100:1100],
            mode='lines',
            name="Extrapolated",
            marker=dict(color=colors_dict["Extrapolated"])
        ), row=1, col=1)
        read_fig.add_trace(go.Scatter(
            x=np.arange(1000),
            y=calling_vec_smoothed[100:1100],
            mode='lines',
            name="Smoothed",
            marker=dict(color=colors_dict["Smoothed"])
        ), row=1, col=1)

        # Add the autocorrelation scatter plot to the second subplot
        read_fig.add_trace(go.Scatter(
            x=np.arange(1000),
            y=autocorr_limited[100:1100],
            mode='markers',
            name="Autocorrelation",
            # reduce marker size
            marker=dict(size=2.5,color=colors_dict["Autocorrelation"])
        ), row=2, col=1)

        # Update the layout of the subplot figure
        read_fig.update_layout(
            title='Read Analysis',
            width=800,
            height=400,
            template='plotly_white'
        )
        # Update xaxis and yaxis properties if needed
        read_fig.update_xaxes(title_text='Genomic Position', row=1, col=1)
        read_fig.update_yaxes(title_text='Mod Probability', row=1, col=1)
        read_fig.update_xaxes(title_text='Lag (bp)', row=1, col=2)
        read_fig.update_yaxes(title_text='Autocorrelation Value', row=2, col=1)
        return read_id, group.iloc[0][metadata_cols], autocorr_limited, read_fig

### CONFIGS
#grouped = grouped_subset.groupby('read_id')
crr_length = 1000
gauss_std = 10
corr_start = 100
# show this many single read tracks:
figures_shown = 0
# Process this many reads
reads_to_process = 0
corr_buff = 500

print("Grouping df...")


# grouped_auto = plot_df where (rel_read_end - rel_read_start) >= crr_length + 100
grouped_auto = plot_df.copy() #[plot_df['chr_type'] == 'X']
grouped_auto = grouped_auto[grouped_auto['condition'] != 'N2_mixed_endogenous_R10']
#grouped_auto = grouped_auto[plot_df['type'].str.contains('all_rex')]
grouped_auto = grouped_auto[(grouped_auto['rel_read_end'] - grouped_auto['rel_read_start']) >= (2*crr_length + corr_buff)]
grouped_auto = grouped_auto[grouped_auto['mod_code'] == 'm']
#set grouped_auto type column equal to m6A_quartile
#grouped_auto['type'] = grouped_auto['m6A_quartile']
nanotools.display_sample_rows(grouped_auto,10)

grouped_auto.sort_values(by=["read_id","rel_pos"], inplace=True)
grouped_auto.reset_index(inplace=True, drop=True)

# drop rows where read_id not in first 5 read_ids
if reads_to_process > 0:
    first_rows = grouped_auto['read_id'].unique()[:reads_to_process]#[plot_df['type'] == 'intergenic_control']
    grouped_auto = grouped_auto[grouped_auto['read_id'].isin(first_rows)]
    # reset index
    grouped_auto.reset_index(inplace=True, drop=True)

grouped_auto = grouped_auto.groupby('read_id')
#display(grouped_auto.head(10))
#nanotools.display_sample_rows(grouped_auto)

print("Processing autocorrelations...")
grouped_data_with_constants = [(read_id,group,metadata_cols,gauss_std,crr_length,corr_start,'a') for read_id,group in grouped_auto]

#processes=multiprocessing.cpu_count()
with multiprocessing.Pool() as pool:
    # set results equal to pool.starmap() with the function and grouped_data_with_constants as arguments using tqdm to track progress
    results = pool.starmap(process_raw_read, tqdm.tqdm(grouped_data_with_constants, total=len(grouped_data_with_constants)))

# Clear grouped_auto dataframe
grouped_auto = None

# Extracting autocorrelations and their corresponding metadata
grouped_autocorrelations = {}
# Initialize lists to hold the filtered conditions, clusters, and chr_types
conditions_list = []
chr_types_list = []
types_list = []

for read_id, metadata, autocorr, read_fig in results:
    #if autocorrs has nan values skip
    if len(autocorr) >1:
        # Create a unique key for each combination of type, chr_type, and condition
        key = (metadata['type'], metadata['chr_type'], metadata['condition'])
        if key not in grouped_autocorrelations:
            grouped_autocorrelations[key] = []
        grouped_autocorrelations[key].append(autocorr)

        conditions_list.append(metadata['condition'])
        chr_types_list.append(metadata['chr_type'])
        types_list.append(metadata['type'])

# Create a heatmap for each group
fig = go.Figure()

y_labels = []  # To store y-axis labels
z_data = []  # To store autocorrelation data for heatmap

for group_key, autocorrs in grouped_autocorrelations.items():
    group_label = f"{group_key[0]}, {group_key[1]}, {group_key[2]}"
    for i, autocorr in enumerate(autocorrs, start=1):
        read_label = f"{group_label} - Read {i}"
        y_labels.append(read_label)
        z_data.append(autocorr)

from scipy.signal import find_peaks, peak_prominences, peak_widths
# Extracting autocorrelations and their corresponding metadata for clustering, only if results[2] does not contain any nan values
autocorrelation_data = [result[2] for result in results if len(result[2])>1]

In [None]:
### PLOTTING AUTOCORRELATIONS
leiden_res = 0.4

def extract_peak_features(autocorr, num_peaks=4):
    # Find peaks
    peaks, _ = find_peaks(autocorr)

    # Initialize a fixed-length array filled with placeholders
    features = np.full(num_peaks * 4, -1.0) # 4 features per peak

    # If there are peaks, extract their features
    if len(peaks) > 0:
        # Sort peaks by height and select the top ones
        sorted_peaks = sorted(peaks, key=lambda x: autocorr[x], reverse=True)[:num_peaks]

        # if any two peaks are < 50 apart then remove the one with the lower height
        if len(sorted_peaks) > 1:
            for i in range(len(sorted_peaks)-1):
                if sorted_peaks[i+1] - sorted_peaks[i] < 50:
                    if autocorr[sorted_peaks[i+1]] > autocorr[sorted_peaks[i]]:
                        sorted_peaks[i] = -1
                    else:
                        sorted_peaks[i+1] = -1
            sorted_peaks = [x for x in sorted_peaks if x != -1]


        # Extract peak heights
        peak_heights = autocorr[sorted_peaks]

        # Extract peak prominences
        prominences = peak_prominences(autocorr, sorted_peaks)[0]

        # Extract peak widths
        widths = peak_widths(autocorr, sorted_peaks)[0]

        # Fill the features array with actual values
        for i, peak in enumerate(sorted_peaks):
            features[i * 4] = peak                    # Peak position
            features[i * 4 + 1] = peak_heights[i]     # Height
            features[i * 4 + 2] = prominences[i]      # Prominence
            features[i * 4 + 3] = widths[i]           # Width

    return features

# Apply the function to all autocorrelograms
expanded_peak_features = np.array([extract_peak_features(autocorr) for autocorr in autocorrelation_data])
# Check the shape of the expanded_peak_features array
import scanpy as sc
from scipy.spatial.distance import pdist, squareform

# Convert autocorrelation data to a DataFrame for ease of handling
autocorr_df = pd.DataFrame(autocorrelation_data)

# Compute the distance matrix on expanded peak features
#distance_matrix_expanded = squareform(pdist(expanded_peak_features, metric='euclidean'))

# Compute the distance matrix (Euclidean distance is used here, modify if needed)
distance_matrix_expanded = squareform(pdist(autocorr_df, metric='correlation'))

# Convert the distance matrix to a similarity matrix for expanded peak features
similarity_matrix_expanded = 1 / (1 + distance_matrix_expanded)

# Create an AnnData object with the expanded peak features similarity matrix
adata_expanded = sc.AnnData(similarity_matrix_expanded)
adata_expanded.obs_names = [f'Read_{i}' for i in range(adata_expanded.shape[0])]
adata_expanded.var_names = adata_expanded.obs_names

# Computing the neighborhood graph on the expanded peak features
sc.pp.neighbors(adata_expanded, use_rep='X', metric='correlation')

### DEFINE CLUSTERING RESOLUTION HERE: ###
# Applying Leiden clustering on the expanded peak features
sc.tl.leiden(adata_expanded, resolution=leiden_res)

# Determine the cluster for each read
# Reset the index of cluster_labels to align with autocorr_df
cluster_labels = adata_expanded.obs['leiden'].astype(int).reset_index(drop=True)

# Ensure that the lengths match
if len(cluster_labels) != len(autocorr_df):
    raise ValueError("Mismatch in length between autocorrelation data and cluster labels")

# Add all cluster labels to the DataFrame
autocorr_df['cluster'] = cluster_labels
autocorr_df_plotting = autocorr_df.copy()
autocorr_df_plotting['read_id'] = [result[0] for result in results if len(result[2])>1]

# Count the number of reads in each cluster
cluster_counts = cluster_labels.value_counts()
print("cluster_counts = ", cluster_counts)

# Total number of reads
total_reads = len(cluster_labels)

# Calculate the threshold for 5% of the total dataset
threshold_count = 0.05 * total_reads

# Filter clusters that are less than 5% of total
significant_clusters = cluster_counts[cluster_counts >= threshold_count].index

print("Significant clusters:", significant_clusters)

# Filter autocorrelations based on significant clusters
filtered_df = autocorr_df[autocorr_df['cluster'].isin(significant_clusters)]
filtered_df_plotting = autocorr_df_plotting[autocorr_df_plotting['cluster'].isin(significant_clusters)]

# Group the filtered autocorrelation data by cluster
grouped_by_cluster = filtered_df.groupby('cluster')

# Calculate the representative autocorrelogram for each cluster
representative_autocorrs = grouped_by_cluster.mean()

### Plot unique read plots:
# Extract the unique cluster names
unique_clusters = filtered_df_plotting['cluster'].unique()

# representative_reads_tuple = autocorr_df_plotting.groupby('cluster')['read_id'].nth(1)
# Plot one figure from one read for each unique cluster_name
#for cluster, rep_read_id in representative_reads_tuple.iteritems():
# choose 10 reads from cluster == 5 as representative_reads
representative_reads_list = autocorr_df_plotting[autocorr_df_plotting['cluster'] == 3]['read_id'].unique()[25:35]
for rep_read_id in representative_reads_list:
    # Loop through the results to find the figure for the representative read
    for result in results:
        if result[0] == rep_read_id:
            print("read_id", rep_read_id)
            # Display the figure for the representative read of this cluster
            disp_fig = result[3]
            # set x range for disp fig to 0-600
            disp_fig.update_xaxes(range=[0, 1000])
            # set width to 500
            disp_fig.update_layout(width=800, height = 500)
            #disp_fig.show()
            #if rep_read_id == "0b9e1dda-a4ae-4841-869b-533dd829136a":
                # save png and svg to images_11_14_23/raw_read_extrapolation_smoothing_strong_rex
                #disp_fig.write_image("images_11_14_23/raw_read_extrap_smooth_strong_rex" + rep_read_id[0:6] + ".png")
                #disp_fig.write_image("images_11_14_23/raw_read_extrap_smooth_strong_rex" + rep_read_id[0:6] + ".svg")
            break

"""def determine_cluster_name2(autocorr, cluster_id, prominence_threshold=0.1):
    # Find peaks and their prominences
    peaks, _ = find_peaks(autocorr, prominence=prominence_threshold,width=25)
    prominences = peak_prominences(autocorr, peaks)[0]
    print("peaks = ", peaks)
    print("prominences = ", prominences)

    # Check if there are any prominent peaks
    if len(peaks) > 0 and any(prominences >= prominence_threshold):
        # Find the position of the first prominent peak
        first_prominent_peak = peaks[np.argmax(prominences >= prominence_threshold)]
        return f'NRL-{first_prominent_peak + 100}-C{cluster_id}'
    else:
        # For clusters without prominent peaks, use 'NP' followed by the cluster id
        return f'NP{cluster_id}'"""

def determine_cluster_name(autocorr, cluster_id, prominence_threshold=0.06, distance_threshold=50):
    global corr_start
    # Find peaks in the autocorrelation signal
    peaks, properties = find_peaks(autocorr, prominence=prominence_threshold, width=25)
    prominences = properties["prominences"]

    # Filter peaks based on prominence to ensure they are significant
    significant_peaks = peaks[prominences >= prominence_threshold]
    print("peaks = ", peaks)
    print("prominences = ", prominences)
    print("significant_peaks = ", significant_peaks)

    # If there are enough significant peaks, calculate peak-to-peak distances including the first peak
    if len(significant_peaks) > 1:
        # if any two peaks are < 50 apart then remove the one with the lower height
        for i in range(len(significant_peaks)-1):
            if significant_peaks[i+1] - significant_peaks[i] < distance_threshold:
                if autocorr[significant_peaks[i+1]] > autocorr[significant_peaks[i]]:
                    significant_peaks[i] = -1
                else:
                    significant_peaks[i+1] = -1
        significant_peaks = [x for x in significant_peaks if x != -1]

        # Calculate distances between consecutive significant peaks
        peak_distances = np.diff(significant_peaks)

        # Include the distance from start (0 + 100) to the first significant peak
        #peak_distances_with_first = np.insert(peak_distances, 0, significant_peaks[0] + corr_start)
        print("peak_distances_with_first = ", peak_distances)
        if len(peak_distances) > 0:
            average_distance = round(np.mean(peak_distances))
            return f'NRL-{average_distance:.1f}-C{cluster_id}'
        # Calculate the average distance if we have enough valid peak distances
        """if len(significant_peaks) == 1:
            return f'NRL-{significant_peaks[0]:.1f}-C{cluster_id}'"""
    else:
        # If not enough valid peak distances, consider it as no pattern found
        return f'NP{cluster_id}'

# Apply the function to each cluster's average autocorrelogram to determine names
cluster_name_mapping = {cluster_id: determine_cluster_name(autocorr, cluster_id)
                        for cluster_id, autocorr in representative_autocorrs.iterrows()}



# Update the index of representative_autocorrs with new cluster names
representative_autocorrs.index = representative_autocorrs.index.map(cluster_name_mapping)

# Sort representative_autocorrs by index (cluster names)
#representative_autocorrs_sorted = representative_autocorrs.sort_index(ascending=False)
# Calculate the range for each row
representative_autocorrs['range'] = representative_autocorrs.apply(lambda row: row.iloc[220:].max() - row.iloc[220:].min(), axis=1)

# Sort the DataFrame based on the range
representative_autocorrs_sorted = representative_autocorrs.sort_values(by='range', ascending=True)

# Drop the 'range' column if you want to revert back to the original columns
representative_autocorrs_sorted.drop(columns=['range'], inplace=True)

# Prepare data for the heatmap
heatmap_data = representative_autocorrs_sorted.values
heatmap_labels = representative_autocorrs_sorted.index.to_list()

# Create the heatmap
fig_heatmap = go.Figure(go.Heatmap(
    z=heatmap_data,
    x=list(range(len(heatmap_data[0]))),
    y=heatmap_labels,
    colorscale='Inferno'
))
fig_heatmap.update_layout(
    title='Heatmap of Representative Autocorrelograms for Each Cluster',
    xaxis_title='Lag',
    yaxis_title='Cluster',
    yaxis={'type': 'category'},
    width=800,
    height=600,
    template='plotly_white'
)
fig_heatmap.update_xaxes(tickmode='array', tickvals=list(range(0, crr_length-corr_start, 100)), ticktext=list(range(corr_start, crr_length+1, 100)))


#display(autocorr_df.head(10))

### PRINT PER-READ HEATMAP
print("Number of rows after filtering:", len(filtered_df))
# Sort the DataFrame based on the cluster labels
sorted_df = filtered_df.sort_values(by='cluster')
# Step 1: Map the index of representative_autocorrs_sorted to full cluster names
unique_clusters = representative_autocorrs_sorted.index.map(cluster_name_mapping).unique()


# Step 2: Create a dictionary for sorting order based on full cluster names
sort_order = {name: i for i, name in enumerate(unique_clusters)}

# Step 3: Map the 'cluster' column to the full cluster names using cluster_name_mapping
sorted_df['full_cluster_name'] = sorted_df['cluster'].map(cluster_name_mapping)

# Step 4: Sort 'sorted_df' based on the order in 'representative_autocorrs_sorted'
sorted_df['sort_order'] = sorted_df['full_cluster_name'].map(sort_order)
sorted_df = sorted_df.sort_values(by='sort_order')

# Optional: Remove the 'sort_order' column if it's no longer needed
sorted_df.drop(columns=['sort_order'], inplace=True)

# Check if we have more than 500 rows
if sorted_df.shape[0] > 500:
    # Randomly sample 500 rows from sorted_df
    sampled_df = sorted_df.sample(n=500, random_state=np.random.RandomState())
    sampled_df = sampled_df.sort_values(by='full_cluster_name')
    # reset index
    sampled_df.reset_index(inplace=True, drop=True)
else:
    # If we have 500 rows or less, just use the entire dataframe
    sampled_df = sorted_df

# Use the 'read_id' as the y-axis labels for the sampled dataframe
y_labels_sampled = [f'{row["full_cluster_name"]}_{row.name}' for index, row in sampled_df.iterrows()]

# Extract the autocorrelation data for the sampled dataframe
z_data_sampled = sampled_df.drop(['cluster', 'full_cluster_name'], axis=1).values

# Create the heatmap with the sampled data
read_fig = go.Figure(go.Heatmap(
    z=z_data_sampled,
    x=list(range(z_data_sampled.shape[1])),  # Use the number of columns for the x-axis
    y=y_labels_sampled,
    colorscale='Inferno'
))

read_fig.update_layout(
    title='Random Per-Read Autocorrelation Heatmaps Grouped by Full Cluster Name',
    xaxis_title='Lag',
    yaxis_title='Read ID with Cluster Name',
    yaxis={'type': 'category'},
    width=800,
    height=800,
    template='plotly_white',
)
# Relabel x-axis labels to start from 100
read_fig.update_xaxes(tickmode='array', tickvals=list(range(0, z_data_sampled.shape[1], 100)), ticktext=list(range(corr_start, z_data_sampled.shape[1]+corr_start, 100)))


# Assign colors from 'prism' scheme to each cluster_name
colors_scheme = plotly.colors.qualitative.Prism
cluster_colors = {cluster_name: colors_scheme[i % len(colors_scheme)] for i, cluster_name in enumerate(representative_autocorrs_sorted.index)}

# Create scatter plots for each cluster with updated cluster names
fig_scatter = go.Figure()

for cluster_name, autocorr in representative_autocorrs_sorted.iterrows():
    fig_scatter.add_trace(go.Scatter(
        x=list(range(len(autocorr))),
        y=autocorr,
        mode='markers+lines',
        # reduce line width
        line=dict(width=3, color=cluster_colors[cluster_name]),
        # reduce marker size
        marker=dict(size=3, color=cluster_colors[cluster_name]),
        name=cluster_name,  # Updated cluster name
    ))

fig_scatter.update_layout(
    title='Scatter Plots of Representative Autocorrelograms for Each Cluster',
    xaxis_title='Lag',
    yaxis_title='Autocorrelation',
    width=800,
    height=600,
    template='plotly_white'
)
fig_scatter.update_xaxes(tickmode='array', tickvals=list(range(0, crr_length-corr_start, 100)), ticktext=list(range(corr_start, crr_length+1, 100)))

# Replace this with your actual DataFrame containing 'condition' and 'cluster' information
# For example, if you have a DataFrame 'metadata_df' with these columns, use that
# Extract conditions and clusters
"""conditions = [metadata['condition'] for _, metadata, _, _ in results
chr_types = [metadata['chr_type'] for _, metadata, _ , _ in results]"""
clusters = adata_expanded.obs['leiden'].astype(int).tolist()

"""for read_id, metadata, autocorr, read_fig in results:
    #if autocorrs has nan values skip
    if len(autocorr) < 100:
        continue"""

# Create DataFrame for plotting
metadata_df = pd.DataFrame({
    'condition': conditions_list,
    'chr_type': chr_types_list,
    'type': types_list,
    'cluster': clusters
})

# Update the cluster names in metadata_df with new cluster names
metadata_df['cluster'] = metadata_df['cluster'].map(cluster_name_mapping)
metadata_df_sorted = metadata_df.sort_values(by=['type', 'condition','cluster'])

# Create a new column that combines 'condition' and 'chr_type'
metadata_df_sorted['combined'] = metadata_df_sorted['type'] + ', '+ metadata_df_sorted['condition']  +  ', ' + metadata_df_sorted['chr_type']

# Now group by this new combined column and cluster, then count the number of reads
cluster_counts_by_combined = metadata_df_sorted.groupby(['combined', 'cluster']).size().unstack(fill_value=0)

# Sort the index of the resulting DataFrame to ensure the rows are ordered by 'type' and then by 'condition'
cluster_counts_by_combined = cluster_counts_by_combined.sort_index(key=lambda x: [tuple(i.split(', ')[1:]) for i in x])


# Calculate the percentage of reads in each cluster for each unique combination
cluster_percentages_by_combined = cluster_counts_by_combined.div(cluster_counts_by_combined.sum(axis=1), axis=0) * 100

# Sort the combined column to ensure the order of N2 Fiber, X, etc.
sorted_combinations = sorted(cluster_percentages_by_combined.index, key=lambda x: (x.split(', ')[0], x.split(', ')[1]))

nanotools.display_sample_rows(cluster_percentages_by_combined)

# Plotting with sorted combinations
fig_stacked = go.Figure()

for cluster_name in cluster_percentages_by_combined.columns:
    fig_stacked.add_trace(go.Bar(
        x=sorted_combinations,
        y=cluster_percentages_by_combined.loc[sorted_combinations, cluster_name],
        name=cluster_name,
        # set colorscheme to prism
        marker=dict(color=cluster_colors[cluster_name]),
        # Add % data labels
        text=cluster_percentages_by_combined.loc[sorted_combinations, cluster_name].round(1),
    ))

fig_stacked.update_layout(
    barmode='stack',
    title='Percentage of Reads in Each Cluster by Condition and Chr_Type',
    xaxis_title='Condition, Chr_Type',
    yaxis_title='Percentage of Reads',
    yaxis=dict(type='linear', ticksuffix='%'),
    legend_title='Clusters',
    template='plotly_white',
    width=800,
    height=600
)
fig_heatmap.show()
fig_scatter.show()
fig_stacked.show()
read_fig.show()

# Save all figures as svgs and pngs to images_11_14_23/
"""fig_heatmap.write_image("images_11_14_23/average_autocorrelograms_heatmap_2000bp_10gs_0p5res.png")
fig_scatter.write_image("images_11_14_23/average_autocorrelograms_scatter_2000bp_10gs_0p5res.png")
fig_stacked.write_image("images_11_14_23/average_autocorrelograms_stacked_box_2000bp_10gs_0p5res.png")
read_fig.write_image("images_11_14_23/per_read_500_autocorrelograms_heatmap_2000bp_10gs_0p5res.png")
fig_heatmap.write_image("images_11_14_23/average_autocorrelograms_heatmap_2000bp_10gs_0p5res.svg")
fig_scatter.write_image("images_11_14_23/average_autocorrelograms_scatter_2000bp_10gs_0p5res.svg")
fig_stacked.write_image("images_11_14_23/average_autocorrelograms_stacked_box_2000bp_10gs_0p5res.svg")
read_fig.write_image("images_11_14_23/per_read_500_autocorrelograms_heatmap_2000bp_10gs_0p5res.svg")"""

In [None]:

# Drop rows where 'combined' column contains 'weak_rex'
cluster_percentages_by_combined = cluster_percentages_by_combined[~cluster_percentages_by_combined.index.to_series().str.contains('strong_rex')]

# Create a bar plot
fig = go.Figure()

condition_colors = {
    'N2_fiber': 'rgba(0, 77, 153, 0.8)',  # Darker Blue
    'SDC2_degron_fiber': 'rgba(204, 0, 0, 0.8)'  # More Vivid Red
}

# Extract clusters
clusters = cluster_percentages_by_combined.columns
clusters_reversed = clusters[::-1]  # Reverse the cluster order

# Add bars for each condition in each cluster
for condition in ['N2_fiber', 'SDC2_degron_fiber']:
    condition_df = cluster_percentages_by_combined[cluster_percentages_by_combined.index.to_series().str.contains(condition)]
    fig.add_trace(go.Bar(
        name=condition,
        x=clusters_reversed,
        y=[condition_df[cluster].mean() for cluster in clusters_reversed],  # Use mean percentage for each cluster
        marker_color=condition_colors[condition]
    ))

# Update layout to group bars
fig.update_layout(
    title='Bar Plot of Cluster Percentages by Condition (STRONG REX)',
    xaxis_title='Cluster',
    yaxis_title='Percentage of Reads',
    yaxis=dict(type='linear', ticksuffix='%'),
    barmode='group',  # Group bars by cluster
    template='plotly_white',
    width=800,
    height=600
)
fig.update_xaxes(showgrid=False)  # Remove gridlines from X-axis
fig.update_yaxes(showgrid=False)  # Remove gridlines from Y-axis

# Display the figure
fig.show()

fig.write_image("images_11_14_23/average_autocorrelograms_bar_plot_2000bp_10gs_0p5res_WEAKrex.png")
fig.write_image("images_11_14_23/average_autocorrelograms_bar_plot_2000bp_10gs_0p5res_WEAKrex.svg")


In [None]:
### Calculating Nucleosome, Met Accessible Domain regions and midpoints
# Simple Algo
import multiprocessing
metadata_cols = ['chrom', 'chr_type', 'condition', 'bed_start','type', 'read_id', 'rel_read_start','rel_read_end']
MAD_dist_max = 65 # Distance below which m6A marks are combined into MAD
NUC_max_width  = 170 # Distance below which m6A marks are combined into NUC
NUC_min_width  = 110 # Distance above which m6A marks are combined into NUC

def calculate_midpoints_for_group(read_id,group,MAD_dist_max,NUC_max_width ,NUC_min_width ,metadata_cols):
    # if read_id %2000 == 0 then print progress message
    #iter=0
    #if iter % 5000 == 0:
    #    print("Processing read:", iter, sep='\n')
    #if group.empty:
    #    print(f"Warning: Empty group for read_id: {read_id}. Skipping.")
    #    return read_id, [], [], None  # Return None to indicate empty group for later checks

    regions_MAD_list = []
    regions_NUC_list = []
    #print("Processing read:", read_id, sep='\n')
    """print("group:",group.head(10))"""

    min_position = min(group['rel_pos'])

    MAD_start, MAD_end = min_position, min_position
    NUC_start, NUC_end = min_position, min_position

    for i in range(len(group['rel_pos'])):
        x = group.iloc[i]['rel_pos']
        '''if iter % 2000 == 0:
            print("x: ", x)'''
        """if read_id == "0037677b-c871-4a92-943a-2062ebecaf2f" and x > 200:
            print("MAD_start: ", MAD_start, " | MAD_end: ", MAD_end, " | dist from x:", x - MAD_end)
            print("NUC_start: ", NUC_start, "NUC_end: ", NUC_end, " | dist from x:", x - NUC_end)
            print("regions_NUC_list:",regions_NUC_list)
            print("regions_MAD_list:",regions_MAD_list)
            print("x: ", x)"""

        # Initialize current state if no MAD_region
        if MAD_start is None and NUC_start is None:
            MAD_start, MAD_end = x, x
            NUC_start, NUC_end = x, x

        '''if iter % 2000 == 0:
            print("MAD_start: ", MAD_start, " | MAD_end: ", MAD_end, " | MAD_length: ", MAD_end - MAD_start)
            print("NUC_start: ", NUC_start, "NUC_end: ", NUC_end, " | NUC_length: ", NUC_end - NUC_start)'''

        # if x is within MAD_dist_max of MAD_end, extend MAD by setting MAD_end equal to x
        if (x - MAD_end) <= MAD_dist_max:
            MAD_end = x
            # if NUC meets NUC_min_width  and NUC_max_width , add midpoint to midpoints_NUC_list
            if NUC_min_width  < (NUC_end - NUC_start) <= NUC_max_width :
                regions_NUC_list.append((NUC_start,NUC_end))
                '''if iter % 2000 == 0:
                    print("Appending NUC:", (NUC_end + NUC_start) / 2)'''
            # Reset NUC_start and NUC_end to x regardless of whether it was appended or not.
            NUC_start, NUC_end = x, x

        # if x is greater than MAD_dist_max from MAD_end, add MAD region to regions_MAD_list and reset MAD_start and MAD_end to x
        elif (x - MAD_end) > MAD_dist_max:
            if (MAD_end - MAD_start) > 0:
                regions_MAD_list.append((MAD_start, MAD_end))
                '''if iter % 2000 == 0:
                    print("Appending MAD:", (MAD_start, MAD_end))'''
                MAD_start, MAD_end = x, x
            # if MAD is 0 update to new location
            if (MAD_end - MAD_start) == 0:
                MAD_start, MAD_end = x, x
            # if extended nuc would not be greater than NUC_max_width , extend NUC by setting NUC_end equal to x
            if (x - NUC_start) <= NUC_max_width :
                NUC_end = x
            # else if extended nuc would be greater than NUC_max_width , add midpoint to midpoints_NUC_list and reset NUC_start and NUC_end to x
            if (x - NUC_start) > NUC_max_width :
                if NUC_min_width  < (NUC_end - NUC_start) <= NUC_max_width :
                    regions_NUC_list.append((NUC_start,NUC_end))
                    '''if iter % 2000 == 0:
                        print("Appending NUC:", (NUC_end + NUC_start) / 2)'''
                    # if new position is < MAD dist from end of nuc, update MAD
                    if (x-NUC_end) <= MAD_dist_max:
                        MAD_start, MAD_end = NUC_end, x
                    NUC_start, NUC_end = x, x
                MAD_start, MAD_end = x, x
    # append last nucs / mads
    if (MAD_end - MAD_start) > 0:
        regions_MAD_list.append((MAD_start, MAD_end))
    if NUC_min_width  < (NUC_end - NUC_start) <= NUC_max_width :
        regions_NUC_list.append((NUC_start,NUC_end))
    return read_id, regions_MAD_list, regions_NUC_list, group.iloc[0][metadata_cols]

def create_dataframe_from_results(results, kind='NUC', metadata_cols=metadata_cols):
    # if "rel_start" in res[3] == -21 then print res[0:3]
    if kind == 'NUC':
        df = pd.DataFrame.from_dict({res[0]: [(x[0] + x[1]) / 2 for x in res[2]] for res in results}, orient='index')
    elif kind == 'MAD':
        df = pd.DataFrame.from_dict({res[0]: [(x[0] + x[1]) / 2 for x in res[1]] for res in results}, orient='index')
    elif kind == 'MAD_region':
        df = pd.DataFrame.from_dict({res[0]: res[1] for res in results}, orient='index')
    elif kind == 'NUC_region':
        df = pd.DataFrame.from_dict({res[0]: res[2] for res in results}, orient='index')
    else:
        raise ValueError(f"Invalid kind: {kind}")

    # Reset the index to make it a new column
    df.reset_index(inplace=True)

    # Rename the new column to something meaningful
    df.rename(columns={'index': 'read_id'}, inplace=True)
    # Adding additional metadata columns to each df.
    # Create a new DataFrame for the metadata_cols
    metadata_df = pd.DataFrame({cols: [res[3][cols] if res[3] is not None else None for res in results] for cols in metadata_cols})
    # Concatenate the new DataFrame and the original DataFrame along axis 1 (columns)
    df = pd.merge(metadata_df, df, on='read_id', how='left')

    return df

# Using multiprocessing to parallelize the calculations
print("Grouping df...")
# Sort plot_df by read_id then by rel_pos
#print("Plot_df:")
#display(plot_df.head(10))
plot_df.sort_values(by=["read_id","rel_pos"], inplace=True)
plot_df.reset_index(inplace=True, drop=True)
# grouped = plot_df dropping all rows where mod_qual != 1, then grouped by read_id
plot_df_m6a_ony = plot_df[plot_df['mod_qual_bin'] == 1]

grouped = plot_df_m6a_ony.groupby('read_id')
print("Grouped_df...")
display(grouped.head(10))
print("Calculating nucleosome and MAD positions...")

grouped_data_with_constants = [(read_id,group,MAD_dist_max, NUC_max_width , NUC_min_width ,metadata_cols) for read_id,group in grouped]

#processes=multiprocessing.cpu_count()
with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
    results = pool.starmap(calculate_midpoints_for_group, grouped_data_with_constants)

print("creating result dfs...")
midpoint_NUC = create_dataframe_from_results(results, kind='NUC', metadata_cols=metadata_cols)
midpoint_MAD = create_dataframe_from_results(results, kind='MAD', metadata_cols=metadata_cols)
region_MAD = create_dataframe_from_results(results, kind='MAD_region', metadata_cols=metadata_cols)
region_NUC = create_dataframe_from_results(results, kind='NUC_region', metadata_cols=metadata_cols)

# drop duplicates based on read_id from all dfs
midpoint_NUC.drop_duplicates(subset=['read_id'], keep='first', inplace=True)
midpoint_MAD.drop_duplicates(subset=['read_id'], keep='first', inplace=True)
region_MAD.drop_duplicates(subset=['read_id'], keep='first', inplace=True)
region_NUC.drop_duplicates(subset=['read_id'], keep='first', inplace=True)

#print column names
print("midpoint_MAD columns:")
display(midpoint_MAD.head(10))
print("midpoint_NUC columns:")
display(midpoint_NUC.head(3))
print("region_MAD columns:")
display(region_MAD.head(10))
print("region_NUC columns:")
display(region_NUC.head(3))

In [None]:
## Calculate positive and negative controls
# Define percentile
percentile = 0.95

# Filter condition
filtered_df = plot_df.copy()#[plot_df['condition'] == "N2_bg"]
filtered_df = filtered_df[filtered_df['condition'] == 'N2_fiber']
# Add column for average_mod_qual for each read_id
filtered_df['avg_mod_qual'] = filtered_df.groupby('read_id')['mod_qual'].transform('mean')
percentile_99 = filtered_df['avg_mod_qual'].quantile(percentile)
high_methylation_read_ids = filtered_df[filtered_df['avg_mod_qual'] <= percentile_99]['read_id']
plot_df_filtered_99 = plot_df[plot_df['read_id'].isin(high_methylation_read_ids)]
print("reads above threshold:", len(high_methylation_read_ids)," | bases above threshold:", len(plot_df_filtered_99))

x_values = np.linspace(0, 1, 1000)
kde = gaussian_kde(plot_df_filtered_99["mod_qual"])
y_values = kde(x_values)/1000

# Create the KDE plot using Plotly
fig = go.Figure()
fig.add_trace(go.Scatter(x=x_values, y=y_values, mode='lines', name='KDE'))
fig.update_layout(title='KDE plot of mod_qual',
                  xaxis_title='mod_qual',
                  yaxis_title='Density')
# set theme to plotly white
fig.update_layout(template="plotly_white")
fig.show()

output_fn = f"temp_files/N2-intergenic_0p1-neg-ctrl-{str(percentile)}.txt"
# output y_values list to txt file
with open(output_fn, 'w') as f:
    for item in y_values:
        f.write("%s\n" % item)

In [None]:
### MATRIX BASED NUCLEOSOME POSITIONS
import multiprocessing
output_file = "temp_files/nucleosomes.csv"
NUC_width  = 147 # Distance below which m6A marks are combined into NUC

emission_NEG_fn = "temp_files/N2-intergenic_0p1-neg-ctrl-0.01.txt"
emission_PGC_fn = "temp_files/N2-sdc2-intergenic_0p1-pos-ctrl-0.95.txt"

# Assume emission_NEG_array and emission_PGC_array are defined
if emission_NEG_fn is None:
    print("Using default Negative Control array...")
    emission_NEG_array = np.array([0.00149399264776291,
	0.00152416739428112,
	0.00155428407897689,
	0.00158366982352413,
	0.00161263012922332,
	0.00164116736093195,
	0.00166881741930283,
	0.00169644480709635,
	0.00172243991332119,
	0.00174843501954603,
	0.00177314930158316,
	0.00179732634903045,
	0.0018208293765571,
	0.00184330135380909,
	0.00186562176698358,
	0.00188633353415362,
	0.00190704530132367,
	0.00192636133771002,
	0.00194527658490408,
	0.00196334324282506,
	0.0019804441422428,
	0.00199721573026683,
	0.00201223912200322,
	0.00202726251373961,
	0.00204062362531173,
	0.00205368981620955,
	0.00206583124459128,
	0.00207718528459883,
	0.00208815849360906,
	0.00209786334175532,
	0.00210756818990159,
	0.00211581913836465,
	0.00212394661339389,
	0.00213115648509963,
	0.00213778589208315,
	0.00214392113051369,
	0.00214892137536451,
	0.00215392162021533,
	0.00215755004456053,
	0.00216117036760352,
	0.00216398797335146,
	0.00216643632914367,
	0.00216847696309653,
	0.00216984018420171,
	0.00217113675922387,
	0.00217149832890489,
	0.0021718598985859,
	0.00217152301775043,
	0.00217096234398462,
	0.00216995779524117,
	0.00216840664063006,
	0.00216674524343967,
	0.00216440065280018,
	0.00216205606216068,
	0.00215914191711877,
	0.00215611090596126,
	0.00215274778308761,
	0.00214907876784382,
	0.00214528404933982,
	0.00214102027144815,
	0.00213675649355648,
	0.00213199079662317,
	0.00212717066161593,
	0.00212198067237151,
	0.00211653678978936,
	0.00211095201976393,
	0.00210501163039892,
	0.00209907124103391,
	0.00209269337479204,
	0.00208630408847105,
	0.00207963303317977,
	0.00207282026467108,
	0.00206586580125198,
	0.00205865519703331,
	0.00205142714861267,
	0.00204384512826049,
	0.00203626310790831,
	0.00202837761686597,
	0.00202038426090988,
	0.00201226433018618,
	0.00200397563310567,
	0.00199565703906218,
	0.00198711789365238,
	0.00197857874824259,
	0.00196987018727312,
	0.00196112193231988,
	0.0019522914291573,
	0.00194337899906746,
	0.00193444440342173,
	0.00192541639579044,
	0.00191638838815915,
	0.00190731276271764,
	0.00189823080231246,
	0.00188916972807923,
	0.0018801241937282,
	0.0018711050764496,
	0.00186215953560976,
	0.00185321399476991,
	0.00184442692402003,
	0.00183564732214734,
	0.00182701881316729,
	0.00181847317634955,
	0.00181003266704129,
	0.00180179935671974,
	0.00179357963248821,
	0.00178584221445295,
	0.00177810479641768,
	0.00177072403798445,
	0.00176348322228335,
	0.00175647482276216,
	0.00174980225906073,
	0.0017431934656268,
	0.00173715985499785,
	0.0017311262443689,
	0.00172565269607944,
	0.00172032748113237,
	0.00171539371441344,
	0.00171088162972888,
	0.00170655046389971,
	0.00170308894211409,
	0.00169962742032846,
	0.0016969800459273,
	0.0016944619012697,
	0.00169249675040448,
	0.00169097712103983,
	0.00168971428709704,
	0.00168924443741698,
	0.00168877458773692,
	0.00168933081369971,
	0.00168995787752069,
	0.0016913461472447,
	0.00169318876174207,
	0.00169548047964238,
	0.00169873971042928,
	0.0017020085047235,
	0.00170648947357022,
	0.00171097044241694,
	0.00171631727885466,
	0.00172203762680224,
	0.00172824495561917,
	0.00173521579213078,
	0.00174228730888025,
	0.00175051296494183,
	0.00175873862100341,
	0.00176801038609194,
	0.00177759279776968,
	0.00178783967310848,
	0.00179886081754385,
	0.00181006443736322,
	0.00182228102950662,
	0.00183449762165001,
	0.00184768975197826,
	0.00186106295667957,
	0.00187502883980271,
	0.00188951141693461,
	0.00190423112833026,
	0.00191976740640802,
	0.00193530368448579,
	0.00195183130849742,
	0.00196845001771239,
	0.00198567721797348,
	0.00200329895597738,
	0.00202114813398883,
	0.00203953448938654,
	0.00205792084478426,
	0.00207696436002779,
	0.00209601568311108,
	0.00211544651033297,
	0.00213505669974141,
	0.0021548321436499,
	0.00217488913226217,
	0.00219496718641124,
	0.00221537141661406,
	0.00223577564681688,
	0.002256299317718,
	0.00227686249114174,
	0.00229742175345622,
	0.00231797608361963,
	0.00233851210455439,
	0.00235892884624954,
	0.00237934558794468,
	0.00239954336575385,
	0.00241969433418121,
	0.00243964324843439,
	0.00245940163251569,
	0.00247903240997627,
	0.00249816977517,
	0.00251730714036373,
	0.00253575320236407,
	0.00255411926639662,
	0.00257200383160063,
	0.00258954974262857,
	0.0026068402979811,
	0.00262346728823345,
	0.00264009427848579,
	0.00265574328491529,
	0.00267136067974647,
	0.00268625812992732,
	0.00270078385794448,
	0.00271482710676313,
	0.00272797547796219,
	0.00274106811104646,
	0.00275286018811893,
	0.00276465226519139,
	0.00277549804233091,
	0.00278599718019019,
	0.00279593860426456,
	0.00280511833590122,
	0.00281414646296176,
	0.00282199050600528,
	0.00282983454904881,
	0.00283659926587619,
	0.00284310137291861,
	0.00284881712314759,
	0.00285373097591462,
	0.0028583866666396,
	0.00286191187936988,
	0.00286543709210015,
	0.00286784905596031,
	0.00287010446919811,
	0.00287166224046996,
	0.00287268836045052,
	0.00287340899169051,
	0.00287325249995569,
	0.00287309600822087,
	0.00287186670592455,
	0.00287057989213421,
	0.00286847444583565,
	0.00286590836228474,
	0.00286300801757572,
	0.00285943123904376,
	0.00285583510643775,
	0.00285135924319256,
	0.00284688337994738,
	0.00284181504428998,
	0.00283650744092764,
	0.00283089171678661,
	0.00282481979440526,
	0.00281868236584968,
	0.00281191307565515,
	0.00280514378546061,
	0.00279778159178072,
	0.00279025675624698,
	0.00278247440790956,
	0.00277440805368165,
	0.00276626310802112,
	0.00275772428059235,
	0.00274918545316358,
	0.0027402863069457,
	0.00273132710759631,
	0.00272216384304706,
	0.00271283222470792,
	0.00270342209370996,
	0.00269376162316871,
	0.00268410115262745,
	0.00267411397844846,
	0.00266410204357272,
	0.0026539405715072,
	0.00264368758456844,
	0.00263336631929775,
	0.00262289389696513,
	0.00261242107840481,
	0.00260174618371062,
	0.00259107128901644,
	0.00258026477300697,
	0.00256939988257235,
	0.00255846378400356,
	0.00254741324379054,
	0.00253634629093877,
	0.00252507504392566,
	0.00251380379691255,
	0.0025023930007641,
	0.00249093937623652,
	0.00247939968750435,
	0.00246775731272549,
	0.0024560860411578,
	0.00244424678292431,
	0.00243240752469082,
	0.00242039539324634,
	0.00240834973732403,
	0.00239618425532223,
	0.00238391181744946,
	0.00237157772447086,
	0.00235902398517773,
	0.00234647024588459,
	0.0023336939501436,
	0.00232089563806583,
	0.00230794559982686,
	0.00229489474029509,
	0.00228176844467743,
	0.00226845887481082,
	0.00225514930494422,
	0.0022415815317423,
	0.00222800912939691,
	0.00221424733004706,
	0.0022003935850584,
	0.00218642661016831,
	0.0021722617873322,
	0.00215808312729033,
	0.00214366577544926,
	0.00212924842360819,
	0.00211465082044989,
	0.00209999171421747,
	0.00208523301609069,
	0.00207034569704204,
	0.00205543141202063,
	0.00204033209893997,
	0.00202523278585931,
	0.00200996428509849,
	0.0019946581421026,
	0.00197925654419814,
	0.00196376279559351,
	0.00194824325308297,
	0.0019326202820674,
	0.00191699731105183,
	0.00190128344073455,
	0.00188555838281092,
	0.00186979031776118,
	0.00185399126441182,
	0.00183817992156937,
	0.00182233568886051,
	0.00180649145615164,
	0.00179063475186534,
	0.00177477756595309,
	0.00175894100514134,
	0.00174311537439783,
	0.0017273109446994,
	0.00171154686926291,
	0.0016957859228771,
	0.00168011067771277,
	0.00166443543254843,
	0.00164884368014338,
	0.00163328343443994,
	0.00161778151875516,
	0.00160236121032654,
	0.00158696118087024,
	0.0015717292825299,
	0.00155649738418957,
	0.00154144106989098,
	0.00152642908195586,
	0.00151152299297459,
	0.00149672746250112,
	0.0014819738214615,
	0.00146741095947384,
	0.00145284809748619,
	0.00143850054783584,
	0.00142418493215013,
	0.00141001546650868,
	0.00139596006919445,
	0.00138198396196036,
	0.00136824265170775,
	0.00135450134145514,
	0.00134105764081501,
	0.00132763184614742,
	0.00131439181724396,
	0.00130125900500067,
	0.00128822313563168,
	0.00127538874392315,
	0.001262559093001,
	0.00125002751902218,
	0.00123749594504335,
	0.00122518071851831,
	0.00121295538705716,
	0.00120087371641448,
	0.00118901001391015,
	0.00117717598865721,
	0.00116564928782962,
	0.00115412258700203,
	0.00114283794089724,
	0.00113162201887202,
	0.00112055184536474,
	0.00110964624622561,
	0.00109879029004084,
	0.00108819405126365,
	0.00107759781248645,
	0.00106726369661379,
	0.00105697537692674,
	0.00104688426775375,
	0.00103695975518832,
	0.0010271099969022,
	0.00101750648800675,
	0.00100790297911129,
	0.000998574902724826,
	0.000989269584532959,
	0.000980145224832513,
	0.000971134396414901,
	0.000962212227150721,
	0.000953491814293388,
	0.000944771401436055,
	0.000936335842153333,
	0.000927901402138094,
	0.000919699878608294,
	0.000911604527835387,
	0.000903616376577226,
	0.000895804854571337,
	0.000888011547514706,
	0.000880465963767979,
	0.000872920380021253,
	0.000865571953731993,
	0.000858286028389301,
	0.0008511140273685,
	0.000844081212149661,
	0.000837083102316501,
	0.000830296598932136,
	0.00082351009554777,
	0.000816966812248717,
	0.00081047274524928,
	0.000804099870408216,
	0.000797837747951584,
	0.000791622424112246,
	0.00078557969700825,
	0.000779536969904255,
	0.000773685769414869,
	0.000767854882354304,
	0.000762145512890653,
	0.000756518887317006,
	0.000750948457352519,
	0.000745518516410936,
	0.000740088575469351,
	0.000734887391392363,
	0.000729691708482459,
	0.000724613960722743,
	0.000719595007612682,
	0.000714636779849552,
	0.000709787393145405,
	0.000704945959474296,
	0.000700259132765475,
	0.000695572306056653,
	0.000691000669700496,
	0.00068646956328324,
	0.000682003915709478,
	0.000677624852333247,
	0.000673266799812099,
	0.000669060906547919,
	0.00066485501328374,
	0.000660756608283642,
	0.000656683068038609,
	0.000652673186708866,
	0.000648726221194235,
	0.000644802695524232,
	0.000640976717490456,
	0.00063715073945668,
	0.000633427003179422,
	0.000629716612618168,
	0.000626072256016885,
	0.000622476638767007,
	0.000618914414036679,
	0.000615444264268344,
	0.000611974114500009,
	0.000608597190210564,
	0.000605224460114442,
	0.000601911931742057,
	0.000598632141467172,
	0.000595382346985485,
	0.00059219115700149,
	0.000589002444417624,
	0.00058589565119799,
	0.000582788857978357,
	0.000579743349561872,
	0.000576721655620523,
	0.000573736587747257,
	0.000570804017238244,
	0.000567878586911705,
	0.000565016182426628,
	0.00056215377794155,
	0.000559344191342501,
	0.000556548428683992,
	0.00055378323086729,
	0.000551050701545474,
	0.0005483286120001,
	0.000545656023417182,
	0.000542983434834264,
	0.000540363805116811,
	0.000537752442733559,
	0.00053517517832031,
	0.000532625168270409,
	0.000530087240315178,
	0.000527586229871436,
	0.000525085219427695,
	0.000522627681576383,
	0.000520173047948365,
	0.000517745979727691,
	0.000515335227480211,
	0.000512937717693019,
	0.000510568480152332,
	0.000508199664723461,
	0.000505873567010415,
	0.000503547469297369,
	0.000501250589220506,
	0.000498966196010502,
	0.000496695162962353,
	0.000494444904965888,
	0.00049219731869335,
	0.000489979566671334,
	0.000487761814649318,
	0.000485568031671005,
	0.000483381287702995,
	0.0004812082906857,
	0.000479051186767545,
	0.000476899003524202,
	0.000474773728274767,
	0.000472648453025332,
	0.000470548777374678,
	0.000468453783132135,
	0.000466372655635534,
	0.000464303522368087,
	0.000462240098070154,
	0.000460196113797724,
	0.000458152129525293,
	0.000456130665525532,
	0.000454111218266608,
	0.000452106390033155,
	0.000450110963102256,
	0.000448123632972984,
	0.000446155248301112,
	0.000444186863629241,
	0.000442243842531968,
	0.000440301072062903,
	0.000438373921295375,
	0.000436454086652417,
	0.000434542734007763,
	0.000432645711476032,
	0.000430750112587146,
	0.000428875823852097,
	0.000427001535117047,
	0.000425144345785697,
	0.00042329275262283,
	0.000421452959951471,
	0.000419627930890632,
	0.000417806234963531,
	0.000416005890414868,
	0.000414205545866205,
	0.000412424037570675,
	0.000410646501849858,
	0.000408880780328978,
	0.000407126112245158,
	0.000405376199837444,
	0.000403644457701639,
	0.000401912715565835,
	0.000400201615548831,
	0.000398492854072518,
	0.000396800259167048,
	0.000395118943264493,
	0.000393444494351909,
	0.000391787717133798,
	0.000390130939915688,
	0.000388496551523359,
	0.000386862840212455,
	0.000385244372316799,
	0.000383633706524213,
	0.000382031144945843,
	0.000380443486116947,
	0.000378856815347661,
	0.000377292106448783,
	0.000375727397549905,
	0.000374183018515518,
	0.000372646012445319,
	0.000371119054035018,
	0.000369605708583288,
	0.000368094967065863,
	0.000366604174333278,
	0.000365113381600693,
	0.000363640651466014,
	0.000362172261615784,
	0.000360714926551738,
	0.000359268776842218,
	0.000357826772517435,
	0.000356402688345427,
	0.000354978604173419,
	0.000353578176775674,
	0.000352181016014909,
	0.000350796387518862,
	0.000349421233587207,
	0.000348051632342441,
	0.000346697812418305,
	0.000345343992494169,
	0.000344010261141992,
	0.00034267756315257,
	0.000341358284997707,
	0.000340046494128588,
	0.000338741586524437,
	0.000337450485940016,
	0.000336159986650359,
	0.000334894491201272,
	0.000333628995752185,
	0.000332377860385073,
	0.000331132469050751,
	0.000329895133393405,
	0.000328669628295482,
	0.000327446009242417,
	0.000326240172898006,
	0.000325034336553596,
	0.000323843802632659,
	0.000322657417876296,
	0.000321480409969967,
	0.000320313662797077,
	0.000319150848565898,
	0.000318007470383873,
	0.000316864092201848,
	0.000315736736668329,
	0.000314612008995157,
	0.000313497416614113,
	0.000312391120023778,
	0.000311289206730967,
	0.00031020112042547,
	0.000309113034119972,
	0.000308041703754552,
	0.000306971605443548,
	0.000305913023897305,
	0.000304861432085648,
	0.000303816465382589,
	0.000302786032511515,
	0.000301755666774007,
	0.000300742386528938,
	0.000299729106283869,
	0.000298727582240846,
	0.000297731224303511,
	0.000296741301278296,
	0.000295761634915037,
	0.000294783224857443,
	0.000293820019319691,
	0.000292856813781938,
	0.000291906837801576,
	0.000290960877936851,
	0.000290023596044622,
	0.00028909658679505,
	0.000288171901203492,
	0.000287260515496621,
	0.000286349129789751,
	0.000285450676845293,
	0.00028455469589738,
	0.00028366677093643,
	0.000282785980428877,
	0.000281908502735774,
	0.000281042693185597,
	0.000280176883635421,
	0.00027932567435724,
	0.00027847587497657,
	0.000277636298150485,
	0.000276803458917196,
	0.000275974807062331,
	0.000275156232831541,
	0.000274337658600752,
	0.000273532950404955,
	0.000272728462742518,
	0.000271933368354912,
	0.000271142793629978,
	0.000270357248373163,
	0.000269580417275012,
	0.000268804437281335,
	0.000268042603230979,
	0.000267280769180623,
	0.000266530390658594,
	0.000265783880626498,
	0.000265043132268183,
	0.000264309766054401,
	0.000263578076364993,
	0.000262857690514166,
	0.000262137304663338,
	0.000261427425761381,
	0.000260719853874449,
	0.000260018739978868,
	0.000259323811898758,
	0.000258631710753275,
	0.000257950807562433,
	0.000257269904371592,
	0.000256601265399393,
	0.000255934105664093,
	0.000255273919170365,
	0.000254618716906352,
	0.000253966734398879,
	0.000253323284246238,
	0.000252679834093597,
	0.000252047517582553,
	0.000251415607573636,
	0.000250791120036053,
	0.000250170532065499,
	0.000249554342610917,
	0.000248946453180823,
	0.000248339005770367,
	0.000247743004342121,
	0.000247147002913874,
	0.000246558633187281,
	0.000245973115091144,
	0.000245391893239115,
	0.00024481663393148,
	0.000244242474958107,
	0.000243677256200615,
	0.000243112037443122,
	0.000242554686957331,
	0.000241999298730009,
	0.000241449423858154,
	0.000240905260031805,
	0.00024036294523246,
	0.000239828941113632,
	0.000239294936994804,
	0.000238768967787017,
	0.000238244169586444,
	0.000237724495420592,
	0.000237208788823291,
	0.000236695372210134,
	0.000236188667012042,
	0.00023568196181395,
	0.000235183703228886,
	0.000234685934302548,
	0.000234194909884835,
	0.00023370774540129,
	0.000233223680435969,
	0.00023274600010699,
	0.000232268477733303,
	0.000231799765754802,
	0.000231331053776302,
	0.000230868747713796,
	0.00023040907834047,
	0.000229953056537195,
	0.000229502523969241,
	0.000229052823992294,
	0.00022861154042444,
	0.000228170256856586,
	0.000227737891724723,
	0.00022730803001349,
	0.000226882893428852,
	0.000226463050554506,
	0.000226044793367349,
	0.000225634712976775,
	0.000225224632586202,
	0.000224822999457194,
	0.000224422819412171,
	0.000224028095108463,
	0.000223637943422475,
	0.000223250173711301,
	0.000222870165268062,
	0.000222490156824824,
	0.000222121864192542,
	0.000221754512788202,
	0.000221393658863558,
	0.000221036847815168,
	0.000220683208351109,
	0.000220336720102671,
	0.000219990231854234,
	0.000219654013032453,
	0.000219317814346803,
	0.000218988656646756,
	0.00021866267939191,
	0.000218340543167296,
	0.000218024683265277,
	0.000217709699142918,
	0.00021740627044663,
	0.000217102841750341,
	0.000216806732872727,
	0.000216512919670892,
	0.00021622333370687,
	0.00021593887166747,
	0.000215655703589276,
	0.000215380299278321,
	0.000215104894967367,
	0.000214836774936073,
	0.000214570108340863,
	0.000214307931846458,
	0.000214049826600171,
	0.000213793815724045,
	0.000213545439779141,
	0.000213297063834237,
	0.000213055370049964,
	0.000212814369016686,
	0.00021257753598141,
	0.00021234351805298,
	0.000212111397471346,
	0.000211883975081779,
	0.000211656552692212,
	0.000211435212188537,
	0.000211214005486352,
	0.000210996865091232,
	0.000210781734070072,
	0.000210568949130645,
	0.00021036033358647,
	0.000210151971603062,
	0.000209948339503172,
	0.000209744707403282,
	0.00020954457722148,
	0.00020934566670508,
	0.000209148698349941,
	0.000208954278532204,
	0.000208760387808217,
	0.000208570259528677,
	0.000208380131249136,
	0.000208193566946129,
	0.000208007816531995,
	0.000207824449975262,
	0.000207643421077145,
	0.000207463150993792,
	0.000207285999639121,
	0.00020710884828445,
	0.000206935107171597,
	0.000206761802690145,
	0.000206590714329373,
	0.000206421248619673,
	0.000206252815287235,
	0.000206087200269366,
	0.000205921585251497,
	0.000205760013058788,
	0.000205598614130014,
	0.000205440217173112,
	0.000205283438749239,
	0.000205128027190769,
	0.000204975262952075,
	0.000204822626027799,
	0.000204673927893311,
	0.000204525229758824,
	0.000204379501751994,
	0.00020423491672112,
	0.000204092046723699,
	0.000203951615162467,
	0.000203811664116433,
	0.000203675865447009,
	0.000203540066777586,
	0.000203408108477148,
	0.000203277143375585,
	0.000203148220553568,
	0.000203021463536072,
	0.000202895446517695,
	0.000202772891050758,
	0.000202650335583821,
	0.000202531415746654,
	0.000202413053974553,
	0.000202297027843048,
	0.000202182853833558,
	0.000202069854713078,
	0.000201960408012907,
	0.000201850961312736,
	0.00020174599791707,
	0.000201641324068154,
	0.000201539371328492,
	0.000201439015761654,
	0.000201340085857401,
	0.000201244172446545,
	0.000201148313224282,
	0.000201057014875246,
	0.00020096571652621,
	0.000200877820309825,
	0.000200791364550719,
	0.000200707265618541,
	0.000200626801593565,
	0.000200546875878303,
	0.000200472808975743,
	0.000200398742073183,
	0.000200329639676169,
	0.000200261978991624,
	0.000200197605597393,
	0.000200137002918519,
	0.000200077621070956,
	0.000200024816452072,
	0.000199972011833189,
	0.000199926517229558,
	0.000199882339441608,
	0.000199844327192172,
	0.000199811606095788,
	0.00019978167545251,
	0.000199761141719694,
	0.000199740607986877,
	0.000199731621329111,
	0.000199723641969381,
	0.000199724098440952,
	0.000199729935458768,
	0.000199740333157121,
	0.000199761303342173,
	0.000199782273527224,
	0.000199819612513777,
	0.000199857080642749,
	0.000199909419283246,
	0.000199968660568062,
	0.000200035936954253,
	0.000200116675290393,
	0.000200198819625938,
	0.000200301329649831,
	0.000200403839673724,
	0.00020052387080789,
	0.000200649575876729,
	0.000200786254254552,
	0.00020093655301151,
	0.000201090408014908,
	0.000201266667370873,
	0.000201442926726838,
	0.00020164693178497,
	0.000201856708998635,
	0.000202082315666143,
	0.000202322616845521,
	0.000202569298123725,
	0.00020284006874703,
	0.000203110839370335,
	0.000203409986717346,
	0.000203712280179377,
	0.000204033785043592,
	0.000204368585215941,
	0.000204712803032829,
	0.000205081022763226,
	0.000205449242493624,
	0.000205858843705319,
	0.000206269610592543,
	0.000206704366648947,
	0.000207151294201218,
	0.000207610972814619,
	0.000208093899219027,
	0.000208578467946045,
	0.000209097882474504,
	0.000209617297002964,
	0.000210163855596826,
	0.000210720160065236,
	0.000211292334392631,
	0.000211885838392886,
	0.000212484853947511,
	0.000213125298019156,
	0.000213765742090802,
	0.000214436342965348,
	0.00021511409969143,
	0.000215810317775301,
	0.000216525069614443,
	0.000217246742629908,
	0.000217997955511329,
	0.00021874916839275,
	0.000219531803703943,
	0.000220318698500062,
	0.000221125776118586,
	0.000221947990853914,
	0.000222780702994833,
	0.000223642948881333,
	0.000224505194767833,
	0.000225395531828983,
	0.000226287253260136,
	0.000227196000586105,
	0.000228114166492185,
	0.000229039960001378,
	0.000229980919505062,
	0.000230922349020108,
	0.00023188178358581,
	0.000232841218151512,
	0.000233810417976696,
	0.00023478348655222,
	0.000235758799559412,
	0.000236737382099275,
	0.000237715437531063,
	0.000238688633026399,
	0.000239661828521735,
	0.00024062359010255,
	0.000241582287507762,
	0.000242529187859978,
	0.000243463279983227,
	0.000244391292284552,
	0.000245289674189373,
	0.000246188056094193,
	0.000247041563206147,
	0.000247887829158342,
	0.000248689811368131,
	0.000249455832666534,
	0.000250200947164118,
	0.00025088081221076,
	0.000251560677257402,
	0.000252145254830345,
	0.00025272304042403,
	0.000253226844667207,
	0.000253686120950204,
	0.00025410276062841,
	0.000254426712035328,
	0.000254749680232945,
	0.000254906158322218,
	0.000255062636411492,
	0.000255074786952388,
	0.000255024098251716,
	0.000254897104726594,
	0.000254649487918376,
	0.00025438527813475,
	0.000253925686065691,
	0.000253466093996631,
	0.000252832809555784,
	0.000252147373206868,
	0.000251352324908476,
	0.000250428540873678,
	0.000249463017922746,
	0.000248262239262703,
	0.00024706146060266,
	0.000245608930612002,
	0.000244108977696273,
	0.000242467896104711,
	0.000240702808066369,
	0.000238878433365831,
	0.000236847595464494,
	0.000234816757563156,
	0.00023254494552298,
	0.000230250429281286,
	0.000227799376482782,
	0.000225245992342934,
	0.000222609597455605,
	0.000219775296150444,
	0.000216940994845282,
	0.000213848306558107,
	0.000210752029562849,
	0.00020751151176366,
	0.000204202208609996,
	0.000200821681705679,
	0.000197318786336146,
	0.000193805789483209])
else:
    # read file into np.array. File has one element on each new line.
    print("Loading Negative Control array from file:", emission_NEG_fn)
    emission_NEG_array = np.loadtxt(emission_NEG_fn, delimiter='\n')
print("emission_NEG_array: ",emission_NEG_array[0:3])
if emission_PGC_fn is None:
    print("Using default Positive Control array...")
    mission_PGC_array = np.array([0.000218394723023875,
	0.000222978465201427,
	0.000227555235612708,
	0.000232042256740602,
	0.000236477039895504,
	0.000240859072758921,
	0.000245130489859589,
	0.000249399045315033,
	0.00025346156231431,
	0.000257524079313588,
	0.000261423437111359,
	0.000265254358688498,
	0.000268998994625546,
	0.000272611639162654,
	0.000276204832147616,
	0.000279591573431396,
	0.000282978314715176,
	0.000286185689607641,
	0.000289341558785724,
	0.000292388346455785,
	0.000295310991154374,
	0.000298191360734618,
	0.000300847306847099,
	0.000303503252959581,
	0.000305946590457699,
	0.000308352207072622,
	0.000310640154361108,
	0.000312827912186866,
	0.000314967507226035,
	0.00031694669569631,
	0.000318925884166584,
	0.000320722603463631,
	0.000322503826440235,
	0.000324170972597796,
	0.000325765955144589,
	0.000327300402987551,
	0.000328695816284937,
	0.000330091229582323,
	0.000331322025577324,
	0.000332551849422903,
	0.000333687359540695,
	0.000334779485343705,
	0.000335824760758471,
	0.000336792196205269,
	0.000337752168342873,
	0.000338607434642682,
	0.000339462700942492,
	0.000340242041926293,
	0.00034099705541866,
	0.000341705810617947,
	0.000342357601456322,
	0.000342998463121428,
	0.000343571594698672,
	0.000344144726275916,
	0.000344663695707856,
	0.000345171551680735,
	0.000345649015912825,
	0.000346098487752084,
	0.000346536881064369,
	0.000346933935058412,
	0.000347330989052454,
	0.000347685332563613,
	0.000348035043701356,
	0.000348354352914425,
	0.0003486527921594,
	0.000348939900488314,
	0.000349198408022205,
	0.000349456915556097,
	0.000349680405036328,
	0.00034990298039066,
	0.000350102932050512,
	0.000350291505322845,
	0.000350468591594087,
	0.000350624912901818,
	0.000350779797406728,
	0.000350905526787732,
	0.000351031256168735,
	0.0003511312685659,
	0.000351222140178359,
	0.000351301922414389,
	0.000351366918815182,
	0.000351429231461231,
	0.00035147174591655,
	0.000351514260371869,
	0.000351541249832706,
	0.000351564601794836,
	0.000351580303151066,
	0.000351588383786576,
	0.000351594400336041,
	0.00035159171823718,
	0.000351589036138319,
	0.000351582688323641,
	0.000351575852830583,
	0.00035157224358863,
	0.000351571034765586,
	0.000351573367225251,
	0.00035158556281305,
	0.000351597758400849,
	0.000351630877968515,
	0.000351664983707363,
	0.000351719192014218,
	0.00035178442627504,
	0.000351863877322186,
	0.000351971348663306,
	0.000352080714295238,
	0.000352257327250995,
	0.000352433940206751,
	0.000352661934264006,
	0.000352910088753456,
	0.000353192630472559,
	0.000353524860915848,
	0.000353866773759944,
	0.000354296018062288,
	0.000354725262364633,
	0.000355241714232557,
	0.000355781263153741,
	0.000356383393728729,
	0.000357052939499887,
	0.00035775241829841,
	0.000358595779717108,
	0.000359439141135807,
	0.000360421662612702,
	0.000361426272987725,
	0.000362527833131959,
	0.000363707501217466,
	0.000364933313022451,
	0.000366301616633184,
	0.000367669920243917,
	0.000369227091459414,
	0.000370797301651808,
	0.00037251128922475,
	0.000374311093921705,
	0.00037619847690961,
	0.000378274531647771,
	0.000380352507520218,
	0.000382673987163385,
	0.000384995466806552,
	0.000387495041689219,
	0.000390071441969319,
	0.000392750388190677,
	0.000395590109958127,
	0.000398451538494382,
	0.000401561800721822,
	0.000404672062949262,
	0.000408013982270395,
	0.000411424693165034,
	0.00041498755249337,
	0.000418727703428767,
	0.000422511009486048,
	0.000426533881796863,
	0.000430556754107679,
	0.000434816908899611,
	0.00043912110684349,
	0.000443573871068032,
	0.000448156154614176,
	0.000452799868440732,
	0.000457655116102447,
	0.000462510363764161,
	0.000467633555107455,
	0.000472781365207816,
	0.000478103726356244,
	0.000483539264151966,
	0.000489044531073141,
	0.000494714487178773,
	0.000500384443284405,
	0.000506270514427137,
	0.00051215915327328,
	0.000518184431539964,
	0.000524274288668231,
	0.000530431498930391,
	0.000536703458976223,
	0.000542986191924638,
	0.000549435731106588,
	0.000555885270288539,
	0.00056244211671957,
	0.000569034452787666,
	0.000575670102389663,
	0.000582360373041047,
	0.00058906029933304,
	0.000595823129283702,
	0.000602585959234365,
	0.000609387289704112,
	0.000616196850688627,
	0.000623017018977834,
	0.000629847189592377,
	0.000636674815598702,
	0.000643492602453251,
	0.000650310389307799,
	0.000657083288111348,
	0.000663850992446114,
	0.000670575539320763,
	0.000677269732115129,
	0.00068393694657477,
	0.000690534055361778,
	0.000697131164148786,
	0.000703612482952332,
	0.000710090059049341,
	0.000716474820420568,
	0.000722811659648513,
	0.000729081101137085,
	0.000735225536762686,
	0.000741361632973395,
	0.000747303142836627,
	0.00075324465269986,
	0.000759037728611016,
	0.000764776431657374,
	0.0007704243176846,
	0.000775948171208844,
	0.000781446480169015,
	0.000786745277275307,
	0.0007920440743816,
	0.000797155130133698,
	0.000802220506724191,
	0.000807143865374689,
	0.000811922399406534,
	0.000816652484961936,
	0.000821170417188913,
	0.000825688349415891,
	0.000829990924892564,
	0.000834263215826226,
	0.000838396723650937,
	0.000842424469175378,
	0.000846389693983552,
	0.000850175408553351,
	0.000853961123123149,
	0.00085752077684141,
	0.000861068311833609,
	0.000864436510307483,
	0.000867703797712943,
	0.000870894638391669,
	0.000873930776226943,
	0.000876962325502774,
	0.000879785303894813,
	0.000882608282286852,
	0.000885285577662803,
	0.000887904039513025,
	0.000890443689823824,
	0.00089286665314664,
	0.000895272124559082,
	0.000897508871919366,
	0.00089974561927965,
	0.000901813936052142,
	0.000903836049920251,
	0.000905779619472603,
	0.000907636564018275,
	0.000909468061593912,
	0.000911172024943254,
	0.000912875988292595,
	0.00091445560642861,
	0.000916014500362404,
	0.000917497882505659,
	0.00091891896742172,
	0.000920308710335504,
	0.000921598518504327,
	0.00092288832667315,
	0.000924033142749313,
	0.000925166969951085,
	0.000926226866763281,
	0.000927241519109932,
	0.00092821979320619,
	0.000929117531804726,
	0.000930015044059764,
	0.000930797121131362,
	0.000931579198202959,
	0.000932281330170592,
	0.000932948006258782,
	0.000933569173117553,
	0.000934117200143328,
	0.000934654305182335,
	0.000935055460228136,
	0.000935456615273937,
	0.000935763390724809,
	0.000936041200571098,
	0.00093626084924554,
	0.000936411103903786,
	0.000936541953841058,
	0.000936559997667074,
	0.000936578041493089,
	0.00093648115659902,
	0.000936361984126056,
	0.00093616416073995,
	0.000935896134131372,
	0.000935588264732996,
	0.000935138455395966,
	0.000934688646058935,
	0.000934096579935242,
	0.000933490441097169,
	0.000932787712683358,
	0.000932020800968407,
	0.000931205878566927,
	0.000930274312752543,
	0.000929342746938158,
	0.000928246243065466,
	0.000927146782135951,
	0.000925925208964196,
	0.000924644354558562,
	0.000923288912725867,
	0.000921803143399651,
	0.000920307974493325,
	0.000918650662830106,
	0.000916993351166887,
	0.0009152094354125,
	0.000913382321411783,
	0.000911482304643578,
	0.000909488135422295,
	0.000907473192949486,
	0.000905315714012363,
	0.00090315823507524,
	0.000900860641103677,
	0.000898531873212969,
	0.000896114147362088,
	0.000893610542865042,
	0.000891078966862419,
	0.000888435230607041,
	0.000885791494351664,
	0.00088303055955132,
	0.000880255200314673,
	0.000877408637360903,
	0.0008745107693719,
	0.000871582203438364,
	0.00086857148185471,
	0.000865560760271057,
	0.000862442235678448,
	0.000859319547961511,
	0.000856131612852902,
	0.000852909100533736,
	0.000849661175678016,
	0.000846364880369195,
	0.000843066372575779,
	0.000839707267287603,
	0.000836348161999428,
	0.000832951387900238,
	0.000829540399155383,
	0.00082611229807238,
	0.000822660255783891,
	0.000819204740267805,
	0.000815720428537459,
	0.000812236116807113,
	0.000808738418615214,
	0.000805237340997972,
	0.00080173458505071,
	0.000798230076926907,
	0.00079472662769909,
	0.000791228001051902,
	0.000787729374404713,
	0.000784243646654687,
	0.00078075983199952,
	0.000777289130977434,
	0.000773828664986046,
	0.000770377756776348,
	0.000766955151455066,
	0.000763532546133784,
	0.000760154654254514,
	0.000756779452603083,
	0.000753436291569518,
	0.000750111623114781,
	0.000746805389943305,
	0.000743537471065353,
	0.000740270531434388,
	0.000737065161957817,
	0.000733859792481245,
	0.000730702415869147,
	0.000727564980890488,
	0.000724461886173006,
	0.00072141089392076,
	0.000718367500589473,
	0.000715402798306326,
	0.00071243809602318,
	0.000709538448109523,
	0.0007066572704063,
	0.000703816785435171,
	0.000701022249340699,
	0.00069824205199385,
	0.000695536870777856,
	0.000692831689561863,
	0.000690204540288889,
	0.000687591024182971,
	0.000685038135088265,
	0.000682536461591478,
	0.000680058425289435,
	0.000677658252687724,
	0.000675258080086013,
	0.000672946596691343,
	0.000670642441430305,
	0.000668397370152596,
	0.000666189367488631,
	0.000664010649210049,
	0.000661898571680254,
	0.000659786494150459,
	0.000657769388226167,
	0.000655752655471249,
	0.00065381422052602,
	0.000651911476890052,
	0.000650044947545228,
	0.000648238087343326,
	0.000646437388905121,
	0.000644720490453201,
	0.000643003592001282,
	0.000641353380258228,
	0.000639724308889376,
	0.00063813372187888,
	0.000636590152714898,
	0.000635058279575077,
	0.000633597719693693,
	0.000632137159812308,
	0.000630758190380054,
	0.000629395731015154,
	0.000628073694572396,
	0.0006267885989274,
	0.000625519040278574,
	0.000624306783211952,
	0.000623094526145331,
	0.00062194557330668,
	0.000620803334552815,
	0.000619701090730542,
	0.000618626080299188,
	0.000617569494336162,
	0.000616558969543957,
	0.000615548444751751,
	0.000614612644878304,
	0.00061367864199689,
	0.000612783096373646,
	0.000611906722990546,
	0.000611050160545483,
	0.000610229106229797,
	0.000609410650973877,
	0.000608642721439824,
	0.00060787479190577,
	0.000607144622295878,
	0.000606427738585229,
	0.000605732412931554,
	0.000605065602707787,
	0.000604405772218546,
	0.000603796488516437,
	0.000603187204814327,
	0.000602614012411541,
	0.000602049168839442,
	0.000601505899700324,
	0.000600983953152558,
	0.000600470029609514,
	0.000599989494732583,
	0.000599508959855652,
	0.000599063783974608,
	0.000598623223537716,
	0.000598205754717613,
	0.0005978053297099,
	0.000597416736030119,
	0.000597060765497211,
	0.000596704794964303,
	0.000596382212131776,
	0.000596061131404379,
	0.000595761745640372,
	0.000595474157741653,
	0.000595197426671371,
	0.000594941907208094,
	0.000594687286349716,
	0.000594462379360029,
	0.000594237472370342,
	0.000594034773035221,
	0.000593840703305,
	0.000593659802593801,
	0.000593497777476533,
	0.00059333828213923,
	0.000593201116975081,
	0.000593063951810931,
	0.000592945168400905,
	0.000592831196017883,
	0.000592727614919201,
	0.000592635140294471,
	0.000592546110701244,
	0.000592473415976495,
	0.000592400721251746,
	0.000592344776058365,
	0.000592291445610832,
	0.000592248202997888,
	0.000592213023548508,
	0.000592181148339416,
	0.000592159369420095,
	0.000592137590500773,
	0.000592126800707829,
	0.000592116745052764,
	0.000592113027440863,
	0.000592113061318685,
	0.000592115808770647,
	0.00059212434962016,
	0.000592132963919072,
	0.000592149011297151,
	0.000592165058675231,
	0.00059218506429714,
	0.000592206761570742,
	0.00059222979406909,
	0.000592254902841916,
	0.000592280199499509,
	0.000592307594203675,
	0.00059233498890784,
	0.000592363361839817,
	0.000592392022048822,
	0.00059242084602359,
	0.000592449859330927,
	0.000592478776580631,
	0.000592507168553147,
	0.000592535560525664,
	0.000592562660181378,
	0.000592589523510639,
	0.000592615391265963,
	0.000592640397886167,
	0.000592664917198194,
	0.000592687777141864,
	0.000592710637085533,
	0.000592731357244762,
	0.000592751885781504,
	0.00059277094518019,
	0.00059278905979556,
	0.000592806361907693,
	0.000592821762868463,
	0.000592837163829234,
	0.000592850237690147,
	0.000592863288556002,
	0.000592875123394168,
	0.000592886388656147,
	0.000592897131039091,
	0.000592906990031666,
	0.000592916788553247,
	0.000592925681899632,
	0.000592934575246018,
	0.000592943115622039,
	0.000592951540480486,
	0.000592960081058767,
	0.000592968766414283,
	0.000592977606857003,
	0.000592987440756889,
	0.000592997274656774,
	0.000593008599023902,
	0.000593020237731278,
	0.000593033169367441,
	0.000593047310675612,
	0.000593062119970463,
	0.000593079481478572,
	0.000593096842986681,
	0.00059311775683079,
	0.000593139073118193,
	0.000593163828504855,
	0.000593190983263093,
	0.000593219902557034,
	0.000593253362754043,
	0.000593286822951052,
	0.00059332685047335,
	0.000593367076604277,
	0.000593412291057477,
	0.000593460058705214,
	0.000593510756564866,
	0.000593566842758649,
	0.000593623320565975,
	0.000593688502874333,
	0.000593753685182692,
	0.000593827851214409,
	0.00059390527534313,
	0.000593987627763663,
	0.000594076657224713,
	0.000594167065531452,
	0.000594268037198389,
	0.000594369008865326,
	0.000594480222438743,
	0.000594593897052795,
	0.000594714260684274,
	0.000594841392337322,
	0.00059497119236053,
	0.000595112527525574,
	0.000595253862690618,
	0.000595411641843845,
	0.000595571691659137,
	0.000595741117782164,
	0.000595917632522721,
	0.000596098527511738,
	0.000596291871627953,
	0.000596485215744169,
	0.000596695200324178,
	0.000596906040895123,
	0.000597128526129575,
	0.000597357508234215,
	0.000597592732321052,
	0.000597840477089824,
	0.000598088797854742,
	0.000598361070459372,
	0.000598633343064003,
	0.000598920079429399,
	0.000599212601299102,
	0.000599513517632457,
	0.000599826762067023,
	0.000600142035138305,
	0.000600476435355776,
	0.000600810835573247,
	0.000601162187075376,
	0.000601518134821354,
	0.000601884764312424,
	0.000602263082434395,
	0.000602646027435377,
	0.000603051837826823,
	0.000603457648218269,
	0.000603882833388092,
	0.000604311196198104,
	0.000604752038390955,
	0.000605203095025277,
	0.000605659637664174,
	0.0006061334857484,
	0.000606607333832627,
	0.000607102463829317,
	0.000607599158672513,
	0.000608110684005642,
	0.000608631210233605,
	0.000609160394462458,
	0.00060970857218076,
	0.00061025683868796,
	0.000610827701969662,
	0.000611398565251363,
	0.000611985103359496,
	0.000612578529560371,
	0.000613180600633589,
	0.000613796450843435,
	0.000614414000924118,
	0.000615052123800565,
	0.000615690246677013,
	0.000616346404411806,
	0.000617008037014311,
	0.000617681610837895,
	0.000618369320209508,
	0.000619060257335116,
	0.000619769669105301,
	0.000620479080875487,
	0.000621206608104756,
	0.000621937597962568,
	0.000622679976443195,
	0.000623432440789044,
	0.000624189636083787,
	0.000624963494366673,
	0.000625737352649559,
	0.000626532326089027,
	0.000627329338545763,
	0.000628141395567509,
	0.000628963368325284,
	0.000629791616267602,
	0.000630634966487383,
	0.000631478316707164,
	0.00063234278601301,
	0.000633207591208892,
	0.000634086951766792,
	0.000634973315774242,
	0.000635867616157656,
	0.000636775667267338,
	0.000637685091154005,
	0.000638617331540917,
	0.000639549571927829,
	0.000640500795643636,
	0.000641458429960038,
	0.000642425853336763,
	0.000643405818947027,
	0.000644388693738128,
	0.000645391183460628,
	0.000646393673183129,
	0.00064741478645544,
	0.00064843998890814,
	0.000649476884714488,
	0.00065052498105115,
	0.000651578325213818,
	0.000652652456881438,
	0.000653726588549059,
	0.000654824140865433,
	0.000655924518040757,
	0.000657038540734793,
	0.000658162316769386,
	0.000659292524636586,
	0.000660439776860197,
	0.000661587029083808,
	0.000662756983452414,
	0.000663927766702582,
	0.000665113998121805,
	0.000666308345594024,
	0.00066751207476046,
	0.000668733507236887,
	0.00066995590878459,
	0.000671203404177973,
	0.000672450899571356,
	0.000673715525707908,
	0.000674986552847955,
	0.000676267427286205,
	0.000677561968301867,
	0.000678859085408484,
	0.000680177133254108,
	0.000681495181099732,
	0.000682832050822501,
	0.000684173614509742,
	0.000685528725696515,
	0.000686897870149738,
	0.000688271695422548,
	0.000689666559217806,
	0.000691061423013064,
	0.000692477140231675,
	0.000693895896626334,
	0.000695328252400697,
	0.000696771138250316,
	0.00069822023839242,
	0.00069968755119173,
	0.00070115486399104,
	0.000702645606567474,
	0.000704137707391858,
	0.000705649025642396,
	0.000707171342173749,
	0.000708702736467397,
	0.000710252829864218,
	0.00071180339575871,
	0.000713380316523326,
	0.000714957237287942,
	0.00071655368033604,
	0.000718158158799163,
	0.000719773953098194,
	0.000721406776620874,
	0.000723042227410737,
	0.00072470423644495,
	0.000726366245479162,
	0.000728056991530223,
	0.000729755804112328,
	0.000731470166148265,
	0.000733201948734139,
	0.000734939036080134,
	0.000736703478092782,
	0.00073846792010543,
	0.000740261067766464,
	0.000742059153325729,
	0.000743876075764799,
	0.000745708785948454,
	0.000747549855339688,
	0.000749418161815614,
	0.000751286468291541,
	0.000753196847242768,
	0.000755110606223725,
	0.000757048307346835,
	0.000759000905802841,
	0.000760965451607507,
	0.000762956936019976,
	0.000764948420432445,
	0.000766979498383766,
	0.000769010653969476,
	0.000771069646335534,
	0.000773141212701806,
	0.000775228384091265,
	0.000777341054410367,
	0.000779457434111872,
	0.000781622757053423,
	0.000783788079994974,
	0.000785985910860532,
	0.00078819393701331,
	0.000790421502281194,
	0.000792672751324963,
	0.000794930246744883,
	0.000797225220421711,
	0.000799520194098538,
	0.000801852028630546,
	0.000804191218028024,
	0.000806554359270955,
	0.000808939218045397,
	0.000811336138235996,
	0.000813777027770949,
	0.000816217917305902,
	0.000818700617387576,
	0.00082118765203997,
	0.0008237026146183,
	0.000826236439992509,
	0.000828783933222733,
	0.000831365270667848,
	0.000833946608112963,
	0.00083657522461193,
	0.000839204881250082,
	0.000841868926497476,
	0.000844549964946513,
	0.000847253165795237,
	0.000849995753082534,
	0.000852741007488514,
	0.000855536013916083,
	0.000858331020343653,
	0.000861166027508755,
	0.000864014966329145,
	0.000866887954210278,
	0.000869792499682431,
	0.000872704127538948,
	0.000875666119017602,
	0.000878628110496256,
	0.000881641763355293,
	0.000884667213885331,
	0.000887730403543006,
	0.000890830600949875,
	0.000893943704542551,
	0.000897109852558889,
	0.000900276000575227,
	0.000903503089578598,
	0.000906737981180751,
	0.000910014263986748,
	0.000913320853572508,
	0.000916647499598504,
	0.000920028898242369,
	0.000923410296886234,
	0.000926873346205817,
	0.000930339894840074,
	0.000933869674649324,
	0.000937433545965573,
	0.000941027278840493,
	0.000944678847033066,
	0.000948333281424669,
	0.000952076388848754,
	0.000955819496272838,
	0.000959631558690689,
	0.000963470156634597,
	0.000967349859834764,
	0.000971288006527242,
	0.000975238125670542,
	0.000979291704860575,
	0.000983345284050607,
	0.000987499440167837,
	0.000991679607559342,
	0.000995915889907207,
	0.00100021168126121,
	0.00100452866883133,
	0.00100894480759017,
	0.00101336094634901,
	0.00101788565644797,
	0.00102242703212508,
	0.00102704108531194,
	0.00103171277048897,
	0.00103642272951668,
	0.00104124841420309,
	0.00104607409888951,
	0.00105105260344727,
	0.00105604097762172,
	0.00106112485716333,
	0.00106626479420305,
	0.00107145552644442,
	0.00107675373389904,
	0.00108205388358284,
	0.00108751750422307,
	0.0010929811248633,
	0.0010985663035624,
	0.00110420294974046,
	0.00110992229276355,
	0.00111576917818447,
	0.00112163425235141,
	0.00112769728775417,
	0.00113376032315694,
	0.00113998384159905,
	0.0011462539649642,
	0.00115262574783751,
	0.00115911414014664,
	0.00116563858159329,
	0.0011723572377683,
	0.0011790758939433,
	0.00118600055168519,
	0.00119296231825438,
	0.00120008808260684,
	0.00120735458687513,
	0.00121469099641726,
	0.0012222628134627,
	0.00122983463050815,
	0.00123768292890811,
	0.00124555534589431,
	0.0012536226889269,
	0.00126181435979035,
	0.00127010824587836,
	0.00127863908544147,
	0.00128716992500457,
	0.00129605834477527,
	0.00130494958569159,
	0.00131415786849641,
	0.00132351331687698,
	0.0013330373841918,
	0.00134284395443312,
	0.00135267989943836,
	0.00136294133344842,
	0.00137320276745847,
	0.00138383095663463,
	0.00139457791367547,
	0.00140555615760454,
	0.00141682148131154,
	0.00142816263080447,
	0.00143998148274867,
	0.00145180033469287,
	0.00146422508859409,
	0.00147677589302065,
	0.00148968451343801,
	0.00150292529511478,
	0.00151631538733128,
	0.00153026923149225,
	0.00154422307565322,
	0.0015588661124482,
	0.00157358555973087,
	0.00158879086911441,
	0.00160433242081279,
	0.00162012273267228,
	0.00163654703744233,
	0.00165297134221239,
	0.00167056129536625,
	0.00168818408368585,
	0.00170653269115108,
	0.00172524956083767,
	0.00174437577825596,
	0.00176424832026792,
	0.00178417680133417,
	0.00180529216320438,
	0.00182640752507459,
	0.00184850498483911,
	0.00187095505953909,
	0.00189401571114457,
	0.00191789700051755,
	0.00194200954848064,
	0.00196786039017938,
	0.00199371123187812,
	0.00202094638583134,
	0.00204851002065852,
	0.00207698294190708,
	0.00210636871540614,
	0.00213612102092552,
	0.00216743770022235,
	0.00219875437951919,
	0.00223186562287322,
	0.002265220129355,
	0.00229982604316085,
	0.00233537051245975,
	0.00237164959623964,
	0.00240999546852203,
	0.00244834134080443,
	0.00248894817437497,
	0.00252966643108909,
	0.00257194806075135,
	0.00261509453502322,
	0.00265906570390398,
	0.00270467661645154,
	0.00275035058369873,
	0.00279844003097806,
	0.0028465294782574,
	0.00289645928223499,
	0.00294711818927611,
	0.00299895420706446,
	0.0030525049582971,
	0.00310627467351356,
	0.00316206323666071,
	0.00321785179980787,
	0.00327526927706371,
	0.0033331232871328,
	0.00339186369933182,
	0.00345156649098618,
	0.00351153936333938,
	0.00357282849105249,
	0.0036341176187656,
	0.00369654540506517,
	0.00375915692945482,
	0.00382224091084585,
	0.00388570855412807,
	0.00394920891838826,
	0.00401281140376521,
	0.00407641388914217,
	0.00413964933120217,
	0.00420285861084546,
	0.00426546685055729,
	0.00432771333582761,
	0.00438946493209569,
	0.00445014068270667,
	0.00451080183169063,
	0.0045689904384967,
	0.00462717904530277,
	0.0046827504948057,
	0.0047371824516626,
	0.00479003628663365,
	0.00484039541377503,
	0.00489038093094737,
	0.00493596719073419,
	0.00498155345052102,
	0.00502294128211577,
	0.00506306851695057,
	0.00510037945365653,
	0.00513438273809201,
	0.00516724582301342,
	0.00519368232885554,
	0.00522011883469766,
	0.00523922479227554,
	0.00525694987913387,
	0.0052703464034534,
	0.00527993952171863,
	0.00528763827575224,
	0.00528873998943537,
	0.0052898417031185,
	0.00528293164700858,
	0.00527526673464406,
	0.00526218784906826,
	0.00524556973181389,
	0.00522594188059144,
	0.00519913850467686,
	0.00517233512876227,
	0.00513561008799325,
	0.00509874724632348,
	0.00505603020799238,
	0.00501052142847885,
	0.00496196266345832,
	0.0049081637637635,
	0.00485390439145447])
else:
    # read file into np.array. File has one element on each new line.
    print("Loading Positive Control array from file:", emission_PGC_fn)
    emission_PGC_array = np.loadtxt(emission_PGC_fn, delimiter='\n')
print("emission_PGC_array: ",emission_PGC_array[0:3])

def single_fiber_nuc(read_id, group, LINK_width,NUC_width, metadata_cols, emission_NEG_array, emission_PGC_array):
    # Initialize lists to store information about linker and nucleosome regions and their midpoints
    regions_LINK_list = []
    regions_NUC_list = []
    mid_NUC_list = []
    mid_LINK_list = []

    # Calculate the total number of modified bases, total bases, and minimum base position in the read
    MOD_BASE_NUM = len(group['rel_pos'])
    BASE_NUM = max(group['rel_pos']) - min(group['rel_pos']) + 1
    BASE_MIN = min(group['rel_pos'])

    # Initialize the calling_vec with -1 and populate it with mod_qual values based on relative position
    calling_vec = np.full(BASE_NUM+1, -1.0)
    for i in range(len(group['rel_pos'])):
        calling_vec[group.iloc[i]['rel_pos'] - BASE_MIN] = group.iloc[i]['mod_qual']

    # Initialize probability matrix (prob_mat) and pointer matrix (ptr_mat)
    # prob_mat stores log probabilities, and ptr_mat stores the previous state index for backtracking
    prob_mat = np.zeros((BASE_NUM+1, 148))
    ptr_mat = np.full((BASE_NUM+1, 148), -1, dtype=int)

    # Initialization of first row in prob_mat and ptr_mat
    initial_rate = 1 / 148.0
    log_initial_rate = np.log(initial_rate)
    prob_mat[1, :] = log_initial_rate
    ptr_mat[1, :] = 0

    # Dynamic Programming Step: Fill the prob_mat and ptr_mat
    for i in range(2, BASE_NUM+1):
        within_linker = 0.0
        back_frm_ncls = 0.0

        # Compute probabilities for staying within linker and going back to linker from nucleosome
        if calling_vec[i] == -1:
            within_linker = prob_mat[i-1, 0]
            if prob_mat[i-1, 147] != 0:
                back_frm_ncls = prob_mat[i-1, 147]
        else:
            k = int(calling_vec[i] * 1000)
            within_linker = np.log(emission_PGC_array[k]) + prob_mat[i-1, 0]
            if prob_mat[i-1, 147] != 0:
                back_frm_ncls = np.log(emission_PGC_array[k]) + prob_mat[i-1, 147]

        # Update first column of prob_mat and ptr_mat based on calculated probabilities
        if back_frm_ncls != 0 and back_frm_ncls > within_linker:
            prob_mat[i, 0] = back_frm_ncls
            ptr_mat[i, 0] = 147
        else:
            prob_mat[i, 0] = within_linker
            ptr_mat[i, 0] = 0

        # Update second column of prob_mat and ptr_mat
        if calling_vec[i] == -1:
            prob_mat[i, 1] = prob_mat[i-1, 0]
        else:
            k = int(calling_vec[i] * 1000)
            prob_mat[i, 1] = np.log(emission_NEG_array[k]) + prob_mat[i-1, 0]
        ptr_mat[i, 1] = 0

        # Update the remaining columns of prob_mat and ptr_mat
        for j in range(2, 148):
            if calling_vec[i] == -1:
                if prob_mat[i-1, j-1] != 0:
                    prob_mat[i, j] = prob_mat[i-1, j-1]
            else:
                k = int(calling_vec[i] * 1000)
                if prob_mat[i-1, j-1] != 0:
                    prob_mat[i, j] = np.log(emission_NEG_array[k]) + prob_mat[i-1, j-1]
            if prob_mat[i, j] != 0:
                ptr_mat[i, j] = j-1

    # Backtrack to identify most probable states
    max_index = np.argmax(prob_mat[BASE_NUM, :])
    backtrack_vec = []
    for i in range(BASE_NUM, 1, -1):
        backtrack_vec.append(max_index)
        max_index = ptr_mat[i, max_index]
    backtrack_vec.reverse()

    # Identify nucleosome and linker regions based on backtracking results
    ncls_start = 0
    ncls_end = 0
    shift = min(group['rel_pos']) - 1
    InNucleosome = False
    for i, val in enumerate(backtrack_vec):
        if val > 0:
            if not InNucleosome:
                ncls_start = i + 1 + shift
                InNucleosome = True
        else:
            if InNucleosome:
                ncls_end = i + 1 + shift
                ncls_mid = round((ncls_end-ncls_start)/2 + ncls_start, 0)
                if ncls_end - ncls_start < NUC_width:
                    ncls_start = ncls_end - NUC_width
                    ncls_mid = round((ncls_end-ncls_start)/2 + ncls_start, 0)
                regions_NUC_list.append((ncls_start, ncls_end))
                mid_NUC_list.append(ncls_mid)
                InNucleosome = False

    # Check if the nucleosome extends to the end of the read
    if InNucleosome:
        ncls_end = ncls_start + NUC_width
        ncls_mid = round((ncls_end-ncls_start)/2 + ncls_start, 0)
        regions_NUC_list.append((ncls_start, ncls_end))
        mid_NUC_list.append(ncls_mid)

    # Infer linker regions and their midpoints
    for i in range(len(regions_NUC_list)-1):
        regions_LINK_list.append((regions_NUC_list[i][1], regions_NUC_list[i+1][0]))
        mid_LINK_list.append(round((regions_NUC_list[i][1] + regions_NUC_list[i+1][0])/2, 0))

    return read_id, mid_NUC_list, mid_LINK_list, regions_LINK_list, regions_NUC_list, group.iloc[0][metadata_cols]

#version that requires a nuc and linker region (does not allow nucleosomes to be stuck next to eachother)
def single_fiber_nuc_linker(read_id, group, LINK_width,NUC_width, metadata_cols, emission_NEG_array, emission_PGC_array):
    # Initialize lists to store information about linker and nucleosome regions and their midpoints
    regions_LINK_list = []
    regions_NUC_list = []
    mid_NUC_list = []
    mid_LINK_list = []
    mod_qual_LINK_list = []  # List to store mod_qual values for linker regions
    mod_qual_NUC_list = []  # List to store mod_qual values for nucleosome region
    width = LINK_width + NUC_width + 1

    # Calculate the total number of modified bases, total bases, and minimum base position in the read
    BASE_NUM = max(group['rel_pos']) - min(group['rel_pos']) + 1
    BASE_MIN = min(group['rel_pos'])

    # Initialize the calling_vec with -1 and populate it with mod_qual values based on relative position
    calling_vec = np.full(BASE_NUM+1, -1.0)
    for i in range(len(group['rel_pos'])):
        calling_vec[group.iloc[i]['rel_pos'] - BASE_MIN] = group.iloc[i]['mod_qual']


    # Initialize probability matrix (prob_mat) and pointer matrix (ptr_mat)
    # prob_mat stores log probabilities, and ptr_mat stores the previous state index for backtracking
    prob_mat = np.zeros((BASE_NUM+1, width))
    ptr_mat = np.full((BASE_NUM+1, width), -1, dtype=int)

    # Initialization of first row in prob_mat and ptr_mat
    initial_rate = 1 / width
    log_initial_rate = np.log(initial_rate)
    prob_mat[1, :] = log_initial_rate
    ptr_mat[1, :] = 0

    # Dynamic Programming Step: Fill the prob_mat and ptr_mat
    #high_mod_qual_indices = [i for i, val in enumerate(calling_vec) if val > 0.8]  # Store indices of high mod_qual

    #distance_threshold = 20
    for i in range(2, BASE_NUM+1):
        within_linker = 0.0
        back_frm_ncls = 0.0

        # Compute probabilities for staying within linker and going back to linker from nucleosome
        if calling_vec[i] == -1:
            within_linker = prob_mat[i-1, 0]
            if prob_mat[i-1, width-1] != 0:
                back_frm_ncls = prob_mat[i-1, width-1]
        else:
            k = int(calling_vec[i] * 1000)
            within_linker = np.log(emission_PGC_array[k]) + prob_mat[i-1, 0]
            if prob_mat[i-1, width-1] != 0:
                back_frm_ncls = np.log(emission_PGC_array[k]) + prob_mat[i-1, width-1]
        """# Special threshold condition for linker region
                if i in high_mod_qual_indices:
                    for idx in high_mod_qual_indices:
                        if abs(i - idx) < distance_threshold:  # Check distance criterion
                            within_linker += 100  # Boost the probability value
                            break"""

        # Update first column of prob_mat and ptr_mat based on calculated probabilities
        if back_frm_ncls != 0 and back_frm_ncls > within_linker:
            prob_mat[i, 0] = back_frm_ncls
            ptr_mat[i, 0] = width-1
        else:
            prob_mat[i, 0] = within_linker
            ptr_mat[i, 0] = 0

        # Update columns for high methylation (1-30) in the nucleosome
        for j in range(1, LINK_width+1):
            if calling_vec[i] == -1:
                prob_mat[i, j] = prob_mat[i-1, j-1]
            else:
                k = int(calling_vec[i] * 1000)
                prob_mat[i, j] = np.log(emission_PGC_array[k]) + prob_mat[i-1, j-1]
                """# Special threshold condition for linker region
                if i in high_mod_qual_indices:
                    for idx in high_mod_qual_indices:
                        if abs(i - idx) < distance_threshold:  # Check distance criterion
                            prob_mat[i, j] += 100  # Boost the probability value
                            break"""
            ptr_mat[i, j] = j - 1

        # Update columns for low methylation (31-177) in the nucleosome
        for j in range(LINK_width+1, width):
            if calling_vec[i] == -1:
                if prob_mat[i-1, j-1] != 0:
                    prob_mat[i, j] = prob_mat[i-1, j-1]
            else:
                k = int(calling_vec[i] * 1000)
                if prob_mat[i-1, j-1] != 0:
                    prob_mat[i, j] = np.log(emission_NEG_array[k]) + prob_mat[i-1, j-1]
            if prob_mat[i, j] != 0:
                ptr_mat[i, j] = j-1

    # Backtrack to identify most probable states
    current_qual_LINK = []
    current_qual_NUC = []
    max_index = np.argmax(prob_mat[BASE_NUM, :])
    backtrack_vec = []
    for i in range(BASE_NUM, 1, -1):
        backtrack_vec.append(max_index)
        ### STORE MOD-qual values for nuc and link regions
        if max_index > 0:  # In nucleosome
            if calling_vec[i-1] != -1:
                current_qual_NUC.append(calling_vec[i-1])
        else:  # In linker
            if calling_vec[i-1] != -1:
                current_qual_LINK.append(calling_vec[i-1])

        if len(current_qual_NUC) > 0 and (max_index == 0 or i == 2):
            mod_qual_NUC_list.append(np.mean(current_qual_NUC))
            current_qual_NUC = []

        if len(current_qual_LINK) > 0 and (max_index > 0 or i == 2):
            mod_qual_LINK_list.append(np.mean(current_qual_LINK))
            current_qual_LINK = []
        max_index = ptr_mat[i, max_index]
    backtrack_vec.reverse()

    # Identify nucleosome and linker regions based on backtracking results
    ncls_start = 0
    ncls_end = 0
    shift = min(group['rel_pos']) - 1
    InNucleosome = False
    for i, val in enumerate(backtrack_vec):
        if val > 0:
            if not InNucleosome:
                ncls_start = i + 1 + shift
                InNucleosome = True
        else:
            if InNucleosome:
                ncls_end = i + 1 + shift
                ncls_mid = round((ncls_end-ncls_start)/2 + ncls_start, 0)
                if ncls_end - ncls_start < NUC_width:
                    ncls_start = ncls_end - NUC_width
                    ncls_mid = round((ncls_end-ncls_start)/2 + ncls_start, 0)
                regions_NUC_list.append((ncls_start, ncls_end))
                mid_NUC_list.append(ncls_mid)
                InNucleosome = False

    # Check if the nucleosome extends to the end of the read
    if InNucleosome:
        ncls_end = ncls_start + NUC_width
        ncls_mid = round((ncls_end-ncls_start)/2 + ncls_start, 0)
        regions_NUC_list.append((ncls_start, ncls_end))
        mid_NUC_list.append(ncls_mid)

    # Infer linker regions and their midpoints
    for i in range(len(regions_NUC_list)-1):
        regions_LINK_list.append((regions_NUC_list[i][1], regions_NUC_list[i+1][0]))
        mid_LINK_list.append(round((regions_NUC_list[i][1] + regions_NUC_list[i+1][0])/2, 0))

    return read_id, mid_NUC_list, mid_LINK_list, regions_LINK_list, regions_NUC_list, group.iloc[0][metadata_cols], mod_qual_LINK_list, mod_qual_NUC_list

def create_dataframe_from_results(results, kind='NUC', metadata_cols=metadata_cols):
    # if "rel_start" in res[3] == -21 then print res[0:3]
    if kind == 'NUC_mid':
        df = pd.DataFrame.from_dict({res[0]: res[1] for res in results}, orient='index')
    elif kind == 'LINK_mid':
        df = pd.DataFrame.from_dict({res[0]: res[2] for res in results}, orient='index')
    elif kind == 'LINK_region':
        df = pd.DataFrame.from_dict({res[0]: res[3] for res in results}, orient='index')
    elif kind == 'NUC_region':
        df = pd.DataFrame.from_dict({res[0]: res[4] for res in results}, orient='index')
    else:
        raise ValueError(f"Invalid kind: {kind}")

    # Reset the index to make it a new column
    df.reset_index(inplace=True)

    # Rename the new column to something meaningful
    df.rename(columns={'index': 'read_id'}, inplace=True)
    # Adding additional metadata columns to each df.
    # Create a new DataFrame for the metadata_cols
    metadata_df = pd.DataFrame({cols: [res[5][cols] if res[5] is not None else None for res in results] for cols in metadata_cols})
    # Concatenate the new DataFrame and the original DataFrame along axis 1 (columns)
    df = pd.merge(metadata_df, df, on='read_id', how='left')

    return df

def create_dataframe_from_mod_qual(results, kind='mod_qual_NUC', metadata_cols=metadata_cols):
    df = pd.DataFrame.from_dict({res[0]: res[7 if kind == 'mod_qual_NUC' else 6] for res in results}, orient='index')
    df.reset_index(inplace=True)
    df.rename(columns={'index': 'read_id'}, inplace=True)
    metadata_df = pd.DataFrame({cols: [res[5][cols] if res[5] is not None else None for res in results] for cols in metadata_cols})
    df = pd.merge(metadata_df, df, on='read_id', how='left')
    return df

# Using multiprocessing to parallelize the calculations
print("Grouping df...")
# Sort plot_df by read_id then by rel_pos
#print("Plot_df:")
#display(plot_df.head(10))
plot_df.sort_values(by=["read_id","rel_pos"], inplace=True)
plot_df.reset_index(inplace=True, drop=True)

# drop rows where read_id not in first 5 read_ids
first_rows = plot_df['read_id'].unique()#[plot_df['type'] == 'intergenic_control'] [:2000]
#print("first_five: ", first_five)
grouped = plot_df[plot_df['read_id'].isin(first_rows)]
grouped = grouped.groupby('read_id')
#grouped = plot_df.groupby('read_id')

#grouped = grouped_subset.groupby('read_id')

print("Calculating nucleosome positions...")
LINK_width = 1
NUC_width = 146
grouped_data_with_constants = [(read_id,group,LINK_width,NUC_width ,metadata_cols,emission_NEG_array,emission_PGC_array) for read_id,group in grouped]

#processes=multiprocessing.cpu_count()
with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
    # set results equal to pool.starmap() with the function and grouped_data_with_constants as arguments using tqdm to track progress
    results = pool.starmap(single_fiber_nuc_linker, tqdm.tqdm(grouped_data_with_constants, total=len(grouped_data_with_constants)))

#print("creating result dfs...")
midpoint_NUC = create_dataframe_from_results(results, kind='NUC_mid', metadata_cols=metadata_cols)
midpoint_MAD = create_dataframe_from_results(results, kind='LINK_mid', metadata_cols=metadata_cols)
region_MAD = create_dataframe_from_results(results, kind='LINK_region', metadata_cols=metadata_cols)
region_NUC = create_dataframe_from_results(results, kind='NUC_region', metadata_cols=metadata_cols)
mod_qual_LINK = create_dataframe_from_mod_qual(results, kind='mod_qual_LINK', metadata_cols=metadata_cols)
mod_qual_NUC = create_dataframe_from_mod_qual(results, kind='mod_qual_NUC', metadata_cols=metadata_cols)


# drop duplicates based on read_id from all dfs
midpoint_NUC.drop_duplicates(subset=['read_id'], keep='first', inplace=True)
midpoint_MAD.drop_duplicates(subset=['read_id'], keep='first', inplace=True)
region_MAD.drop_duplicates(subset=['read_id'], keep='first', inplace=True)
region_NUC.drop_duplicates(subset=['read_id'], keep='first', inplace=True)
mod_qual_LINK.drop_duplicates(subset=['read_id'], keep='first', inplace=True)
mod_qual_NUC.drop_duplicates(subset=['read_id'], keep='first', inplace=True)

nanotools.display_sample_rows(midpoint_NUC)
nanotools.display_sample_rows(region_NUC)

In [None]:
### Process to add per-read statistics
#define minimum and maximum positions for nucs / MADs
min_pos = -bed_window
max_pos = bed_window

def find_closest_tuple(row,min_pos,max_pos,metadata_cols=metadata_cols, min_size=35,max_bound=1000):
    #print("row:",row)
    closest_distance = max_bound
    closest_tuple = None
    for col_name, tup in zip(row.index, row):
        # Skip metadata columns
        if col_name in metadata_cols:
            continue
        if tup is not None and tup is not np.nan:
            # Check if the tuple is entirely contained within desired region
            if tup[0] >= min_pos and tup[1] <= max_pos and abs(tup[0] - tup[1]) >= min_size:
                distance_1 = abs(tup[0] - 0)
                distance_2 = abs(tup[1] - 0)

                min_distance = min(distance_1, distance_2)

                if min_distance < closest_distance:
                    closest_distance = min_distance
                    closest_tuple = tup

        if closest_tuple is None:
            if tup is not None and tup is not np.nan:
                # Check if the tuple is entirely contained within desired region
                if tup[0] >= min_pos and tup[1] <= max_pos:
                    distance_1 = abs(tup[0] - 0)
                    distance_2 = abs(tup[1] - 0)

                    min_distance = min(distance_1, distance_2)

                    if min_distance < closest_distance:
                        closest_distance = min_distance
                        closest_tuple = tup
    return closest_tuple

### Function to calculate percent_MAD
def calc_percent_region(group, min_pos, max_pos, metadata_cols):
    total_MAD = 0
    read_length = 0  # To store the total length for all reads
    for index, row in group.iterrows():
        read_start = row['rel_read_start']
        read_start = max(min_pos, read_start)
        read_end = row['rel_read_end']
        read_end = min(max_pos, read_end)
        read_length = (read_end - read_start)


        for col in group.columns:
            if col in metadata_cols:  # Skip metadata columns
                continue
            mad_tup = row[col]
            if pd.notna(mad_tup):
                # if at least one of mad_tup[0] and mad_tup[1] is within [min_pos, max_pos]

                # Only include the part of the tuple that is within [min_pos, max_pos]
                min_val = max(read_start, mad_tup[0])
                max_val = min(read_end, mad_tup[1])

                if (min_pos <= mad_tup[0] <= max_pos) or (min_pos <= mad_tup[1] <= max_pos):
                    try:
                        total_MAD += (max_val - min_val)
                    except:
                        print("col:", col)
                        print("Warning: mad_tup is None or NaN:", mad_tup)

    percent_MAD = total_MAD / read_length if read_length > 0 else None
    if percent_MAD is not None:
        if (percent_MAD < 0 or percent_MAD >1):
            print("row:")
            print(row)
            print("total_MAD:", total_MAD)
            print("read_length:", read_length)
            print("percent_MAD:", percent_MAD)
    return percent_MAD

def calculate_fiber_NRL_list(row, metadata_cols,sign_filter=None):
    row_values = row.drop(metadata_cols).dropna().values  # Drop metadata columns and NaNs
    # If row_values is empty, return an empty list
    if len(row_values) == 0:
        return []

    # Flatten the array, taking into account that some elements may be lists
    n_values = np.concatenate([np.array(x) if isinstance(x, list) else np.array([x]) for x in row_values])

    if n_values.size == 0:
        return []

    #print("n_values:", n_values)

    if sign_filter == '+':
        filter_func = lambda x: x > 0
    elif sign_filter == '-':
        filter_func = lambda x: x < 0
    else:
        filter_func = lambda x: True

    # Filter the n_values
    n_values = np.array([x for x in n_values if filter_func(x)])

    # Check if n_values is empty after filtering or originally
    if n_values.size == 0:
        return []

    # Calculate the list of differences (each item from the next
    #fiber_NRL_list = [n_values[j] - n_values[i]
    #                  for i in range(len(n_values))
    #                  for j in range(i + 1, len(n_values))]

    # Calculate the list of differences for n_values greater than each item
    fiber_NRL_list = [n_values[j] - n_values[i]
                      for i in range(len(n_values))
                      for j in range(i + 1, len(n_values))
                      if n_values[j] > n_values[i]]

    return fiber_NRL_list

def closest_nuc(row):
    smallest = row['smallest_positive_nuc_midpoint']
    greatest = row['greatest_negative_nuc_midpoint']

    if np.isnan(smallest) and np.isnan(greatest):
        return np.nan

    if np.isnan(smallest):
        return greatest

    if np.isnan(greatest):
        return smallest

    return smallest if abs(smallest) < abs(greatest) else greatest

merged_df = pd.DataFrame()

### Calculate smallest positive and greatest negative midpoints
midpoint_NUC['smallest_positive_nuc_midpoint'] = midpoint_NUC.apply(lambda row: min([x for x in row[len(metadata_cols):] if (max_pos > x >= 0)], default=None), axis=1)
midpoint_NUC['greatest_negative_nuc_midpoint'] = midpoint_NUC.apply(lambda row: max([x for x in row[len(metadata_cols):] if (min_pos < x < 0)], default=None), axis=1)
midpoint_NUC['closest_nuc'] = midpoint_NUC.apply(closest_nuc, axis=1)
midpoint_NUC['inter_nuc_dist'] = midpoint_NUC['smallest_positive_nuc_midpoint'] - midpoint_NUC['greatest_negative_nuc_midpoint']
# set closest_region in region_MAD to be qual to the tupple who's element is closest to 0 in ros 0-19
region_MAD['closest_MAD_region'] = region_MAD.apply(lambda row: find_closest_tuple(row, min_pos=min_pos, max_pos=max_pos, metadata_cols=metadata_cols, min_size=35,max_bound = max_pos), axis=1)
region_MAD['MAD_size'] = region_MAD['closest_MAD_region'].apply(lambda x: x[1] - x[0] if x is not None else None)
region_MAD['closest_MAD_midpoint'] = region_MAD.apply(lambda row: (row['closest_MAD_region'][0] + row['closest_MAD_region'][1]) / 2 if row['closest_MAD_region'] is not None else None, axis=1)
merged_df = pd.merge(plot_df, midpoint_NUC[['read_id', 'smallest_positive_nuc_midpoint', 'greatest_negative_nuc_midpoint','inter_nuc_dist','closest_nuc']], on='read_id', how='left')
merged_df = pd.merge(merged_df, region_MAD[['read_id', 'closest_MAD_region','MAD_size','closest_MAD_midpoint']], on='read_id', how='left')
# drop merged columns from orginal dfs
midpoint_NUC.drop(columns=['smallest_positive_nuc_midpoint', 'greatest_negative_nuc_midpoint','inter_nuc_dist','closest_nuc'], inplace=True)
region_MAD.drop(columns=['closest_MAD_region','MAD_size','closest_MAD_midpoint'], inplace=True)

# 1. Merge the DataFrames
### Calculate % MAD and % NUC
min_calc_region = -500
max_calc_region = 100
print("Status: Calculating percent_MAD...")
grouped_mad = region_MAD.groupby('read_id').apply(calc_percent_region, min_pos=min_calc_region, max_pos=max_calc_region, metadata_cols=metadata_cols).reset_index(name='percent_MAD')
print("Status: Calculating percent_NUC...")
grouped_nuc = region_NUC.groupby('read_id').apply(calc_percent_region, min_pos=min_calc_region, max_pos=max_calc_region, metadata_cols=metadata_cols).reset_index(name='percent_NUC')
merged_df = pd.merge(merged_df, grouped_mad, on='read_id', how='left')
merged_df = pd.merge(merged_df, grouped_nuc, on='read_id', how='left')
merged_df['percent_OTHER'] = 1 - merged_df['percent_MAD'] - merged_df['percent_NUC']

### Create 'fiber_NRL_list' column in the midpoint_NUC DataFrame
midpoint_NUC['fiber_NRL_list'] = midpoint_NUC.apply(lambda row: calculate_fiber_NRL_list(row, metadata_cols,sign_filter=None), axis=1)
midpoint_NUC['fiber_NRL_list_pos'] = midpoint_NUC.apply(lambda row: calculate_fiber_NRL_list(row, metadata_cols,sign_filter="+"), axis=1)
midpoint_NUC['fiber_NRL_list_neg'] = midpoint_NUC.apply(lambda row: calculate_fiber_NRL_list(row, metadata_cols,sign_filter="-"), axis=1)
# Merge the new 'fiber_NRL_list' column into the 'merged_df' DataFrame
merged_df = pd.merge(merged_df, midpoint_NUC[['read_id', 'fiber_NRL_list','fiber_NRL_list_pos','fiber_NRL_list_neg']], on='read_id', how='left')
# Restore midpoint_NUC to it's original state
midpoint_NUC.drop(columns=['fiber_NRL_list','fiber_NRL_list_pos','fiber_NRL_list_neg'], inplace=True)

# 2. Sort the merged DataFrame
merged_df = merged_df.sort_values(by=['percent_MAD'], ascending=True)
# reset index
merged_df.reset_index(inplace=True, drop=True)

# Group by read_id and display 10 unique read_id rows
grouped = merged_df.copy()
# drop dupicate rows by read_id
grouped.drop_duplicates(subset=['read_id'], inplace=True)
# sort by read_id
grouped.sort_values(by=['read_id'], inplace=True)
# Set grouped column "rel_read_length" equal to rel_read_end - rel_read_start
grouped.reset_index(inplace=True, drop=True)

#set combined_nucs_df equal to columns 'read_id' and all columns with integer column names.
combined_nucs_df = midpoint_NUC[['read_id'] + [col for col in midpoint_NUC.columns if isinstance(col, int)]]
# add a new column that creates a list of all the nucleosome positions for each read that are not nan
combined_nucs_df['nucs_list'] = combined_nucs_df.apply(lambda row: [x for x in row[1:] if pd.notna(x)], axis=1)
grouped = pd.merge(grouped,combined_nucs_df[['read_id','nucs_list']], on='read_id', how='left')

print("Creating aligned nuc and mad lists...")
def align_nuc_list(row, column_name):
    subtract_val = row[column_name]
    if isinstance(subtract_val, tuple):
        midpoint = (subtract_val[0] + subtract_val[1]) / 2.0
    else:
        midpoint = subtract_val

    if isinstance(midpoint, (int, float)) and not np.isnan(midpoint):
        return [x - midpoint for x in row['nucs_list']]
    else:
        return []

def calculate_internuc(nuc_list):
    return [j - i for i, j in zip(nuc_list[:-1], nuc_list[1:])]

def align_nuc_list_internuc(row):
    smallest_positive = row['smallest_positive_nuc_midpoint']
    greatest_negative = row['greatest_negative_nuc_midpoint']

    # Check if both are NaN
    if np.isnan(smallest_positive) or np.isnan(greatest_negative):
        return []
    else:
        midpoint = (smallest_positive + greatest_negative) / 2.0

    return [x - midpoint for x in row['nucs_list']]

def largest_gap_size(nucs_list, nuc_width=147):
    if len(nucs_list) < 2:
        return np.nan
    sorted_list = sorted(nucs_list)
    max_gap = max(np.diff(sorted_list)) - nuc_width
    return max_gap

def largest_gap_pos(nucs_list, nuc_width=147):
    if len(nucs_list) < 2:
        return np.nan
    sorted_list = sorted(nucs_list)
    max_gap = max(np.diff(sorted_list))
    # Find the index of the largest gap
    max_gap_index = np.argmax(np.diff(sorted_list))
    # Calculate the center position of the largest gap
    center_position = (sorted_list[max_gap_index] + sorted_list[max_gap_index + 1]) / 2
    return center_position

grouped = grouped.dropna(subset=['nucs_list'])

grouped['inter_nuc_sub'] = grouped['nucs_list'].apply(calculate_internuc)
grouped['largest_nfr_size'] = grouped['nucs_list'].apply(largest_gap_size, NUC_width)
grouped['largest_nfr_pos'] = grouped['nucs_list'].apply(largest_gap_pos, NUC_width)
grouped['nuc_list_internuc_aligned'] = grouped.apply(align_nuc_list_internuc, axis=1)
grouped['nucs_list_closest_aligned'] = grouped.apply(align_nuc_list, args=('closest_nuc',), axis=1)
grouped['nucs_list_largest_nfr_aligned'] = grouped.apply(align_nuc_list, args=('largest_nfr_pos',), axis=1)


if 'exp_id' not in grouped.columns:
    import pysam
    # Check that all lists have the same length
    if not (len(new_bam_files) == len(exp_ids)):
        raise ValueError("The lists new_bam_files, exp_ids, and conditions must have the same length.")

    # Initialize an empty list to store the data
    data = []

    print("Status: Looping through bam files to append exp_ids")
    # Loop over bam files, their corresponding experiment ids, and conditions
    for bam_file, exp_id in zip(new_bam_files, exp_ids):
        # Open the bam file
        with pysam.AlignmentFile(bam_file, "rb") as bam:
            # Fetch the first 3 read ids as an example
            # In practice, you would iterate over all reads as needed
            for read in bam.fetch(): # Adjust the number as needed for actual use
                data.append({
                    "bam_file_name": bam_file,
                    "read_id": read.query_name,
                    "exp_id": exp_id,
                })

    # Convert the list of dictionaries to a DataFrame
    read_exp_ids_df = pd.DataFrame(data)



    # Add exp_id column to grouped dataframe by merging on read_id
    # if exp_id column does not exist in grouped then:

    grouped = pd.merge(grouped, read_exp_ids_df[['read_id', 'exp_id']], on='read_id', how='left')

nanotools.display_sample_rows(grouped, 5)
nanotools.display_sample_rows(merged_df, 5)

In [None]:

### Calculate positioning statistics
# Function to compute the midpoint
"""def compute_midpoint(row):
    pos_values = [v for v in row[integer_columns] if v > 0]
    neg_values = [v for v in row[integer_columns] if v < 0]

    # If either list is empty, return NaN as the midpoint
    if not pos_values or not neg_values:
        return float('nan')

    least_pos = min(pos_values)
    greatest_neg = max(neg_values)

    return (least_pos + greatest_neg) / 2"""

def calculate_feature(args):
    occ_cutoff, midpoint_NUC, NUC_max_width, metadata_cols, align_to, condition, chr_type, dtype  = args

    # Filter the dataframe based on condition, chr_type, and dtype
    print("Processing with following filters:", condition, chr_type, dtype, sep="\n")
    midpoint_NUC_filtered = midpoint_NUC.copy(deep=True)
    midpoint_NUC_filtered = midpoint_NUC_filtered[
        (midpoint_NUC_filtered['condition'] == condition) &
        (midpoint_NUC_filtered['chr_type'] == chr_type) &
        (midpoint_NUC_filtered['type'] == dtype) &
        (midpoint_NUC_filtered[align_to] is not np.nan)
    ]

    # Identify columns with integer names
    integer_columns = [col for col in midpoint_NUC_filtered.columns if isinstance(col, int)]

    # Shift columns by the midpoint
    #for col in integer_columns:
    #    midpoint_NUC_filtered[col] = midpoint_NUC_filtered[col] - midpoint_NUC_filtered[align_to]

    midpoint_NUC_filtered.set_index('read_id', inplace=True)
    #print("midpoint_NUC_filtered:")
    #display(midpoint_NUC_filtered.head(20))
    results = {
        'mean_nuc_pos': [],
        'total_reads': [],
        'total_nucs': [],
        'percent_occ': [],
        'nucs_list': [],
        'subs_list': []
    }

    used_nucleosomes = set()

    print(f"Looping through  {len(midpoint_NUC_filtered)} reads...")
    iter=0
    for idx, row in midpoint_NUC_filtered.iterrows():
        # print progress whenever idx % 2000 == 0
        iter += 1
        if iter % 200 == 0:
            print("Processing read:", iter, sep='\n')
        nuc_positions = row[row.index.difference(metadata_cols)].dropna()
        for nuc in nuc_positions:
            if (idx,nuc) in used_nucleosomes:
                #print(f"Warning: Duplicate nucleosome found for read_id: {idx}, nuc: {nuc}")
                continue

            temp_used_nucleosomes = set()
            temp_subs_list = []
            temp_mean_subs_list = []
            #print("idx:", idx, " | nuc:", nuc)

            temp_subtractions = abs(midpoint_NUC_filtered[row.index.difference(metadata_cols)] - nuc)
            #Drop the current row from the subtractions table
            temp_subtractions.drop(index=idx, inplace=True)
            #temp_subtractions = temp_subtractions.apply(lambda row: row.where(row == row.min(), np.nan), axis=1)
            temp_subtractions = temp_subtractions.apply(lambda row: row.nsmallest(2), axis=1)
            temp_subtractions = temp_subtractions.where(temp_subtractions <= ((NUC_max_width / 2 + 12)), np.nan)
            #print("temp_subtractions:")
            #display(temp_subtractions.head(20))

            non_nan_rows = temp_subtractions.dropna(how='all')
            # add current nucleosome to temp_used_nucleosomes
            temp_used_nucleosomes.add((idx, nuc))

            for used_row_idx in non_nan_rows.index:
                for col in non_nan_rows.columns:
                    if not pd.isna(non_nan_rows.at[used_row_idx, col]) and (used_row_idx, midpoint_NUC_filtered.at[used_row_idx, col]) not in used_nucleosomes:
                        original_nuc_value = midpoint_NUC_filtered.at[used_row_idx, col]
                        temp_used_nucleosomes.add((used_row_idx, original_nuc_value))
                        temp_subs_list.append(abs(original_nuc_value - nuc))


            used_nucleosomes.update(temp_used_nucleosomes)
            #print(f"Used nucleosomes: {used_nucleosomes}")

            # Calculate the results
            mean_nuc = np.mean([x[1] for x in temp_used_nucleosomes])
            total_reads = len(midpoint_NUC_filtered)
            total_nucs = len(temp_used_nucleosomes)
            percent_occ = total_nucs / total_reads if total_reads else 0
            nucs_list = [x[1] for x in temp_used_nucleosomes]

            # Append results
            results['mean_nuc_pos'].append(mean_nuc)
            results['total_reads'].append(total_reads)
            results['total_nucs'].append(total_nucs)
            results['percent_occ'].append(percent_occ)
            results['nucs_list'].append(nucs_list)
            results['subs_list'].append(temp_subs_list)
            #print(f"Results: {results}")

    results_df = pd.DataFrame(results)
    results_df['condition'] = condition
    results_df['chr_type'] = chr_type
    results_df['type'] = dtype
    # Sort the results_df by mean_nuc_pos
    # drop all rows where percent_occ is < occ_cutoff
    results_df = results_df[results_df['percent_occ'] >= occ_cutoff]
    results_df.sort_values('mean_nuc_pos', inplace=True)
    results_df.reset_index(drop=True, inplace=True)

    # Add the "nuc_id" column
    # Count negative and positive values to get the maximum n- and n+ steps
    neg_count = sum(results_df['mean_nuc_pos'] < 0)
    pos_count = 1

    # Initialize nuc_id list
    nuc_id_list = []

    # Loop through DataFrame and assign nuc_id
    for idx, row in results_df.iterrows():
        if row['mean_nuc_pos'] < 0:
            nuc_id = f'n-{neg_count}'
            neg_count -= 1
        elif row['mean_nuc_pos'] > 0:
            nuc_id = f'n+{pos_count}'
            pos_count += 1
        else:
            nuc_id = 'n'
        nuc_id_list.append(nuc_id)

    results_df['nuc_id'] = nuc_id_list
    return results_df

def calculate_feature2(args):
    bin_width, occ_cutoff, midpoint_NUC, NUC_max_width, metadata_cols, align_to, condition, chr_type, dtype  = args

    # Filter the dataframe based on condition, chr_type, and dtype
    print("Processing with following filters:", condition, chr_type, dtype, sep="\n")
    midpoint_NUC_filtered = midpoint_NUC.copy(deep=True)
    midpoint_NUC_filtered = midpoint_NUC_filtered[
        (midpoint_NUC_filtered['condition'] == condition) &
        (midpoint_NUC_filtered['chr_type'] == chr_type) &
        (midpoint_NUC_filtered['type'] == dtype)
    ]

    # Identify columns with integer names
    integer_columns = [col for col in midpoint_NUC_filtered.columns if isinstance(col, int)]


    # Shift columns by the midpoint
    if align_to is not None:
        # drop rows where from midpoint_NUC_filtered where align_to column is not np.nan
        print("Dropping this many rows where align_to column is np.nan:", len(midpoint_NUC_filtered[midpoint_NUC_filtered[align_to].isna()]))
        midpoint_NUC_filtered = midpoint_NUC_filtered[midpoint_NUC_filtered[align_to].notna()]
        for col in integer_columns:
            midpoint_NUC_filtered[col] = midpoint_NUC_filtered[col] - midpoint_NUC_filtered[align_to]
        # drop align_to column
        midpoint_NUC_filtered.drop(columns=[align_to], inplace=True)

    # drop rows with duplicate read_ids
    # print rows with duplicate read_ids in midpoint_NUC_filtered
    midpoint_NUC_filtered.drop_duplicates(subset=['read_id'], keep='first', inplace=True)
    #reset index
    midpoint_NUC_filtered.reset_index(drop=True, inplace=True)
    midpoint_NUC_filtered.set_index('read_id', inplace=True)
    #print("midpoint_NUC_filtered:")
    #display(midpoint_NUC_filtered.head(20))
    results = {
        'mean_nuc_pos': [],
        'std_nuc_pos': [],
        'total_reads': [],
        'total_nucs': [],
        'percent_occ': [],
        'nucs_list': [],
        'subs_list': []
    }

    used_nucleosomes = set()

    # Initialize list to hold mean nucleosome positions
    mean_nuc_positions = []

    print(f"Looping through  {len(midpoint_NUC_filtered)} reads...")
    iter=0
    for idx, row in midpoint_NUC_filtered.iterrows():
        # print progress whenever idx % 2000 == 0
        iter += 1
        if iter % 200 == 0:
            print("Processing read:", iter, sep='\n')

        nuc_positions = row[row.index.difference(metadata_cols)].dropna()
        # drop nuc_positions that are within bin_width of any mean_nuc_positions
        #nuc_positions = [x for x in nuc_positions if not any(abs(mean_nuc - x) <= bin_width for mean_nuc in mean_nuc_positions)]

        # drop nuc_positions that are within |5| of any value in results['std_nuc_pos']
        nuc_positions = [x for x in nuc_positions if not any(abs(x - std_nuc) <= 5 for std_nuc in results['std_nuc_pos'])]

        #nuc_positions = [x for x in nuc_positions if x not in results['std_nuc_pos']]

        #print(nuc_positions)
        for nuc in nuc_positions:
            #if (idx,nuc) in used_nucleosomes:
                #print(f"Warning: Duplicate nucleosome found for read_id: {idx}, nuc: {nuc}")
            #    continue

            temp_used_nucleosomes = set()
            temp_subs_list = []
            temp_mean_subs_list = []
            #print("idx:", idx, " | nuc:", nuc)

            temp_subtractions = abs(midpoint_NUC_filtered[row.index.difference(metadata_cols)] - nuc)
            #Drop the current row from the subtractions table
            temp_subtractions.drop(index=idx, inplace=True)
            #temp_subtractions = temp_subtractions.apply(lambda row: row.where(row == row.min(), np.nan), axis=1)
            temp_subtractions = temp_subtractions.apply(lambda row: row.nsmallest(2), axis=1)
            temp_subtractions = temp_subtractions.where(temp_subtractions <= ((NUC_max_width / 2 + 12)), np.nan)
            #print("temp_subtractions:")
            #display(temp_subtractions.head(20))

            non_nan_rows = temp_subtractions.dropna(how='all')
            # add current nucleosome to temp_used_nucleosomes
            temp_used_nucleosomes.add((idx, nuc))

            for used_row_idx in non_nan_rows.index:
                for col in non_nan_rows.columns:
                    # check if non_nan_rows.at[used_row_idx, col] is a series
                    if isinstance(non_nan_rows.at[used_row_idx, col], pd.Series):
                        print("Warning: non_nan_rows.at[used_row_idx, col] is a series")
                        print("non_nan_rows.at[used_row_idx, col]:", non_nan_rows.at[used_row_idx, col])
                        print("non_nan_rows.at[used_row_idx, col].values:", non_nan_rows.at[used_row_idx, col].values)
                    if not pd.isna(non_nan_rows.at[used_row_idx, col]) and (used_row_idx, midpoint_NUC_filtered.at[used_row_idx, col]) not in used_nucleosomes:
                        original_nuc_value = midpoint_NUC_filtered.at[used_row_idx, col]
                        temp_used_nucleosomes.add((used_row_idx, original_nuc_value))
                        temp_subs_list.append(abs(original_nuc_value - nuc))

            mean_nuc = np.mean([x[1] for x in temp_used_nucleosomes])
            mean_nuc_positions.append(mean_nuc)

            used_nucleosomes.update(temp_used_nucleosomes)
            #print(f"Used nucleosomes: {used_nucleosomes}")

            # Calculate the results
            total_reads = len(midpoint_NUC_filtered)
            total_nucs = len(temp_used_nucleosomes)
            percent_occ = total_nucs / total_reads if total_reads else 0
            nucs_list = [x[1] for x in temp_used_nucleosomes]

            # Append results
            results['mean_nuc_pos'].append(mean_nuc)
            results['std_nuc_pos'].append(nuc)
            results['total_reads'].append(total_reads)
            results['total_nucs'].append(total_nucs)
            results['percent_occ'].append(percent_occ)
            results['nucs_list'].append(nucs_list)
            results['subs_list'].append(temp_subs_list)
            #print(f"Results: {results}")

    results_df = pd.DataFrame(results)
    results_df['condition'] = condition
    results_df['chr_type'] = chr_type
    results_df['type'] = dtype
    # Sort the results_df by mean_nuc_pos
    # drop all rows where percent_occ is < occ_cutoff
    results_df = results_df[results_df['percent_occ'] >= occ_cutoff]
    results_df.sort_values('mean_nuc_pos', inplace=True)
    results_df.reset_index(drop=True, inplace=True)

    # Initialize an empty dataframe to hold the final result
    final_df = pd.DataFrame()
    # Find the min and max std_nuc_pos
    min_std_nuc_pos = -bed_window #results_df['std_nuc_pos'].min()
    max_std_nuc_pos = bed_window #results_df['std_nuc_pos'].max()

    results_df = results_df.drop_duplicates(subset=['std_nuc_pos'])
    # Create bins
    bins = np.arange(min_std_nuc_pos, max_std_nuc_pos, bin_width)

    # Initialize a list to hold the rows for each bin
    bin_rows = []

    for lower_bound in bins:
        upper_bound = lower_bound + bin_width

        lower_bound = lower_bound + bin_width/4
        upper_bound = upper_bound - bin_width/4

        # Filter rows where mean_nuc_pos falls within the bin
        bin_data = results_df[(results_df['std_nuc_pos'] >= lower_bound) & (results_df['std_nuc_pos'] < upper_bound)]

        if not bin_data.empty:
            # Find the row with the lowest average of all elements in subs_list
            #bin_data['avg_subs'] = bin_data['subs_list'].apply(np.mean)
            #min_avg_subs_row = bin_data[bin_data['avg_subs'] == bin_data['avg_subs'].min()].copy()

            # Find the row with the max occupancy
            min_avg_subs_row = bin_data[bin_data['percent_occ'] == bin_data['percent_occ'].max()].copy()

            # Add bin_start and bin_end columns
            min_avg_subs_row['bin_start'] = lower_bound
            min_avg_subs_row['bin_end'] = upper_bound
            min_avg_subs_row['bin_pos'] = lower_bound + (upper_bound - lower_bound)/2


            bin_rows.append(min_avg_subs_row)

    # Concatenate all rows for each bin into a single dataframe
    final_df = pd.concat(bin_rows)
    # sort by std_nuc_pos then reset index
    final_df.sort_values(by=['std_nuc_pos'], inplace=True)
    final_df.reset_index(inplace=True, drop=True)

    # Drop the temporary column used for calculation
    #final_df.drop(columns=['avg_subs'], inplace=True)

    # Add the "nuc_id" column
    # Count negative and positive values to get the maximum n- and n+ steps
    neg_count = sum(results_df['std_nuc_pos'] < 0)
    pos_count = 1

    # Initialize nuc_id list
    nuc_id_list = []

    # Loop through DataFrame and assign nuc_id
    for idx, row in final_df.iterrows():
        if row['mean_nuc_pos'] < 0:
            nuc_id = f'n-{neg_count}'
            neg_count -= 1
        elif row['mean_nuc_pos'] > 0:
            nuc_id = f'n+{pos_count}'
            pos_count += 1
        else:
            nuc_id = 'n'
        nuc_id_list.append(nuc_id)

    final_df['nuc_id'] = nuc_id_list
    return final_df

if __name__ == '__main__':
    specific_comps = [
    ('N2_fiber', 'X', 'xol-1_TSS'),
    #('N2_fiber', 'X', 'intergenic_control'),
    ('SDC2_degron_fiber', 'X', 'xol-1_TSS')]
    #('SDC2_degron_fiber', 'X', 'intergenic_control')]
    # set occupancy cutoff, (ignore nucleosomes with occupancy below cutoff)
    occ_cutoff = 0
    bin_width = 40

    # Align to inter-nuc distance?
    #align_to = None # "smallest_positive_nuc_midpoint", "greatest_negative_nuc_midpoint", "closest_nuc", "inter_nuc_dist" or None "closest_nuc"
    #align_to = "closest_nuc"
    align_to = "closest_MAD_midpoint"

    # create copy of midpoint_NUC and merge with grouped on read_id, keeping align_to column
    midpoint_NUC_merged = midpoint_NUC.copy(deep=True)
    if align_to is not None:
        print("Merging...")
        midpoint_NUC_merged = pd.merge(midpoint_NUC, grouped[['read_id', align_to]], on='read_id', how='left')

    # keep first 100 rows of midpoint_NUC_merged
    #midpoint_NUC_merged = midpoint_NUC_merged.sample(n=100)
    args_list = [(bin_width, occ_cutoff,midpoint_NUC_merged, NUC_width, metadata_cols, align_to, c, ch, d) for c, ch, d in specific_comps]

    #processes=multiprocessing.cpu_count()
    with Pool(processes=multiprocessing.cpu_count()) as pool:
        results = pool.map(calculate_feature2, args_list)

    combined_res_df = pd.concat(results)

    # Display first 100 rows
    nanotools.display_sample_rows(combined_res_df, 5)

In [None]:
### PLOT NRL STATISTICS
import seaborn as sns
importlib.reload(nanotools)
sns.set_style("white")

# Your DataFrame is assumed to be in a variable named `df`

# Create an empty list to store the data for the distplot
distplot_df = pd.DataFrame()
#grouped_plot = grouped[grouped['fiber_NRL_list'].apply(lambda x: len(x) > 0)]
grouped_plot = grouped.copy()
grouped_plot["rel_read_len"] = grouped_plot["rel_read_end"] - grouped_plot["rel_read_start"]
# drop all rows where rel_read_end < bed_window * 2 -100
#display(grouped_plot.sample(n=10))
#grouped_plot = grouped_plot[grouped_plot["rel_read_len"] >= (bed_window * 2 - 100)]
#display(grouped_plot.sample(n=10))
# sample_num = minimumum counts of each unique combination of 'type' and 'condition'
#sample_num = grouped_plot[['type', 'condition','read_id']].drop_duplicates().groupby(['type', 'condition']).size().min()
#print("sample_num:", sample_num)
#grouped_plot = grouped_plot.groupby(['condition','type']).apply(lambda x: x.sample(sample_num)).reset_index(drop=True)


# Iterate over each unique combination of 'type' and 'condition'
for unique_comb in grouped_plot[['type', 'condition','chr_type']].drop_duplicates().values:
    type_val, condition_val, chr_type_val = unique_comb
    print("condition_val:", condition_val, sep="\n")
    # Filter data based on the unique combination of 'type' and 'condition'
    sub_df = grouped_plot[(grouped['type'] == type_val) & (grouped_plot['condition'] == condition_val) & (grouped_plot['chr_type'] == chr_type_val)]
    # explod sub_df['fiber_NRL_list'] and drop NaN values and convert to column in fiber_data called "dist"
    fiber_data = sub_df['fiber_NRL_list'].explode().dropna().to_frame()
    fiber_data.rename(columns={'fiber_NRL_list': 'dist'}, inplace=True)
    # Add a column called "type" with the value of type_val and condition_val
    fiber_data['condition'] = f"{condition_val}" #{type_val}_
    fiber_data['type'] = f"{type_val}"
    fiber_data['chr_type'] = f"{chr_type_val}"
    # reset index
    fiber_data.reset_index(inplace=True, drop=True)
    #print length of fiber_data
    print("len(fiber_data):", len(fiber_data['condition']))
    print("len(sub_df):", len(sub_df['type']))
    #display(fiber_data.head(10))
    # if distplot_df is empty, set distplot_df equal to fiber_data
    if distplot_df.empty:
        distplot_df = fiber_data
    # else append fiber_data to distplot_df
    else:
        distplot_df = distplot_df.append(fiber_data)

display(distplot_df.sample(n=20))
peaks1, fig1 = nanotools.plot_NRL_dist(distplot_df[(distplot_df['condition'] == "N2_fiber") & (distplot_df['type'] == type_selected[4])] ,
                                       "X","#16415e","N2",smoothing_val=0.2)
display(peaks1.head(10))
peaks2, fig2 = nanotools.plot_NRL_dist(distplot_df[(distplot_df['condition'] == "SDC2_degron_fiber") & (distplot_df['type'] == type_selected[4])],
                                       "X","#16415e","SDC2_degron",smoothing_val=0.2)
display(peaks2.head(10))

fig = nanotools.plot_NRL_dist_compare(distplot_df[(distplot_df['type'] == type_selected[4])] ,
                                      "X","N2 and SDC2_degron",smoothing_val=0.2,norm_bool=False,hue='condition')


fig.savefig('images/dpy27_sdc2_sdc3_fiber_comparison_NRL.png', dpi=300, bbox_inches='tight')
fig.savefig('images/dpy27_sdc2_sdc3_fiber_comparison_NRL.svg', bbox_inches='tight')

fig1.savefig('images/dpy27_sdc2_sdc3_fiber_NRL_N2.png', dpi=300, bbox_inches='tight')
fig1.savefig('images/dpy27_sdc2_sdc3_fiber_NRL_N2.svg', bbox_inches='tight')

fig2.savefig('images/dpy27_sdc2_sdc3_fiber_NRL_SDC2_degron.png', dpi=300, bbox_inches='tight')
fig2.savefig('images/dpy27_sdc2_sdc3_fiber_NRL_SDC2_degron.svg', bbox_inches='tight')

fig.show()

In [None]:
import seaborn as sns
importlib.reload(nanotools)
sns.set_style("white")

# Your DataFrame is assumed to be in a variable named `df`

# Create an empty list to store the data for the distplot
distplot_df = pd.DataFrame()

align_on = 'nuc_list_largest_nfr_aligned'
    #"nucs_list"
    # 'nuc_list_internuc_aligned'
    # "nucs_list"
    # "nucs_list_MAD_aligned"
    # 'nucs_list_n1_aligned'

# drop rows from grouped where align_on column is a float
grouped_plot = grouped[grouped[align_on].apply(lambda x: isinstance(x, list))]
# drop rows where align_on column == []
grouped_plot = grouped_plot[grouped_plot[align_on].apply(lambda x: len(x) > 0)]
display(grouped_plot.head(10))
grouped_plot["rel_read_len"] = grouped_plot["rel_read_end"] - grouped_plot["rel_read_start"]
# drop all rows where rel_read_end < bed_window * 2 -100
#grouped_plot = grouped_plot[grouped_plot["rel_read_len"] >= (bed_window * 2 - 400)]
#sample_num = min(grouped_plot['condition'].value_counts())
#grouped_plot = grouped.groupby('condition').apply(lambda x: x.sample(sample_num)).reset_index(drop=True)
#display(grouped_plot.sample(n=10))
# print number of N2_fiber reads and number of SDC2_degron_fiber reads
print("N2_fiber reads:", len(grouped_plot[grouped_plot['condition'] == "N2_fiber"]))
print("SDC2_degron_fiber reads:", len(grouped_plot[grouped_plot['condition'] == "SDC2_degron_fiber"]))

# select type for plotting
grouped_plot = grouped_plot[grouped_plot['type'] == type_selected[0]]
"""# find the number of reads in each condition, and set num_reads to the max
num_reads = min(grouped_plot['condition'].value_counts())
# drop reads from condition with most reads to match reads from condition with least reads
grouped_plot = grouped_plot.groupby('condition').apply(lambda x: x.sample(num_reads)).reset_index(drop=True)"""


# Iterate over each unique combination of 'type' and 'condition'
for unique_comb in grouped_plot[['type', 'condition']].drop_duplicates().values:
    type_val, condition_val = unique_comb
    print("condition_val:", condition_val, sep="\n")
    # Filter data based on the unique combination of 'type' and 'condition'
    sub_df = grouped_plot[(grouped_plot['type'] == type_val) & (grouped_plot['condition'] == condition_val)]
    # explod sub_df['fiber_NRL_list'] and drop NaN values and convert to column in fiber_data called "dist"
    fiber_data = sub_df[align_on].explode().dropna().to_frame()
    fiber_data.rename(columns={align_on: 'dist'}, inplace=True)
    # Add a column called "type" with the value of type_val and condition_val
    fiber_data['condition'] = f"{condition_val}" #{type_val}_
    # reset index
    fiber_data.reset_index(inplace=True, drop=True)
    # print length of fiber_data
    print("len(fiber_data):", len(fiber_data['condition']))
    #display(fiber_data.head(10))
    # if distplot_df is empty, set distplot_df equal to fiber_data
    if distplot_df.empty:
        distplot_df = fiber_data
    # else append fiber_data to distplot_df
    else:
        distplot_df = distplot_df.append(fiber_data)

nanotools.display_sample_rows(distplot_df, 5)
#plot_title = concatenate "N2" and align_on
plot_title = "N2_" + align_on

#display(distplot_df.head(10))
peaks1,fig1 = nanotools.plot_NRL_dist(distplot_df[(distplot_df['condition'] == "N2_fiber") ],"X","#16415e",str("N2_"+align_on),smoothing_val=0.2)
display(peaks1.head(10))
peaks2,fig2 = nanotools.plot_NRL_dist(distplot_df[(distplot_df['condition'] == "SDC2_degron_fiber") ],"X","#16415e",str("SDC2_degron_"+align_on),smoothing_val=0.2)
display(peaks2.head(10))

fig = nanotools.plot_NRL_dist_compare(distplot_df,"X","N2 and SDC2_degron",smoothing_val=0.2,norm_bool=False,window=bed_window)
fig.savefig('images/dpy_27_aligned_on_internuc_KDE_COMB.png', dpi=300, bbox_inches='tight')
fig.savefig('images/dpy_27_aligned_on_internuc_KDE_COMB.svg', bbox_inches='tight')
#fig.show()

fig1.savefig('images/dpy_27_aligned_on_internuc_KDE_N2.png', dpi=300, bbox_inches='tight')
fig1.savefig('images/dpy_27_aligned_on_internuc_KDE_N2.svg', bbox_inches='tight')

fig2.savefig('images/dpy_27_aligned_on_internuc_KDE_SDC2_degron.png', dpi=300, bbox_inches='tight')
fig2.savefig('images/dpy_27_aligned_on_internuc_KDE_SDC2_degron.svg', bbox_inches='tight')

In [None]:
### PLOT PERCENT MAD, NUC, OTHER
# Initialize the Plotly figure
# Create subplot

# Create a temporary helper column that combines 'condition' and 'type'
grouped['condition_type'] = grouped['condition'].astype(str) + "_" + grouped['type'].astype(str) + "_" + grouped['chr_type'].astype(str)

# Create subplot with 1 row and 3 columns
fig = make_subplots(rows=1, cols=3, subplot_titles=('Percent MAD', 'Percent NUC', 'Percent OTHER'))

# Unique metrics
metrics = ['percent_MAD', 'percent_NUC', 'percent_OTHER']

# Loop over metrics
for col, metric in enumerate(metrics, start=1):
    fig.add_trace(
        go.Box(
            y=grouped[metric],
            x=grouped['condition_type'],
            name=metric,
            legendgroup=metric,
        ),
        row=1, col=col
    )

# Drop the temporary helper column
grouped.drop('condition_type', axis=1, inplace=True)

# Update layout
fig.update_layout(
    title='Box Plots for Metrics by Condition and Type',
    template='plotly_white'
)

#fig.show()

# Extract integer column names
int_columns = [col for col in mod_qual_LINK.columns if str(col).isdigit()]

# Concatenate all integer columns for LINK and drop NaN values
all_values_LINK = pd.concat([mod_qual_LINK[col] for col in int_columns]).dropna()

# Concatenate all integer columns for NUC and drop NaN values
all_values_NUC = pd.concat([mod_qual_NUC[col] for col in int_columns]).dropna()

# Create a subplot with 1 row and 2 columns
fig = make_subplots(rows=1, cols=2, subplot_titles=("Boxplots", "Counts"))

# Add boxplot for LINK to the first subplot
fig.add_trace(
    go.Box(
        y=all_values_LINK,
        name='LINK',
        marker_color='blue'
    ),
    row=1, col=1
)

# Add boxplot for NUC to the first subplot
fig.add_trace(
    go.Box(
        y=all_values_NUC,
        name='NUC',
        marker_color='red'
    ),
    row=1, col=1
)

# Add bar chart for counts to the second subplot
fig.add_trace(
    go.Bar(
        x=['LINK', 'NUC'],
        y=[len(all_values_LINK), len(all_values_NUC)],
        name='Counts',
        marker_color=['blue', 'red']
    ),
    row=1, col=2
)

fig.update_layout(
    title='Box Plots for NUCs and LINKs',
    template='plotly_white'
)
fig.show()

In [None]:
### Plot n - n+1 nucleosome positioning variance
import plotly.graph_objects as go
# Assuming combined_res_df is the DataFrame you got after running calculate_feature for all combinations
# Process combined_df similar to how you processed results_df
df = combined_res_df.copy()

# Create an empty list to store the data for the plot
data = []

# Loop through each unique combination of ['condition','chr_type','type']
for comb in df[['condition', 'chr_type', 'type']].drop_duplicates().values:
    condition, chr_type, dtype = comb
    subset_df = df[(df['condition'] == condition) & (df['chr_type'] == chr_type) & (df['type'] == dtype)]

    subs_list = []
    nuc_ids = []
    # Loop through each unique nuc_id
    for nuc_id in subset_df['nuc_id'].unique():
        individual_subs_list = subset_df[subset_df['nuc_id'] == nuc_id]['subs_list'].explode().dropna()
        subs_list.extend(individual_subs_list)
        nuc_ids.extend([nuc_id] * len(individual_subs_list))

    trace_name = f"{condition}_{chr_type}_{dtype}"  # Name based on the unique combination
    trace = go.Box(y=subs_list, name=trace_name, x=nuc_ids,
                   #offset boxes from eachother
                   offsetgroup=trace_name
                    )  # x-axis is nuc_id
    data.append(trace)

# Create layout
layout = go.Layout(
    title="Boxplot of subs_list for each nuc_id",
    xaxis_title="Nucleotide ID",
    yaxis_title="Subs List Value",
    template='plotly_white',
    boxmode='group'
)

# Create figure and add data and layout
fig = go.Figure(data=data, layout=layout)



# Show the figure
fig.show()


In [None]:
### Plot singe fiber nucleosome positioing and distribution
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Assuming combined_res_df is the DataFrame you got after running calculate_feature for all combinations
df = combined_res_df.copy()

# Drop all rows where abs(mean_nuc_pos) is > cutoff
df = df[df['mean_nuc_pos'].abs() <= 750]

# Create subplots with 2 rows and 1 column, sharing the same x-axis
fig = make_subplots(rows=2, cols=1, shared_xaxes=True)

# Color map to store unique colors for each trace_name
color_map = {}

# Loop through each unique combination of ['condition','chr_type','type']
for idx, comb in enumerate(df[['condition', 'chr_type', 'type']].drop_duplicates().values):
    condition, chr_type, dtype = comb
    subset_df = df[(df['condition'] == condition) & (df['chr_type'] == chr_type) & (df['type'] == dtype)]

    subs_list = []
    mean_nuc_pos_list = []
    percent_occ_list = []

    # Loop through each unique nuc_id
    for nuc_id in subset_df['nuc_id'].unique():
        nuc_subset = subset_df[subset_df['nuc_id'] == nuc_id]
        individual_subs_list = nuc_subset['subs_list'].explode().dropna()
        subs_list.extend(individual_subs_list)

        mean_nuc_pos_values = nuc_subset['mean_nuc_pos'].unique()
        #mean_nuc_pos_values = nuc_subset['std_nuc_pos'].unique()
        mean_nuc_pos_list.extend([mean_nuc_pos_values[0]] * len(individual_subs_list))

        percent_occ_values = nuc_subset['percent_occ'].unique()
        percent_occ_list.append((mean_nuc_pos_values[0], percent_occ_values[0]))

    # Generate a unique color for each trace_name
    unique_color = f'rgba({50+idx*500},{100+idx*5},{150+idx*5},0.8)'
    color_map[tuple(comb)] = unique_color  # Convert numpy array to tuple


    # Boxplot Trace
    trace_name = f"{condition}_{chr_type}_{dtype}"
    box_trace = go.Box(y=subs_list, name=trace_name, x=mean_nuc_pos_list,
                       width=10, line=dict(color=unique_color))
    fig.add_trace(box_trace, row=1, col=1)

    # Bar Trace for percent_occ
    bar_x, bar_y = zip(*percent_occ_list)
    bar_trace = go.Bar(x=bar_x, y=bar_y, name=f"{trace_name}_percent_occ",
                       marker=dict(color=unique_color), width=10)
    fig.add_trace(bar_trace, row=2, col=1)

# Update layout
fig.update_layout(
    title="Boxplot of subs_list and Barplot of percent_occ",
    xaxis_title="Mean Nucleotide Position",
    yaxis_title="Nuceosome offset",
    yaxis2_title="Percent Occupancy",
    template='plotly_white'
)

# Add a dashed vertical line at x=0
fig.add_shape(type="line", x0=0, y0=0, x1=0, y1=0.5, line=dict(color="Black", width=1, dash="dash"), row=2, col=1)
fig.add_shape(type="line", x0=0, y0=0, x1=0, y1=120, line=dict(color="Black", width=1, dash="dash"), row=1, col=1)

# Change x axis tick marks to every 250
fig.update_xaxes(dtick=250, row=1, col=1)
fig.update_xaxes(dtick=250, row=2, col=1)

# Show figure
fig.show()


In [None]:
# reimport nanotools
importlib.reload(nanotools)
import plotly.graph_objects as go
import pandas as pd

# Assuming 'grouped' is your DataFrame
grouped_clean = grouped.copy(deep=True)
grouped_clean=grouped_clean.dropna(subset=['inter_nuc_dist'])
# keep only rows where type == "TSS_q4"
grouped_clean = grouped_clean[grouped_clean['type'] == "strong_rex"]
# sort by type in alphabetical order
grouped_clean = grouped_clean.sort_values(by=['condition','chr_type'])

# Define subplot structure: 1 row, 2 columns
fig = make_subplots(rows=1, cols=2, subplot_titles=grouped_clean['chr_type'].unique())

colors = {"SDC2_degron_fiber": 'red', "N2_fiber": 'blue', "X": 'red', "Autosome": 'blue'}

# Add traces, specifying the correct subplot for each 'type'
for type_value in grouped_clean['chr_type'].unique():
    for i, condition_value in enumerate(grouped_clean['condition'].unique()):
        df_filtered = grouped_clean[(grouped_clean['condition'] == condition_value) & (grouped_clean['chr_type'] == type_value)]
        fig.add_trace(go.Box(
            y=df_filtered['inter_nuc_dist'], #- NUC_width,
            name=condition_value,
            marker_color=colors.get(condition_value, 'black'),  # Default color if condition not in colors dict
            fillcolor='rgba(0,0,0,0)',
            boxmean=True
        ), row=1, col=grouped_clean['chr_type'].unique().tolist().index(type_value) + 1)

        # label the mean, centered
        fig.add_annotation(
            x=grouped_clean['condition'].unique().tolist().index(condition_value),
            y=df_filtered['inter_nuc_dist'].mean(), #- NUC_width-50,
            text=f"{df_filtered['inter_nuc_dist'].mean():.2f}",
            showarrow=False,
            yshift=8,
            font=dict(
                size=11,
                color="black"
            ),
            row=1, col=grouped_clean['chr_type'].unique().tolist().index(type_value) + 1
        )
# Update layout
fig.update_layout(
    title="Nucleosome N+1 - N Distance by Condition and Type",
    yaxis_title="Inter-nucleosome Distance (bp)",
    template="plotly_white",
    width = 600
)

# Set y-axis range
fig.update_yaxes(range=[(grouped_clean['inter_nuc_dist']-NUC_width).quantile(0.05) * 0.8, (grouped_clean['inter_nuc_dist']-NUC_width).quantile(0.95) * 1.2], row=1, col=1)
fig.update_yaxes(range=[(grouped_clean['inter_nuc_dist']-NUC_width).quantile(0.05) * 0.8, (grouped_clean['inter_nuc_dist']-NUC_width).quantile(0.95) * 1.2], row=1, col=2)

try:
    fig = nanotools.add_p_value_annotation(fig,[[0,1]],1)
except:
    pass
try:
    fig = nanotools.add_p_value_annotation(fig,[[0,1]],2)
except:
    pass


# Define subplot structure: 1 row, 2 columns
fig0 = make_subplots(rows=1, cols=2, subplot_titles=grouped_clean['chr_type'].unique())
# Assuming 'grouped' is your DataFrame
grouped_clean = grouped_clean.dropna(subset=['largest_nfr_size'])

# Add traces, specifying the correct subplot for each 'type'
for type_value in grouped_clean['chr_type'].unique():
    for i, condition_value in enumerate(grouped_clean['condition'].unique()):
        df_filtered = grouped_clean[(grouped_clean['condition'] == condition_value) & (grouped_clean['chr_type'] == type_value)]
        fig0.add_trace(go.Box(
            y=df_filtered['largest_nfr_size'],
            name=condition_value,
            marker_color=colors.get(condition_value, 'black'),  # Default color if condition not in colors dict
            fillcolor='rgba(0,0,0,0)'
        ), row=1, col=grouped_clean['chr_type'].unique().tolist().index(type_value) + 1)

# Update layout
fig0.update_layout(
    title="Largest NFR Distance by Condition and Type",
    yaxis_title="Inter-nucleosome distance (bp)",
    template="plotly_white",
    width = 600
)

# Set y-axis range
fig0.update_yaxes(range=[0, grouped_clean['largest_nfr_size'].quantile(0.95) * 1.1], row=1, col=1)
fig0.update_yaxes(range=[0, grouped_clean['largest_nfr_size'].quantile(0.95) * 1.1], row=1, col=2)

try:
    fig0 = nanotools.add_p_value_annotation(fig0,[[0,1]],1)
except:
    pass
try:
    fig0 = nanotools.add_p_value_annotation(fig0,[[0,1]],2)
except:
    pass

# Step 1: Expand the lists in 'inter_nuc_sub' into separate rows
grouped_exploded = grouped_clean.explode('inter_nuc_sub')

# Step 2: Drop rows with NaNs (which were originally empty lists)
grouped_exploded = grouped_exploded.dropna(subset=['inter_nuc_sub'])

# Define subplot structure: 1 row, 2 columns
fig2 = make_subplots(rows=1, cols=2, subplot_titles=grouped_exploded['condition'].unique())

# Generate a unique color for each 'condition'
conditions_selec = grouped_exploded['condition'].unique()
# reverse conditions
conditions_selec = conditions_selec[::-1]

# Add traces, specifying the correct subplot for each 'type'
for type_value in grouped_exploded['chr_type'].unique():
    for condition in conditions_selec:
        df_filtered = grouped_exploded[(grouped_exploded['condition'] == condition) & (grouped_exploded['chr_type'] == type_value)]
        fig2.add_trace(go.Box(
            y=df_filtered['inter_nuc_sub'],
            name=condition,
            marker_color=colors.get(condition, 'black'),  # Default color if condition not in colors dict
            fillcolor='rgba(0,0,0,0)'
        ), row=1, col=grouped_exploded['chr_type'].unique().tolist().index(type_value) + 1)

# Update layout
fig2.update_layout(
    title="Per-fiber nucleosome subtraction values",
    yaxis_title="Ni+1 - Ni (bp)",
    template="plotly_white",
    width = 600
)

# Set y-axis range
fig2.update_yaxes(range=[120, 250], row=1, col=1)
fig2.update_yaxes(range=[120, 250], row=1, col=2)

# Optional: Add p-value annotation
# This depends on the function `nanotools.add_p_value_annotation2`
# fig2 = nanotools.add_p_value_annotation2(fig2, [[0,1]])

try:
    fig2 = nanotools.add_p_value_annotation(fig2,[[0,1]],1)
except:
    pass
try:
    fig2 = nanotools.add_p_value_annotation(fig2,[[0,1]],2)
except:
    pass




# Reshape the DataFrame for percent_NUC and percent_MAD plotting
nanotools.display_sample_rows(grouped_clean)
melted_df = grouped_clean.melt(id_vars=['condition','chr_type'], value_vars=['percent_NUC', 'percent_MAD'], var_name='stat', value_name='percent')
melted_df = melted_df.dropna(subset=['percent'])
# reset index
melted_df.reset_index(inplace=True, drop=True)
conditions_list = melted_df['condition'].unique()
fig3 = make_subplots(rows=2, cols=2, subplot_titles=melted_df['condition'].unique(),shared_yaxes=True)

# Adding plots for percent_NUC and percent_MAD by condition
for each_cond in conditions_list:
    for each_type in melted_df['chr_type'].unique():
        df_filtered = melted_df[(melted_df['condition'] == each_cond) & (melted_df['chr_type'] == each_type)]
        #display(df_filtered.head(10))
        fig3.add_trace(go.Box(y=df_filtered[df_filtered['stat'] == 'percent_NUC']['percent'],
                              name=each_type,
                              marker_color=colors.get(each_cond, 'black'),  # Default color if condition not in colors dict
                              fillcolor='rgba(0,0,0,0)'  # Transparent fill
                              ), row=1, col=melted_df['condition'].unique().tolist().index(each_cond) + 1)
        fig3.add_trace(go.Box(y=df_filtered[df_filtered['stat'] == 'percent_MAD']['percent'],
                              name=each_type,
                              marker_color=colors.get(each_cond, 'black'),  # Default color if condition not in colors dict
                              fillcolor='rgba(0,0,0,0)'  # Transparent fill
                              ), row=2, col=melted_df['condition'].unique().tolist().index(each_cond) + 1)


# Update layout for the fifth row's y-axis to show percentage
fig3.update_yaxes(title_text="Percentage", tickformat='.0%', row=1, col=1)
fig3.update_yaxes(title_text="Percentage", tickformat='.0%', row=2, col=1)
# update y range of row=2 col=1

# Update layout
fig3.update_layout(
    title="Boxplot of % Nucleosome and % NFR",
    yaxis_title="% Occupied",
    template="plotly_white",
    width = 600,
    height = 800
    # group box plots of the same subplot together
)

# set y axis range between 20% and 110%
fig3.update_yaxes(range=[0.2, 1.1], row=1, col=1)
fig3.update_yaxes(range=[0.2, 1.1], row=1, col=2)
# set y axis range between 20% and 110%
fig3.update_yaxes(range=[0, 0.95], row=2, col=1)
fig3.update_yaxes(range=[0, 0.95], row=2, col=2)

try:
    fig3 = nanotools.add_p_value_annotation(fig3,[[0,1]],1)
except:
    pass
try:
    fig3 = nanotools.add_p_value_annotation(fig3,[[0,1]],2)
except:
    pass
try:
    fig3 = nanotools.add_p_value_annotation(fig3,[[0,1]],3)
except:
    pass
try:
    fig3 = nanotools.add_p_value_annotation(fig3,[[0,1]],4)
except:
    pass

# Show plot
fig.show()
fig0.show()
fig2.show()
fig3.show()
# save fig to images_11_14_23/inter_nuc_dist_boxplot_strong_rex.png and .svg
fig.write_image("images_11_14_23/inter_nuc_dist_boxplot_strong_rex.png")
fig.write_image("images_11_14_23/inter_nuc_dist_boxplot_strong_rex.svg")

In [None]:
# Assuming 'grouped' is your DataFrame
grouped_clean = grouped.copy(deep=True)
grouped_clean=grouped_clean.dropna(subset=['inter_nuc_dist'])
# keep only rows where type == "TSS_q4" or == "TSS_q3"
grouped_clean = grouped_clean[grouped_clean['type'].isin(["strong_rex"])]
#grouped_clean = grouped_clean[grouped_clean['chr_type'] == "Autosome"]
nanotools.display_sample_rows(grouped_clean)

colors = {"SDC2_degron_fiber": 'red', "N2_fiber": 'blue', "X": 'red', "Autosome": 'blue'}

### Plot dist of internuc dist
# Parameters
bin_width = 1
bin_range = range(147-NUC_width, 700, bin_width)  # Bin range from 170 to 600

# Grouping the data by 'condition', 'type', and 'chr_type'
grouped_clean_grouped = grouped_clean.groupby(['condition', 'type', 'chr_type'])

# Updated plot data preparation with cumulative percentages
cumulative_percentage_plot_data = []

for group_name, group_df in grouped_clean_grouped:
    group_df['inter_nuc_dist_adjusted'] = group_df['inter_nuc_dist'].clip(lower=147, upper=600+NUC_width) -NUC_width
    # Total number of reads in the group
    total_reads = group_df.shape[0]
    # Create bins and calculate cumulative count for each bin
    bin_counts = [group_df[group_df['inter_nuc_dist_adjusted'] >= bin_edge].shape[0] for bin_edge in bin_range]

    # Convert counts to a cumulative count
    #cumulative_counts = np.cumsum(bin_counts)

    # Convert cumulative counts to percentages (excluding the first bin)
    bin_percentages = (np.array(bin_counts[1:]) / total_reads)

    # Bin centers for plotting (excluding the first bin)
    bin_centers = np.array(bin_range[1:]) + (bin_width / 2)

    # Adding to cumulative percentage plot data
    cumulative_percentage_plot_data.append(go.Scatter(x=bin_centers,
                                                      y=bin_percentages,
                                                      mode='lines',
                                                      name=str(group_name[0]+' '+group_name[2]+' '+group_name[1])))
                                                      #marker_color=colors.get(group_name[0], 'black')))

# Create the updated plot layout for cumulative percentages
cumulative_percentage_layout = go.Layout(
    title='Cumulative Percentage Distribution of inter_nuc_dist',
    xaxis=dict(title='Inter-nucleosome distance (bp)'),
    yaxis=dict(title='Cumulative Percentage of Reads (%)'),
    template='plotly_white',
    width = 600
)

# Set y axis to %
cumulative_percentage_layout['yaxis']['tickformat'] = '.0%'

# Create the updated figure for cumulative percentages
cumulative_percentage_fig = go.Figure(data=cumulative_percentage_plot_data, layout=cumulative_percentage_layout)

# Grouping the data by 'bed_start', 'chrom', 'condition', 'type', and 'chr_type'
grouped_genes = grouped_clean.groupby(['bed_start', 'chrom', 'condition', 'type', 'chr_type'])

### Plot difference plot
# Assuming the first two groups are the ones you want to compare
first_group_data = None
second_group_data = None

for i, (group_name, group_df) in enumerate(grouped_clean_grouped):
    group_df['inter_nuc_dist_adjusted'] = group_df['inter_nuc_dist'].clip(lower=147, upper=600+NUC_width) - NUC_width
    total_reads = group_df.shape[0]
    bin_counts = [group_df[group_df['inter_nuc_dist_adjusted'] >= bin_edge].shape[0] for bin_edge in bin_range]
    bin_percentages = (np.array(bin_counts[1:]) / total_reads)
    bin_centers = np.array(bin_range[1:]) + (bin_width / 2)

    if i == 0:
        first_group_data = bin_percentages
    elif i == 1:
        second_group_data = bin_percentages
        break

# Calculate the difference between the first and second group
difference = first_group_data - second_group_data

# Plot for the difference
difference_plot_data = go.Scatter(x=bin_centers, y=difference, mode='lines', name='Difference between first and second group')
difference_layout = go.Layout(
    title='Difference in Cumulative Percentage Distribution between First and Second Group',
    xaxis=dict(title='Inter-nucleosome distance (bp)'),
    yaxis=dict(title='Difference in Cumulative Percentage (%)'),
    template='plotly_white',
    width = 600,
    #yaxis tickformat
    yaxis_tickformat = '.0%'
)

difference_fig = go.Figure(data=[difference_plot_data], layout=difference_layout)


# Preparing data for the box plot
box_plot_data = []

for group_name, group_df in grouped_genes:
    # Calculate the percentage of reads in open configuration (inter_nuc_dist >= 180)
    percent_open_reads = (group_df['inter_nuc_dist'] >= 190).sum() / group_df.shape[0]
    #print("Starting on:", group_name,"with this many reads:", len(group_df),"with % open:", percent_open_reads,sep="\n")

    # Append the result with the condition and chr_type as categories
    condition, chr_type = group_name[2], group_name[4]
    box_plot_data.append({'condition': condition, 'chr_type': chr_type, 'percent_open_reads': percent_open_reads})


# Convert to DataFrame for easier plotting
box_plot_df = pd.DataFrame(box_plot_data)
nanotools.display_sample_rows(box_plot_df)

# Create the box plot
box_fig = go.Figure()

# Adding box plots for each chr_type
chr_types = box_plot_df['chr_type'].unique()
for chr_type in chr_types:
    filtered_df = box_plot_df[box_plot_df['chr_type'] == chr_type]
    box_fig.add_trace(go.Box(y=filtered_df['percent_open_reads'],
                             x=filtered_df['condition'],
                             name=chr_type,
                             boxpoints='all',
                             jitter=0.3))


# Update layout
box_fig.update_layout(
    title='Percentage of Reads in Open Configuration by Condition and Chr_Type',
    xaxis_title='Condition',
    yaxis_title='Percentage of Reads in Open Configuration (%)',
    template='plotly_white',
    # group box plots
    boxmode='group',
    width = 300
)



# Create a box plot
n1_fig = go.Figure()

# Adding box plots for each chr_type
chr_types = grouped_clean['condition'].unique()
# flip the order of chr_types
#chr_types = chr_types[::-1]
for chr_type in chr_types:
    filtered_df = grouped_clean[grouped_clean['condition'] == chr_type]
    n1_fig.add_trace(go.Box(y=filtered_df['smallest_positive_nuc_midpoint'],
                            #x=filtered_df['condition'],
                            name=chr_type,
                            marker_color=colors.get(chr_type, 'black'),
                            fillcolor='rgba(0,0,0,0)'))  # Transparent fill

n1_fig.update_layout(
    title="Distribution of N+1 nucleosome position",
    xaxis_title="Type",
    yaxis_title="Smallest Positive Nucleotide Midpoint",
    template="plotly_white",
    width = 300,
    #do not show legend
    showlegend=False
)

# Calculating and printing the variance
print("\nVariance of 'smallest_positive_nuc_midpoint' by 'condition':")
conditions = grouped_clean['condition'].unique()
for condition in conditions:
    subset_df = grouped_clean[grouped_clean['condition'] == condition]
    variance = subset_df['smallest_positive_nuc_midpoint'].var()
    print(f"Condition: {condition}, Variance: {variance:.2f}")

# Show the updated plot
cumulative_percentage_fig.show()

difference_fig.show()
# Show plot
box_fig.show()
# Distribution of N+1 nucleosome
n1_fig.show()

# save cumulative_percentage_fig and difference_fig to png and svg files in images_11_14_23/
cumulative_percentage_fig.write_image("images_11_14_23/cumulative_percentage_dist_strong_rex.png")
cumulative_percentage_fig.write_image("images_11_14_23/cumulative_percentage_dist_strong_rex.svg")
difference_fig.write_image("images_11_14_23/difference_cumulative_percentage_dist_strong_rex.png")
difference_fig.write_image("images_11_14_23/difference_cumulative_percentage_dist_strong_rex.svg")


In [None]:
### Plot accessibility pileups based on single fiber alignments
print("plot_df")
display(plot_df.head(3))
print("grouped")
display(grouped.head(3))
print("coverage_df")
display(coverage_df.head(3))
# merge comb_bedmethyl_plot_df with grouped on read_id addig grouped's 'closest_nuc' column
print("Merging plot_df and grouped...")
merged_df_access = pd.merge(plot_df, grouped[['read_id', 'closest_nuc','closest_MAD_midpoint','smallest_positive_nuc_midpoint','greatest_negative_nuc_midpoint']], on='read_id', how='left')
# drop all rows where closest_nuc is NaN
#merged_df_access.dropna(subset=['closest_nuc'], inplace=True)

#subtract closest_nuc from rel_pos
#merged_df_access['rel_pos'] -= merged_df_access['closest_MAD_midpoint']
#merged_df_access['rel_pos'] -= ((merged_df_access['smallest_positive_nuc_midpoint']-merged_df_access['greatest_negative_nuc_midpoint'])/2+merged_df_access['greatest_negative_nuc_midpoint'])

# drop rows such that each unique combination of condition, type and chr_type has the same number of records. Match the combination that has the fewest.

# Group by 'condition', 'type', 'chr-type' and 'rel_pos', adding sum and count of 'mod_qual' column
group_merge = merged_df_access.groupby(['condition', 'type', 'chr_type', 'rel_pos'])['mod_qual_bin'].agg(['sum', 'count']).reset_index()
group_merge['raw_meth_frac'] = group_merge['sum'] / group_merge['count']

# group coverage_df by 'condition' summing total_m6A and total_A
coverage_group = coverage_df.groupby(['condition'])['total_m6a', 'total_A_m6a'].sum().reset_index()
# set condition_m6A_frac to total_m6A / total_A
coverage_group['condition_m6A_frac'] = coverage_group['total_m6a'] / coverage_group['total_A_m6a']

#merge group_merge with coverage_group on 'condition'
group_merge = pd.merge(group_merge, coverage_group[['condition', 'condition_m6A_frac']], on='condition', how='left')
#Add weighted_m6A_frac column
group_merge['weighted_norm_mod_frac'] = group_merge['raw_meth_frac'] / group_merge['condition_m6A_frac']

#Rename rel_pos column to rel_start
group_merge.rename(columns={'rel_pos': 'rel_start'}, inplace=True)
display(group_merge.sample(n=10))


region_fig = plot_bedmethyl(group_merge, conditions, chr_types=["X"], types=["strong_rex"], strands=["all"], window_size=50, selection_indices=[0,1,2], bed_window=1500)

# save region_fig to temp folder
region_fig[0].show()
#region_fig[0].write_image("images/dpy27_sdc2_sdc3_aligned_nearest_nuc_n2vsSDC2.svg")
#region_fig[0].write_image("images/dpy27_sdc2_sdc3_aligned_nearest_nuc_n2vsSDC2.png",width=1600,height=1300)
# merge grouped with coverage_df

In [None]:
## Downsamplin for read plotting
n_read_ids = 100000  # max reads / condition for plotting
# Function to downsample each group
def downsample_group(group):
    global bed_window
    print("\nProcessing group:", group.name)  # Display group name (combination of 'condition', 'chr_type', 'type')
    unique_read_ids = group['read_id'].unique()

    # Check if downsampling is needed
    if len(unique_read_ids) > n_read_ids:
        # Filter read_ids based on length requirement
        sampled_read_ids_long = group[
            (group['rel_read_end'] - group['rel_read_start']) > (3/4 * bed_window)
        ]['read_id'].unique()

        # Sample read_ids based on the number required
        if len(sampled_read_ids_long) > n_read_ids:
            sampled_read_ids = pd.Series(sampled_read_ids_long).sample(n=n_read_ids).tolist()
        else:
            # Include additional read_ids if long ones are not enough
            sampled_read_ids = sampled_read_ids_long.tolist()
            remaining_ids = group[~group['read_id'].isin(sampled_read_ids)]['read_id'].unique()
            additional_sampled_ids = pd.Series(remaining_ids).sample(n=(n_read_ids - len(sampled_read_ids))).tolist()
            sampled_read_ids.extend(additional_sampled_ids)

        downsampled_group = group[group['read_id'].isin(sampled_read_ids)]
        return downsampled_group
    else:
        return group

def process_bam_files(bam_exp_pairs):
    data = []
    for bam_file, exp_id in bam_exp_pairs:
        with pysam.AlignmentFile(bam_file, "rb") as bam:
            for read in bam.fetch():  # Adjust the number as needed for actual use
                data.append({
                    "bam_file_name": bam_file,
                    "read_id": read.query_name,
                    "exp_id": exp_id,
                })
    return data

def parallel_process_bam_files(bam_files, exp_ids, num_processes=None):
    # Pair bam files with their corresponding experiment IDs
    paired_bam_exp = list(zip(bam_files, exp_ids))

    # Determine the number of processes to use
    if num_processes is None:
        num_processes = multiprocessing.cpu_count()

    # Adjust num_processes if it's greater than the length of paired_bam_exp
    num_processes = min(num_processes, len(paired_bam_exp))

    # Ensure chunksize is at least 1
    chunksize = max(1, len(paired_bam_exp) // num_processes)

    with multiprocessing.Pool(num_processes) as pool:
        results = pool.map(process_bam_files, [paired_bam_exp[i:i + chunksize] for i in range(0, len(paired_bam_exp), chunksize)])

    # Combine the results from all processes
    combined_data = [item for sublist in results for item in sublist]
    return combined_data

merged_df = plot_df.copy(deep=True)

# Apply downsampling for each unique combination of 'condition', 'chr_type', 'type'
down_sampled_plot_df = merged_df.groupby(['condition', 'chr_type', 'type']).apply(downsample_group).reset_index(drop=True)
# print number of groups
print("Number of groups:", len(down_sampled_plot_df.groupby(['condition', 'chr_type', 'type'])))

# Further processing on the downsampled DataFrame
try:
    down_sampled_plot_df = down_sampled_plot_df.sort_values(by=['smallest_positive_nuc_midpoint', 'greatest_negative_nuc_midpoint'], ascending=[True, True])
    down_sampled_plot_df.reset_index(inplace=True, drop=True)
    down_sampled_group_df = grouped[grouped['read_id'].isin(down_sampled_plot_df['read_id'])]

# else
except:
    print("Missing nucs, proceeding without...")

    # Assuming 'grouped' is another DataFrame you want to filter based on the downsampled read_ids
    if 'exp_id' not in merged_df.columns:
        print("Missing exp_id column, adding it...")
        data = parallel_process_bam_files(new_bam_files, exp_ids)

        # Convert the list of dictionaries to a DataFrame
        read_exp_ids_df = pd.DataFrame(data)

    # Add exp_id column to grouped dataframe by merging on read_id
    # if exp_id column does not exist in grouped then:

        down_sampled_plot_df = pd.merge(down_sampled_plot_df, read_exp_ids_df[['read_id', 'exp_id']], on='read_id', how='left')

    else:
        down_sampled_plot_df = pd.merge(down_sampled_plot_df, read_exp_ids_df[['read_id', 'exp_id']], on='read_id', how='left')

print("Number of unique reads in down_sampled_plot_df:", len(down_sampled_plot_df['read_id'].unique()))
nanotools.display_sample_rows(down_sampled_plot_df)
#nanotools.display_sample_rows(down_sampled_group_df)


In [None]:
def process_dataframe(down_sampled_plot_df, min_mod_qual, selec_conds, selec_type, selected_chr_type, num_read_ids, rel_pos_window, smoothing_window):
    mC_aligned_dfs = []
    output_df = pd.DataFrame()
    read_id_suffix = 1

    for selec_cond in selec_conds:
        down_sampled_plot_df_filtered = down_sampled_plot_df[(down_sampled_plot_df['condition'] == selec_cond) & (down_sampled_plot_df['type'] == selec_type) & (down_sampled_plot_df['chr_type'] == selected_chr_type)]

        # Limit the read_IDs if necessary
        if num_read_ids is not None:
            read_ids_to_keep = down_sampled_plot_df_filtered['read_id'].unique()[:num_read_ids]  # Get first num_read_ids
            mC_aligned_df = down_sampled_plot_df_filtered[down_sampled_plot_df_filtered['read_id'].isin(read_ids_to_keep)]

        # 1. Create and align the 'mC_aligned_df'
        print(f"*** Filtering for 5mC candidates in condition {selec_cond} ***")
        mC_aligned_df = mC_aligned_df[
            (mC_aligned_df['mod_code'] == 'm') &
            (mC_aligned_df['query_kmer'].str[1].isin(['C', 'G'])) &
            (mC_aligned_df['mod_qual'] > min_mod_qual)
        ].copy()
        nanotools.display_sample_rows(mC_aligned_df)

        print(f"*** Filtering for rows with three nearby rows meeting the same condition in condition {selec_cond} ***")
        mC_aligned_df['nearby_rows_count'] = mC_aligned_df.apply(lambda row: mC_aligned_df[
            (mC_aligned_df['read_id'] == row['read_id']) &
            (mC_aligned_df.index != row.name) &
            (abs(mC_aligned_df['rel_pos'] - row['rel_pos']) <= 10) &
            (mC_aligned_df['mod_code'] == 'm') &
            (mC_aligned_df['query_kmer'].str[1].isin(['C', 'G'])) &
            (mC_aligned_df['mod_qual'] > min_mod_qual) &
            (~mC_aligned_df['read_id'].isin(mC_aligned_df[
                (mC_aligned_df['read_id'] == row['read_id']) &
                (mC_aligned_df['mod_code'] == 'a') &
                (mC_aligned_df['mod_qual'] > min_mod_qual) &
                (abs(mC_aligned_df['rel_pos'] - row['rel_pos']) <= 10)
            ]['read_id'].unique()))
        ].shape[0], axis=1)

        mC_aligned_df = mC_aligned_df[mC_aligned_df['nearby_rows_count'] >= 1].copy()
        mC_aligned_df.drop(columns=['nearby_rows_count'], inplace=True)
        nanotools.display_sample_rows(mC_aligned_df)

        print(f"*** Shifting to midpoint and skipping clustered rows for condition {selec_cond} ***")
        mC_aligned_df['cluster_id'] = mC_aligned_df.groupby('read_id').ngroup()
        mC_aligned_df['min_rel_pos'] = mC_aligned_df.groupby('cluster_id')['rel_pos'].transform('min')
        mC_aligned_df['max_rel_pos'] = mC_aligned_df.groupby('cluster_id')['rel_pos'].transform('max')
        mC_aligned_df['midpoint_rel_pos'] = (mC_aligned_df['min_rel_pos'] + mC_aligned_df['max_rel_pos']) / 2
        # round to nearest integer
        mC_aligned_df['midpoint_rel_pos'] = mC_aligned_df['midpoint_rel_pos'].round()
        mC_aligned_df = mC_aligned_df.drop_duplicates(subset=['cluster_id'], keep='first')
        mC_aligned_df['rel_pos'] = mC_aligned_df['midpoint_rel_pos']
        mC_aligned_df = mC_aligned_df.drop(columns=['cluster_id', 'min_rel_pos', 'max_rel_pos', 'midpoint_rel_pos'])
        nanotools.display_sample_rows(mC_aligned_df)

        print(f"*** Aligning reads efficiently for condition {selec_cond} ***")
        mC_aligned_df['rel_pos_lower'] = mC_aligned_df['rel_pos'] - rel_pos_window
        mC_aligned_df['rel_pos_upper'] = mC_aligned_df['rel_pos'] + rel_pos_window
        mC_aligned_df = pd.merge(mC_aligned_df[['read_id', 'rel_pos', 'rel_pos_lower', 'rel_pos_upper']], down_sampled_plot_df_filtered, on='read_id')
        mC_aligned_df = mC_aligned_df[
            (mC_aligned_df['rel_pos_y'] >= mC_aligned_df['rel_pos_lower']) &
            (mC_aligned_df['rel_pos_y'] <= mC_aligned_df['rel_pos_upper'])
        ]
        mC_aligned_df['rel_pos'] = mC_aligned_df['rel_pos_y'] - mC_aligned_df['rel_pos_x']
        mC_aligned_df = mC_aligned_df.drop(columns=['rel_pos_lower', 'rel_pos_upper', 'rel_pos_x', 'rel_pos_y'])
        nanotools.display_sample_rows(mC_aligned_df)

        # Increment the suffix for read_id
        mC_aligned_df['read_id'] = mC_aligned_df['read_id'].astype(str) + '_' + str(read_id_suffix)
        read_id_suffix += 1

        # Append mC_aligned_df to output_df
        if output_df.empty:
            output_df = mC_aligned_df
        else:
            output_df = pd.concat([output_df, mC_aligned_df])

        # 2. Group, Calculate Fractions, and Filter
        print(f"*** Grouping, calculating, and filtering for condition {selec_cond} ***")
        grouped_df = mC_aligned_df.groupby('rel_pos')[
            ['canonical_base', 'query_kmer', 'mod_qual']
        ].apply(lambda g: pd.Series([
            g[g['canonical_base'] == 'A'].shape[0],
            g[(g['canonical_base'] == 'C') & (g['query_kmer'].str[1].isin(['C', 'G']))].shape[0],
            g[(g['canonical_base'] == 'A') & (g['mod_qual'] > min_mod_qual)].shape[0],
            g[(g['canonical_base'] == 'C') & (g['query_kmer'].str[1].isin(['C', 'G'])) & (g['mod_qual'] > min_mod_qual)].shape[0]
        ], index=['A', 'GC_CC', 'm6A', '5mC']))
        grouped_df.reset_index(inplace=True)
        nanotools.display_sample_rows(grouped_df)

        grouped_df['A_frac'] = grouped_df['A'] / (grouped_df['A'] + grouped_df['GC_CC'])
        grouped_df['GC_CC_frac'] = grouped_df['GC_CC'] / (grouped_df['A'] + grouped_df['GC_CC'])
        grouped_df['m6A_frac'] = grouped_df['m6A'] / grouped_df['A']
        grouped_df['5mC_frac'] = grouped_df['5mC'] / grouped_df['GC_CC']

        # Apply rolling average smoothing
        m6A_frac_0 = grouped_df.loc[grouped_df['rel_pos'] == 0, 'm6A_frac'].iloc[0]
        grouped_df.loc[grouped_df['rel_pos'] == 0, 'm6A_frac'] = np.nan
        grouped_df['m6A_frac_smoothed'] = grouped_df['m6A_frac'].rolling(window=smoothing_window, center=True).mean()
        grouped_df.loc[grouped_df['rel_pos'] == 0, 'm6A_frac_smoothed'] = m6A_frac_0

        fivemC_frac_0 = grouped_df.loc[grouped_df['rel_pos'] == 0, '5mC_frac'].iloc[0]
        grouped_df.loc[grouped_df['rel_pos'] == 0, '5mC_frac'] = np.nan
        grouped_df['5mC_frac_smoothed'] = grouped_df['5mC_frac'].rolling(window=smoothing_window, center=True).mean()
        grouped_df.loc[grouped_df['rel_pos'] == 0, '5mC_frac_smoothed'] = fivemC_frac_0

        mC_aligned_dfs.append(grouped_df)

    # Plot m6A_frac vs. 5mC_frac for all conditions (skipping bp 0)
    fig1 = go.Figure()
    for i, mC_aligned_df in enumerate(mC_aligned_dfs):
        fig1.add_trace(go.Scatter(x=mC_aligned_df[mC_aligned_df['rel_pos'] != 0]['rel_pos'], y=mC_aligned_df[mC_aligned_df['rel_pos'] != 0]['m6A_frac_smoothed'], mode='lines', name=f'{selec_conds[i]} - m6A_frac'))
        fig1.add_trace(go.Scatter(x=mC_aligned_df[mC_aligned_df['rel_pos'] != 0]['rel_pos'], y=mC_aligned_df[mC_aligned_df['rel_pos'] != 0]['5mC_frac_smoothed'], mode='lines', name=f'{selec_conds[i]} - 5mC_frac'))
    fig1.update_layout(title='Methylation Fractions (Smoothed)', xaxis_title='Absolute Relative Position', yaxis_title='Fraction', template='plotly_white', showlegend=True)

    # Plot count of A and count of GC_CC for all conditions
    fig2 = go.Figure()
    for i, mC_aligned_df in enumerate(mC_aligned_dfs):
        fig2.add_trace(go.Scatter(x=mC_aligned_df['rel_pos'], y=mC_aligned_df['A'], mode='markers+lines', name=f'{selec_conds[i]} - Count of A'))
        fig2.add_trace(go.Scatter(x=mC_aligned_df['rel_pos'], y=mC_aligned_df['GC_CC'], mode='markers+lines', name=f'{selec_conds[i]} - Count of GC_CC'))
    fig2.update_layout(title='Count of A and GC_CC', xaxis_title='Absolute Relative Position', yaxis_title='Count', template='plotly_white', showlegend=True)

    fig1.show(renderer='plotly_mimetype+notebook')
    fig2.show(renderer='plotly_mimetype+notebook')

    nanotools.display_sample_rows(output_df)
    return output_df

# Execute with your data and configurations
mC_aligned_df = process_dataframe(
    down_sampled_plot_df.copy(),
    min_mod_qual=0.8,
    selec_conds=["60_old_ama1_3xGNB_GFPHia5_mChMCVIPI", "54_mixed_sdc2_3xmCNB_mChMCVIPI_GFPHia5"],
    selec_type="all_rex",
    selected_chr_type="X",
    num_read_ids=500,
    rel_pos_window=500,
    smoothing_window=50
)


In [None]:
### Read plot without nucleosomes
# Reimport nanotools
importlib.reload(nanotools)
### READ PLOT + NUCLEOSOME PLOT
def create_plot(plot_df, condition, chr_type, data_type, plot_window, min_prob, bw_selection, bigwig_df_cropped,max_read_ids=100):
    print("Creating dataframes...")
    plot_df_copy = plot_df[(plot_df['condition'] == condition) &
                                (plot_df['chr_type'] == chr_type) &
                                (plot_df['type'] == data_type) &
                                (plot_df['rel_pos'] > -plot_window) &
                                (plot_df['rel_pos'] < plot_window) &
                                (plot_df['mod_qual'] > min_prob)]
    # reset index
    plot_df_copy.reset_index(inplace=True, drop=True)
    nanotools.display_sample_rows(plot_df_copy)

    # drop rows where canonical_base == C, query_kmer does not have a G or a C at character 2 and mod_qual < min_prob
    plot_df_copy = plot_df_copy[~((plot_df_copy['canonical_base'] == 'C') &
                                  (plot_df_copy['query_kmer'].str[1] != 'G') &
                                  #(plot_df_copy['query_kmer'].str[1] != 'C') &
                                  (plot_df_copy['mod_qual'] < min_prob))]



    # Sort the DataFrame by 'bed_start' and 'bed_end'
    plot_df_copy = plot_df_copy.sort_values(by=['bed_start', 'bed_end','ref_strand'], ascending=[True, True,True])
    plot_df_copy.reset_index(inplace=True, drop=True)

    # Keep the first max_read_ids read_ids starting from the top
    plot_df_copy = plot_df_copy[plot_df_copy['read_id'].isin(plot_df_copy['read_id'].unique()[:max_read_ids])]

    plot_df_copy_nodups = plot_df_copy.drop_duplicates(subset=['read_id'])[['read_id']]
    plot_df_copy_nodups.reset_index(inplace=True, drop=True)
    plot_df_copy_nodups['read_count'] = range(1, len(plot_df_copy_nodups) + 1)
    nanotools.display_sample_rows(plot_df_copy_nodups, 5)
    nanotools.display_sample_rows(plot_df_copy, 10)

    #merge the read_count column back into plot_df_copy
    plot_df_copy = pd.merge(plot_df_copy, plot_df_copy_nodups[['read_id', 'read_count']], on='read_id', how='left')



    #multiply read_count by 1.2
    #plot_df_copy['read_count'] = plot_df_copy['read_count']*1.2

    # Create a subplot with 3 rows and 1 column
    fig = make_subplots(rows=4,
                        cols=1,
                        shared_xaxes=True,
                        vertical_spacing=0.02,
                        specs=[[{}], [{}], [{}], [{}]],
                        row_heights=[0.65, 0.05, 0.15,0.15])

    # Update xaxes for all subplots
    fig.update_xaxes(range=[-plot_window, plot_window])

    # Calculate sum and count of mod_qual at each rel_pos
    agg_df_m6a = plot_df_copy[plot_df_copy["mod_code"]=="a"].groupby('rel_pos')['mod_qual_bin'].agg(['sum', 'count']).reset_index()
    agg_df_5mC = plot_df_copy[plot_df_copy["mod_code"]=="m"].groupby('rel_pos')['mod_qual_bin'].agg(['sum', 'count']).reset_index()

    agg_df_m6a['ratio'] = agg_df_m6a['sum'] / agg_df_m6a['count']
    agg_df_5mC['ratio'] = agg_df_5mC['sum'] / agg_df_5mC['count']

    # Calculate the moving average of the ratio with a centered window of 20
    rolling_window_size=25
    agg_df_m6a['moving_avg'] = agg_df_m6a['ratio'].rolling(window=rolling_window_size, center=True).mean()
    agg_df_5mC['moving_avg'] = agg_df_5mC['ratio'].rolling(window=rolling_window_size, center=True).mean()

    #drop nan values
    agg_df_m6a.dropna(inplace=True)
    agg_df_5mC.dropna(inplace=True)

    ### Upper scatter plot
    # Configuration: Minimum number of consecutive modifications
    min_consecutive_mods = 1
    min_prob_plus = min_prob-0.00001

    print("Number of rows where mod_qual < min_prob:", plot_df_copy[plot_df_copy['mod_qual'] < min_prob].shape[0])
    if min_consecutive_mods == 1:
        #  plot every m7a modification
        scatter_trace_m6A = go.Scatter(
            x=plot_df_copy[plot_df_copy['mod_code']=='a']['rel_pos'],
            y=plot_df_copy[plot_df_copy['mod_code']=='a']['read_count'],
            mode='markers',
            marker=dict(size=6,
                        color=plot_df_copy[plot_df_copy['mod_code']=='a']['mod_qual'],
                        colorscale=[
                            [0,'rgba(255, 255, 255, 0)'],
                            [min_prob_plus,'rgba(255, 255, 255, 0)'],
                            #[0,'#ffffff'],
                            #[min_prob_plus,'#ffffff'],
                            [min_prob,'#FF5733'],
                            [1,'#FF5733']],
                        #line=dict(
                        #    color='darkgrey',
                        #    width=0.25
                        #)
                        )
        )
        # Add scatter trace to the figure
        fig.add_trace(scatter_trace_m6A, row=1, col=1)

        # add histogram to 3rd row of the rel_pos of all mod_code == a
        num_bins = 100  # Set your desired number of bins here
        hist_trace = go.Histogram(x=plot_df_copy[plot_df_copy['mod_code']=='a']['rel_pos'], nbinsx=num_bins, marker_color='#FF5733')
        fig.add_trace(hist_trace, row=3, col=1)

    else:
        # Initialize lists for scatter plot and lines
        scatter_x = []
        scatter_y = []
        line_x = []
        line_y = []
        m6a_cluster_midpoints = []

        for read_id, group in plot_df_copy[plot_df_copy['mod_code'] == 'a'].groupby('read_id'):
            # Sort by relative position
            sorted_group = group.sort_values(by='rel_pos')

            # Variables to track consecutive modifications and midpoints
            consecutive_count = 1
            temp_scatter_x = []
            temp_scatter_y = []
            temp_line_x = []
            temp_line_y = []
            cluster_start_pos = None
            prev_row = None

            # Iterate through sorted 'a' modifications in the current read
            for _, row in sorted_group.iterrows():
                if prev_row is not None and (row['rel_pos'] - prev_row['rel_pos']) <= 10:
                    # Increase count and add to temporary lists
                    consecutive_count += 1
                    if consecutive_count == min_consecutive_mods:
                        cluster_start_pos = prev_row['rel_pos']
                    temp_line_x.extend([prev_row['rel_pos'], row['rel_pos']])
                    temp_line_y.extend([prev_row['read_count'], row['read_count']])
                else:
                    # Check if previous group met the threshold
                    if consecutive_count >= min_consecutive_mods:
                        cluster_midpoint = (cluster_start_pos + prev_row['rel_pos']) / 2
                        m6a_cluster_midpoints.append(cluster_midpoint)
                        scatter_x.extend(temp_scatter_x)
                        scatter_y.extend(temp_scatter_y)
                        line_x.extend(temp_line_x + [None])  # None to break the line segment
                        line_y.extend(temp_line_y + [None])

                    # Reset for the next group of modifications
                    consecutive_count = 1
                    temp_scatter_x = [row['rel_pos']]
                    temp_scatter_y = [row['read_count']]
                    temp_line_x = []
                    temp_line_y = []
                    cluster_start_pos = None

                prev_row = row

            # Check for the last group in the read
            if consecutive_count >= min_consecutive_mods:
                cluster_midpoint = (cluster_start_pos + prev_row['rel_pos']) / 2
                m6a_cluster_midpoints.append(cluster_midpoint)
                scatter_x.extend(temp_scatter_x)
                scatter_y.extend(temp_scatter_y)
                line_x.extend(temp_line_x + [None])
                line_y.extend(temp_line_y + [None])

        # Create scatter trace for 'a' modifications
        scatter_trace_m6A = go.Scatter(x=scatter_x, y=scatter_y, mode='markers',
                                       marker=dict(size=4, color='#FF5733'))

        # Create and add line trace for connections
        line_trace_m6A = go.Scatter(x=line_x, y=line_y, mode='lines',
                                    line=dict(color='#FF5733', width=4))

        fig.add_trace(line_trace_m6A, row=1, col=1)

        # Add scatter trace to the figure
        fig.add_trace(scatter_trace_m6A, row=1, col=1)

        # Histogram subplot
        num_bins = 100  # Set your desired number of bins here
        hist_trace = go.Histogram(x=m6a_cluster_midpoints, nbinsx=num_bins, marker_color='#FF5733')
        fig.add_trace(hist_trace, row=3, col=1)

    scatter_trace_5mC = go.Scatter(
        x=plot_df_copy[plot_df_copy['mod_code']=='m']['rel_pos'],
        y=plot_df_copy[plot_df_copy['mod_code']=='m']['read_count'],
        mode='markers',
        marker=dict(
            symbol='triangle-up',
            size=6,
            color=plot_df_copy[plot_df_copy['mod_code']=='m']['mod_qual'],
            colorscale=[
                [0, 'rgba(255, 255, 255, 0)'],  # yellow
                [min_prob_plus, 'rgba(255, 255, 255, 0)'],  # yellow
                [min_prob, '#0047ab'],  # sky blue
                [1, '#0047ab']  # sky blue
            ]
            #line=dict(
            #    color='darkgrey',
            #    width=0.25
            #)
        )
    )

    fig.add_trace(scatter_trace_5mC, row=1, col=1)

    # add histogram to 4th row of the rel_pos of all mod_code == a
    num_bins = 100  # Set your desired number of bins here
    hist_trace_5mC = go.Histogram(x=plot_df_copy[plot_df_copy['mod_code']=='m']['rel_pos'], nbinsx=num_bins, marker_color='#87CEEB')
    fig.add_trace(hist_trace_5mC, row=4, col=1)

    print("Adding m6a line traces...")
    # Add line traces for each unique read_id
    for read_id in plot_df_copy['read_count'].unique():
        read_data = plot_df_copy[(plot_df_copy['read_count'] == read_id)]
        read_data = read_data[read_data['mod_code'] == 'a']
        #if read_data has >0 rows
        if len(read_data) > 0:
            #nanotools.display_sample_rows(read_data, 5)
            # Add a line trace for the read

            min_rel_pos = read_data['rel_pos'].min()
            max_rel_pos = read_data['rel_pos'].max()

            fig.add_trace(
                go.Scatter(x=[min_rel_pos, max_rel_pos], y=[read_data['read_count'].iloc[0],read_data['read_count'].iloc[0]],
                           mode='lines', line=dict(color='#000000', width=0.2),showlegend=False),row=1, col=1
            )
            # set y range



    # Update layout
    fig.update_layout(template="simple_white",
                      height=1200,
                      width=1600,
                      )
    fig.update_yaxes(title_text="Read_ID", row=1, col=1)
    fig.update_yaxes(title_text="Localization Counts", row=3, col=1)
    #fig.update_yaxes(title_text="Nucleosome Probability", row=3, col=1)
    fig.update_xaxes(title_text="Genomic location (bp)", row=3, col=1)
    fig.update_xaxes(range=[-plot_window, plot_window], row=3, col=1)

    # Add Rex Line
    fig.add_shape(
        go.layout.Shape(
            type="line",
            x0=0,
            x1=0,
            y0=0,
            y1=1,
            yref="paper",
            line=dict(
                color="grey",
                width=1,
                dash="dash",
            )
        )
    )

    ### Chip plot
    if bw_selection is not None:
        bigwig_df_cropped = bigwig_df_cropped[(bigwig_df_cropped['rel_start'] <= bed_window) & (bigwig_df_cropped['rel_start'] >= -bed_window)]
        # Apply filters appropriately
        filters = []
        filters.append(bigwig_df_cropped['condition'] == bw_selection)
        filters.append(bigwig_df_cropped['chr_type'] == chr_type)
        filters.append(bigwig_df_cropped['type'] == data_type)

        base_filter = np.logical_and.reduce(filters)

        value_data = bigwig_df_cropped.loc[base_filter]['value']
        value_data_xaxis = bigwig_df_cropped.loc[base_filter]['rel_start']

        smoothed_data = value_data.rolling(window=rolling_window_size, center=True).mean()
        y_min = float('inf')
        y_max = float('-inf')
        y_min = min(y_min, smoothed_data.min())
        y_max = max(y_max, smoothed_data.max())

        label = f"{bw_selection}_{selected_chr_type}_{data_type}"  # Add selected_strand here

        fig.add_trace(
            go.Scatter(
                x=value_data_xaxis.values,
                y=value_data.values,
                mode='lines',
                name=label,
                opacity=0.9,
                line=dict(color='green', width=1.5)),
            row=2, col=1
        )

    # Generate a list of shades from light grey to dark grey
    num_groups = len(plot_df_copy.groupby(['bed_start', 'ref_strand']))
    color_list = ['rgb({}, {}, {})'.format(i, i, i) for i in range((num_groups*10), 0, -(num_groups*10) // num_groups)]

    # Create a dictionary to store the tick positions and labels
    tick_positions = []
    tick_labels = []

    # Add vertical bars for each group of read IDs
    for i, ((bed_start, ref_strand), group) in enumerate(plot_df_copy.groupby(['bed_start', 'ref_strand'])):
        min_read_count = group['read_count'].min()
        max_read_count = group['read_count'].max()
        mid_read_count = (min_read_count + max_read_count) / 2

        # Get the rightmost element of the "chrom" column (split on "_")
        chrom_rightmost = group['chrom'].iloc[0].split('_')[-1]

        # Add the tick position and label to the dictionary
        tick_positions.append(mid_read_count)
        tick_labels.append(f"{chrom_rightmost}:{bed_start} ({ref_strand})")

        # Add a vertical bar to designate the read IDs corresponding to the label
        fig.add_shape(
            type='line',
            x0=-plot_window,  # Adjust this value to position the vertical bar within the plot area
            y0=min_read_count,
            x1=-plot_window,  # Adjust this value to position the vertical bar within the plot area
            y1=max_read_count,
            line=dict(color=color_list[i], width=10),
            row=1,
            col=1
        )

    # Update the y-axis tick labels
    fig.update_yaxes(
        tickmode='array',
        tickvals=tick_positions,
        ticktext=tick_labels,
        row=1,
        col=1
    )

    """ #if more than one bed_start in plot_df_copy
    if len(plot_df_copy['bed_start'].unique()) > 1:
        # Add light grey line boxes around reads with the same bed_start value
        for bed_start, group in plot_df_copy.groupby('bed_start'):
            min_rel_pos = group['rel_pos'].min()
            max_rel_pos = group['rel_pos'].max()
            min_read_count = group['read_count'].min()
            max_read_count = group['read_count'].max()

            # Define the rectangle shape
            fig.add_shape(
                type="rect",
                x0=min_rel_pos, y0=min_read_count,
                x1=max_rel_pos, y1=max_read_count,
                line=dict(
                    color="lightgrey",
                    width=2,
                ),
                row=1, col=1
            )"""

    return fig

# Sample usage of the function
#analysis_cond = ["N2_mixed_endogenous_R10","50_mixed_dpy27-3xGNB_GFP-Hia5_mcvipi_R10","54_old_MCVIPIsdc2_LMN1pAhia5_R10","87_old_GFPhia5dpy27_mCmcvipi_ama1_R10","54_mixed_sdc2_3xmCNB_mChMCVIPI_GFPHia5","N2_mixed_DPY27_dimelo_pAHia5_R10"]
selec_cond = "54_mixed_sdc2_3xmCNB_mChMCVIPI_GFPHia5" #50_dpy27dimelo_mcvipi "50_dpy27-3xGNB_GFP-Hia5_mcvipi"
selec_type = "all_rex" #center_SDC2_chip_albretton
selected_chr_type = "X"
# set title to path: "images/"+selec_cond+"_t"+selec_type+"_r"+str(n_read_ids)+"_b"+str(bed_window)+"png"
fig_title = "/Data1/git/meyer-nanopore/scripts/Analysis/combined_bam_analysis/images/50"+selec_cond+"_t_"+selec_type+"_r"+str(n_read_ids)+"_b"+str(bed_window)+"_chr"+selected_chr_type+"_1-12-2024.png"

N2_fig = create_plot(down_sampled_plot_df,selec_cond, selected_chr_type, selec_type, 500, 0.8,None,None,300)
#down_sampled_plot_df
#mC_aligned_df
#plot_comb_bigwig_df
print("saving image...")
print(fig_title)
N2_fig.show(renderer='plotly_mimetype+notebook')
# save N2_fig to png
#N2_fig.write_image(fig_title)
print("Done!")

#SDC2_fig = create_plot(down_sampled_plot_df, "SDC2_degron_fiber", "X", "TSS_q4", plot_window)
#SDC2_fig.show(renderer='plotly_mimetype+notebook')

In [None]:
# Reimport nanotools
importlib.reload(nanotools)
from scipy.signal import find_peaks
### READ PLOT + NUCLEOSOME PLOT
def create_plot(plot_df, group_df, condition, chr_type, data_type, plot_window,plot_nucs=False, min_prob=0):
    print("Creating dataframes...")
    plot_df_copy = plot_df.copy(deep=True)
    plot_df_copy = plot_df_copy[(plot_df_copy['condition'] == condition) &
                                (plot_df_copy['chr_type'] == chr_type) &
                                (plot_df_copy['type'] == data_type) &
                                (plot_df_copy['rel_pos'] > -plot_window) &
                                (plot_df_copy['rel_pos'] < plot_window) &
                                (plot_df_copy['mod_qual'] > min_prob)]

    # drop rows where both smallest_positive_nuc_midpoint and greatest_negative_nuc_midpoint are NaN
    plot_df_copy = plot_df_copy[~(plot_df_copy['smallest_positive_nuc_midpoint'].isna() & plot_df_copy['greatest_negative_nuc_midpoint'].isna())]
    plot_df_copy = plot_df_copy.sort_values(by=['smallest_positive_nuc_midpoint', 'greatest_negative_nuc_midpoint'],ascending=[True, False])
    plot_df_copy_nodups = plot_df_copy.drop_duplicates(subset=['read_id'])[['read_id','smallest_positive_nuc_midpoint', 'greatest_negative_nuc_midpoint']]
    plot_df_copy.reset_index(inplace=True, drop=True)
    plot_df_copy_nodups.reset_index(inplace=True, drop=True)
    # use ngroup to create a incrementing column in ascending order
    plot_df_copy_nodups['read_count'] = range(1, len(plot_df_copy_nodups) + 1)

    #merge the read_count column back into plot_df_copy
    plot_df_copy = pd.merge(plot_df_copy, plot_df_copy_nodups[['read_id', 'read_count']], on='read_id', how='left')

    # drop rows from down_sampled_group_df_copy where read_id not in plot_df_copy read_ids
    down_sampled_group_df_copy = group_df.copy(deep=True)
    down_sampled_group_df_copy = down_sampled_group_df_copy[down_sampled_group_df_copy['read_id'].isin(plot_df_copy_nodups['read_id'])]
    # merge read_count column from plot_df_copy_no_dups with down_sampled_group_df_copy on read_id
    down_sampled_group_df_copy = pd.merge(down_sampled_group_df_copy, plot_df_copy_nodups[['read_id', 'read_count']], on='read_id', how='left')
    #drop rows where nucs_list is nan
    down_sampled_group_df_copy.dropna(subset=['nucs_list'], inplace=True)
    nanotools.display_sample_rows(down_sampled_group_df_copy,10)

    #display(plot_df_copy.head(100))

    # Create a subplot with 3 rows and 1 column
    fig = make_subplots(rows=3,
                        cols=1,
                        shared_xaxes=True,
                        vertical_spacing=0.02,
                        specs=[[{}], [{}], [{"secondary_y": True}]],# [{}]],
                        row_heights=[0.7, 0.15, 0.15])

    # Update xaxes for all subplots
    fig.update_xaxes(range=[-plot_window, plot_window])


    #print("plot_df_copy")
    #display(plot_df_copy.head(10))
    # Calculate sum and count of mod_qual at each rel_pos
    agg_df = plot_df_copy.groupby('rel_pos')['mod_qual_bin'].agg(['sum', 'count']).reset_index()
    agg_df['ratio'] = agg_df['sum'] / agg_df['count']
    # Calculate the moving average of the ratio with a centered window of 20
    rolling_window_size=25
    agg_df['moving_avg'] = agg_df['ratio'].rolling(window=rolling_window_size, center=True).mean()
    #drop nan values
    agg_df.dropna(inplace=True)
    #print("agg_df:",agg_df)
    #display(agg_df.head(100))

    # create occupancy_df where columns are read_id and
    # Assuming genome_size is known
    genome_size = 2 * plot_window  # Replace with your actual genome size

    # Initialize a numpy array with zeros for each base pair in the genome region
    read_counts = np.zeros(genome_size)

    # Upper scatter plot
    scatter_trace = go.Scatter(x=plot_df_copy['rel_pos'], y=plot_df_copy['read_count'], mode='markers',
                               marker=dict(size=2, color=plot_df_copy['mod_qual'], colorscale=[[0, '#FF5733'], [1, '#FF5733']])) #33B8FF
    fig.add_trace(scatter_trace, row=1, col=1)

    print("Adding m6a line traces...")
    # Add line traces for each unique read_id
    for read_id in plot_df_copy['read_count'].unique():
        read_data = plot_df_copy[plot_df_copy['read_count'] == read_id]
        min_rel_pos = read_data['rel_pos'].min()
        max_rel_pos = read_data['rel_pos'].max()

        ## FOR CALCULATING OCCUPANCY
        # Loop through the range of positions between min and max positions
        for pos in range(int(min_rel_pos + plot_window), int(max_rel_pos + plot_window + 1)):
            if 0 <= pos < genome_size:  # Check if pos is within the range
                read_counts[pos] += 1

        fig.add_trace(
            go.Scatter(x=[min_rel_pos, max_rel_pos], y=[read_data['read_count'].iloc[0],read_data['read_count'].iloc[0]],
                       mode='lines', line=dict(color='#000000', width=0.2),showlegend=False),row=1, col=1
        )
        # set y range

    # drop rows where mod_qual == 0
    #plot_df_dropped = plot_df_copy[plot_df_copy['mod_qual'] != 0]
    if plot_nucs == False:
        print("Skipping nucleosome plotting...")
    elif plot_nucs == True:
        print("Plotting nucleosomes...")
        ### PLOT NUCLEOSOMES
        midpoints_list = []
        x_coords = []
        y_coords = []
        # add a blue line for each read_count in down_sampled_group_df using read_count as y value and for each value in nuc_list (value-nuc_width/2),(value-nuc_width/2) as x values
        for read_id in down_sampled_group_df_copy['read_count']:

            read_data = down_sampled_group_df_copy[down_sampled_group_df_copy['read_count'] == read_id]
            #print(read_data['nucs_list'])
            # for each value in nuc_list column
            # Initialize an empty list to store the x and y coordinates for the scatter plot

            read_height = read_data['read_count'].iloc[0]
            # drop nucs from nucs_list that are outside of plot_window
            read_data['nucs_list'] = read_data['nucs_list'].apply(lambda x: [nuc for nuc in x if nuc >= -plot_window and nuc <= plot_window])

            # Loop through the nucleotides and populate x_coords and y_coords
            for nuc in read_data['nucs_list'].iloc[0]:
                midpoints_list.append(nuc)  # Assuming midpoints_list is already defined
                min_rel_pos = nuc - NUC_width / 2
                max_rel_pos = nuc + NUC_width / 2

                x_coords.extend([min_rel_pos, max_rel_pos, None])  # Use None to separate individual line segments
                y_coords.extend([read_height, read_height, None])

        # Add a single trace for all line segments
        fig.add_trace(
            go.Scatter(
                x=x_coords,
                y=y_coords,
                mode='lines',
                line=dict(color='#33B8FF', width=2),
                opacity=0.75,
                showlegend=False
            ),
            row=1,
            col=1
            )

        # Lower line plot for moving average of the ratio
        line_trace = go.Scatter(x=agg_df['rel_pos'], y=agg_df['moving_avg'], mode='lines',
                                #set color to match blue
                                line=dict(color='#FF5733', width=2),
                                # smooth line
                                line_shape='spline')
        fig.add_trace(line_trace, row=2, col=1)

        print("Plotting histogram...")
        #midpoints_df = pd.DataFrame.from_dict(midpoints_dict, orient='index')
        #display(midpoints_list)
        # Add midpoint plot
        rolling_window_size_hist = 20
        hist_bins = int(round(2*plot_window/10)+1)
        midpoint_histogram = go.Histogram(x=midpoints_list,
                                          #histnorm='density',
                                          nbinsx=hist_bins,
                                          marker=dict(color='#33B8FF',opacity=0.8)
                                          ) #
        fig.add_trace(midpoint_histogram, row=3, col=1,secondary_y=False)

        ### OVERLAY GAUSSIAN SMOOTHED PLOT ON TOP OF HISTOGRAM

        print("Plotting gaussian smoothed plot...")

        # Assume midpoints_list contains midpoints of nucleosomes for the current plot
        genome_size = 2 * plot_window  # Define the genome size based on the plot_window

        # Initialize a numpy array with zeros for each base pair in the genome region
        nucleosome_array = np.zeros(genome_size)

        # Populate the nucleosome_array based on the midpoints
        for midpoint in midpoints_list:
            # Convert midpoint to an integer index
            position_index = int(midpoint + plot_window)  # Shift by plot_window to handle negative positions
            if 0 <= position_index < genome_size:  # Check if position_index is within the range
                nucleosome_array[position_index] += 1

        """# Calculate the mean nucleosome density, avoiding division by zero
        mean_density = np.mean(nucleosome_array[nucleosome_array > 0])
        if mean_density == 0:
            raise ValueError("Mean nucleosome density is zero. Check your midpoint values.")

        # Scale by 1/(mean nucleosome density)
        scaled_nucleosome_array = nucleosome_array / mean_density"""

        nucleosome_normalized = np.divide(nucleosome_array, read_counts, where=read_counts != 0)

        # Apply Gaussian smoothing with a standard deviation of 20 base pairs
        smoothed_nucleosome_array = gaussian_filter1d(nucleosome_normalized, 10)

        # Generate x values for the smoothed density plot, shifting back by plot_window to align with the original coordinates
        x_values = np.arange(-plot_window, plot_window, 1)

        # Add the smoothed nucleosome density as a line trace to the third subplot
        smoothed_trace = go.Scatter(
            x=x_values,
            y=smoothed_nucleosome_array,
            mode='lines',
            name='Smoothed Nucleosome Density',
            line=dict(color='#007dfa', width=2),  # Adjust color and width as desired
        )

        print("Plotting peaks...")
        # Find indices of peaks in the smoothed nucleosome array
        peaks, _ = find_peaks(smoothed_nucleosome_array)

        # The y-range for the vertical lines
        y_range = [smoothed_nucleosome_array.min(), smoothed_nucleosome_array.max()]

        # Add vertical lines for each peak
        for peak_idx in peaks:
            # Convert index to x-coordinate
            peak_pos = x_values[peak_idx]

            fig.add_trace(
                go.Scatter(
                    x=[peak_pos, peak_pos],
                    y=y_range,
                    mode='lines',
                    line=dict(color='grey', width=0.5, dash='dash'),
                    showlegend=False
                ),
                row=3, col=1, secondary_y=True
            )

        # Add the new trace to the subplot
        fig.add_trace(smoothed_trace, row=3, col=1,secondary_y=True)
        ###

    print("Plotting bigwig...")
    ## MNASE
    """bigwig_trace = nanotools.create_bigwig_trace("/Data1/reference/lieb_mnase_2017/GSM2098437_RT_rep1_MNaseTC_30m_smoothDyads_ce11.bw", plot_df_copy)"""
    ## GRO MINUS
    #bigwig_trace = nanotools.create_bigwig_trace("/Data1/reference/lieb_gro_2013/GSM1056279_GRO-seq_N2_Emb_replicateAVG_WS230_RPKM_minus_ce11.bw", plot_df_copy)
    ## GRO PLUS
    #bigwig_trace = nanotools.create_bigwig_trace("/Data1/reference/lieb_gro_2013/GSM1056279_GRO-seq_N2_Emb_replicateAVG_WS230_RPKM_plus_ce11.bw", plot_df_copy)
    # Now iterate through the list of traces and add them to the figure
    """for trace in bigwig_trace:
        fig.add_trace(trace, row=4, col=1)"""

    # Update layout
    fig.update_layout(template="simple_white",
                      height=800,
                      width=1100,
                      )
    fig.update_yaxes(title_text="Read_ID", row=1, col=1)
    fig.update_yaxes(title_text="% m6A", row=2, col=1)
    fig.update_yaxes(title_text="Nucleosome Probability", row=3, col=1)
    fig.update_xaxes(title_text="Genomic location (bp)", row=3, col=1)
    # set y max to 60

    # Add Rex Line
    fig.add_shape(
        go.layout.Shape(
            type="line",
            x0=0,
            x1=0,
            y0=0,
            y1=1,
            yref="paper",
            line=dict(
                color="grey",
                width=1,
                dash="dash",
            )
        )
    )
    #fig.add_annotation(
    #    x=0,
    #    y=1,
    #    yref="paper",
    #    text="rex",
    #    showarrow=False,
    #    font=dict(
    #        size=15,
    #        color="grey"
    #    )
    #)

    return fig

# Sample usage of the function
selec_cond = "N2-DPY27_dimelo_pAHia5" #50_dpy27dimelo_mcvipi
selec_type = "weak_rex"
selected_chr_type = "X"
# set title to path: "images/"+selec_cond+"_t"+selec_type+"_r"+str(n_read_ids)+"_b"+str(bed_window)+"png"
fig_title = "/Data1/git/meyer-nanopore/scripts/Analysis/combined_bam_analysis/images/50"+selec_cond+"_t_"+selec_type+"_r"+str(n_read_ids)+"_b"+str(bed_window)+"_chr"+selected_chr_type+"_1-12-2024.png"

N2_fig = create_plot(down_sampled_plot_df, down_sampled_group_df,selec_cond, selected_chr_type, selec_type, int(round(bed_window,0)),plot_nucs=False, min_prob=0.9)

print("saving image...")
print(fig_title)
N2_fig.show(renderer='plotly_mimetype+notebook')
# save N2_fig to png
N2_fig.write_image(fig_title)
print("Done!")

#SDC2_fig = create_plot(down_sampled_plot_df, "SDC2_degron_fiber", "X", "TSS_q4", plot_window)
#SDC2_fig.show(renderer='plotly_mimetype+notebook')

In [None]:
import pyBigWig
from scipy.ndimage import gaussian_filter1d
from scipy.stats import spearmanr

def calculate_correlations(grouped_df, bigwig_paths, bin_size):
    exp_data = {}
    correlations = []  # To store correlation results
    bw_objects = [pyBigWig.open(path) for path in bigwig_paths]  # Open all bigwig files
    grouped_df_copy = grouped_df.copy(deep=True)
    ### Temporarily replace "CHROMOSOME_" with "chr" in chrom column in grouped_df
    grouped_df_copy['chrom'] = grouped_df_copy['chrom'].apply(lambda x: x.replace("CHROMOSOME_", "chr"))
        # Adjust positions in the 'nucs_list' column

    def min_max_normalize(array):
        return (array - array.min()) / (array.max() - array.min())

    def adjust_positions(row):
        mid_point = row['bed_start'] + (row['bed_end'] - row['bed_start']) // 2
        # Adjust each position in the list
        return [int(pos + mid_point - 1) for pos in row['nucs_list']]

    # Apply the adjustment to each row
    grouped_df_copy['adjusted_nucs_list'] = grouped_df_copy.apply(adjust_positions, axis=1)

    # Drop rows where 'adjusted_nucs_list' is empty or not a list
    grouped_df_copy = grouped_df_copy[grouped_df_copy['adjusted_nucs_list'].map(lambda d: isinstance(d, list) and len(d) > 0)]

    # Reset index after dropping rows
    grouped_df_copy.reset_index(drop=True, inplace=True)


    # Iterate over each experiment and chromosome
    for (exp_id, chrom, bed_start, bed_end, cond), group in grouped_df_copy.groupby(['exp_id','chrom', 'bed_start', 'bed_end', 'condition']):
        print(f"Processing {exp_id}, {chrom},bed:,{bed_start}-{bed_end}")
        # Get nucleosome positions and bin them for each chromosome
        all_positions = np.concatenate(group['nucs_list'].values)
        # Subtract bed_end-bed_start/2 from each position in all_positions
        all_positions = all_positions + (bed_end - bed_start) // 2
        # drop all positions that are outside of the bed window
        all_positions = all_positions[(all_positions >= 0) & (all_positions <= (bed_end - bed_start))]
        binned_positions = all_positions // bin_size
        # initialize binned_nucleosome_counts as a series of 0s between 0 and bed_end-bed_start
        binned_nucleosome_counts = pd.Series(0, index=np.arange(0, (bed_end - bed_start) // bin_size))
        # Count the number of nucleosomes in each bin
        for position in binned_positions:
            binned_nucleosome_counts[position] += 1

        # if nucleosome_array is not the same length as binned_nucleosome_counts, initialize nucleosome_array with 0s
        #if len(nucleosome_array) != len(binned_nucleosome_counts):
        nucleosome_array = np.zeros(len(binned_nucleosome_counts))

        # Populate the nucleosome_array based on the binned counts
        for bin_start, count in binned_nucleosome_counts.items():
            position_index = int(bin_start * bin_size)
            nucleosome_array[position_index:int(position_index + bin_size)] += count

        # Calculate the smoothed nucleosome array
        smoothed_nucleosome_array = gaussian_filter1d(nucleosome_array, 20)
        smoothed_nucleosome_array = min_max_normalize(smoothed_nucleosome_array)

        # Store the smoothed array with its exp_id and condition
        if cond not in exp_data:
            exp_data[cond] = {}
        if exp_id not in exp_data[cond]:
            exp_data[cond][exp_id] = []

        exp_data[cond][exp_id].append(smoothed_nucleosome_array)

    # make list of unique chrom ,bed_start and bed_end in grouped_df
    chrom_bed_start_bed_end_list = grouped_df_copy[['chrom', 'bed_start', 'bed_end']].drop_duplicates().values.tolist()
    print("chrom_bed_start_bed_end_list:",chrom_bed_start_bed_end_list)
    # For each bigwig replicate
    #for chrom, bed_start, bed_end in chrom_bed_start_bed_end_list:
    for chrom, bed_start, bed_end in chrom_bed_start_bed_end_list:
        print(f"Processing {chrom},bed:,{bed_start}-{bed_end}")
        for i, bw in enumerate(bw_objects):
            cond = "N2-MNase"
            exp_id = "MNase-seq-rep" + str(i + 1)
            # Get bigwig values for the entire chromosome and then trim to match the length of smoothed nucleosome array
            bigwig_values = bw.values(chrom, bed_start, bed_end)

            # Convert to numpy array and handle None values
            bigwig_values = np.nan_to_num(bigwig_values)
            bigwig_values = min_max_normalize(bigwig_values)
            # Store the smoothed array with its exp_id and condition
            if cond not in exp_data:
                exp_data[cond] = {}
            if exp_id not in exp_data[cond]:
                exp_data[cond][exp_id] = []
            exp_data[cond][exp_id].append(bigwig_values)

    # Now calculate pairwise correlations for each condition
    pairwise_correlations = []
    all_exp_ids = [exp_id for exps in exp_data.values() for exp_id in exps]

    # Ensure you have a list of unique experiment IDs if they can repeat across conditions
    all_exp_ids = list(set(all_exp_ids))

    for i in range(len(all_exp_ids)):
        for j in range(len(all_exp_ids)):
            exp_id1 = all_exp_ids[i]
            exp_id2 = all_exp_ids[j]

            # Find the condition for each experiment ID
            condition1 = [cond for cond, exps in exp_data.items() if exp_id1 in exps][0]
            condition2 = [cond for cond, exps in exp_data.items() if exp_id2 in exps][0]

            # Now perform the correlation check for each k
            for k in range(len(exp_data[condition1][exp_id1])):
                # When exp_id1 is the same as exp_id2, we are comparing the same arrays
                if exp_id1 == exp_id2:
                    correlation = 1.0
                else:
                    array1 = exp_data[condition1][exp_id1][k]
                    array2 = exp_data[condition2][exp_id2][k]
                    if array1.any() and array2.any():
                        correlation = spearmanr(array1, array2)[0]
                    else:
                        correlation = np.nan  # Assign NaN if either array is empty

                pairwise_correlations.append((exp_id1, exp_id2, condition1, condition2, correlation))


    # Close bigwig files
    for bw in bw_objects:
        bw.close()

    # Convert results to dataframe and pivot to wide format for heatmap
    # Construct a DataFrame from the pairwise correlations
    pairwise_correlation_df = pd.DataFrame(pairwise_correlations, columns=['exp_id1', 'exp_id2', 'condition1', 'condition2', 'correlation'])
    # Step 1 & 2: Combine the 'exp_id' and 'condition' columns

    pairwise_correlation_df['exp_condition1'] = pairwise_correlation_df['condition1'] + '-' + pairwise_correlation_df['exp_id1']
    pairwise_correlation_df['exp_condition2'] = pairwise_correlation_df['condition2'] + '-' + pairwise_correlation_df['exp_id2']

    # Step 3: Pivot the DataFrame
    pivot_df = pairwise_correlation_df.pivot(index='exp_condition1', columns='exp_condition2', values='correlation')

    # Display the first 3 rows of the pivoted DataFrame
    display(pivot_df.head(3))

    return pivot_df

# Define bigwig paths and bin size
# Replace with your actual bigwig paths
bigwig_paths = ["/Data1/reference/lieb_mnase_2017/GSM2098437_RT_rep1_MNaseTC_30m_smoothDyads_ce11.bw", "/Data1/reference/lieb_mnase_2017/GSM2098437_RT_rep2_MNaseTC_30m_smoothDyads_ce11.bw"]  # Replace with your actual bigwig file paths

bin_size = 1  # Define your bin size accordingly

nanotools.display_sample_rows(grouped,10)

# Calculate correlations
correlation_df = calculate_correlations(grouped, bigwig_paths, bin_size)

# Display the correlation table (first 3 rows)
nanotools.display_sample_rows(correlation_df)

def plot_heatmap(correlation_matrix):
    heatmap = go.Figure(data=go.Heatmap(
        z=correlation_matrix.values,  # Correlation values
        x=correlation_matrix.columns,  # exp_id as x-axis
        y=correlation_matrix.index,  # exp_id as y-axis
        colorscale='Viridis'
    ))

    heatmap.update_layout(
        title='Heatmap of Pearson Correlation Coefficients',
        xaxis_title="exp_id (X-axis)",
        yaxis_title="exp_id (Y-axis)",
        template="simple_white"
    )

    heatmap.show()

# Assuming correlation_df is a correlation matrix with exp_ids as index and columns
plot_heatmap(correlation_df)


In [None]:
### STANDARD READ PLOT
def create_plot(plot_df, condition, chr_type, data_type, bed_window):
    plot_df_copy = plot_df.copy()
    plot_df_copy.reset_index(inplace=True, drop=True)
    # Filter the DataFrame based on the specified condition, chr_type, and data_type
    plot_df_copy = plot_df_copy[(plot_df_copy['condition'] == condition) & (plot_df_copy['chr_type'] == chr_type) & (plot_df_copy['type'] == data_type)]
    plot_df_copy = plot_df_copy.sort_values(by=['smallest_positive_nuc_midpoint', 'greatest_negative_nuc_midpoint','rel_pos'])
    plot_df_copy.reset_index(inplace=True, drop=True)
    # Create a lookup table of unique read_ids and read_count
    read_id_lookup = plot_df_copy[['read_id', 'read_count','smallest_positive_nuc_midpoint', 'greatest_negative_nuc_midpoint']].drop_duplicates().reset_index(drop=True)
    # reset read_count column to increment by 1 for each row
    read_id_lookup['read_count'] = read_id_lookup.index + 1
    #print("read_id_lookup:")
    #display(read_id_lookup.head(100))
    # Create a new column 'read_count' in plot_df_copy by mapping the read_id_lookup
    plot_df_copy['read_count'] = plot_df_copy['read_id'].map(read_id_lookup.set_index('read_id')['read_count'])

    #display(plot_df_copy.head(100))

    # Create a subplot with 3 rows and 1 column
    fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.02, row_heights=[0.5, 0.2, 0.2])

    # Update xaxes for all subplots
    fig.update_xaxes(range=[-bed_window, bed_window])


    #print("plot_df_copy")
    #display(plot_df_copy.head(10))
    # Calculate sum and count of mod_qual at each rel_pos
    agg_df = plot_df_copy.groupby('rel_pos')['mod_qual'].agg(['sum', 'count']).reset_index()
    agg_df['ratio'] = agg_df['sum'] / agg_df['count']
    # Calculate the moving average of the ratio with a centered window of 20
    agg_df['moving_avg'] = agg_df['ratio'].rolling(window=50, center=True).mean()
    #drop nan values
    agg_df.dropna(inplace=True)
    #print("agg_df:",agg_df)
    #display(agg_df.head(100))

    # Add line traces for each unique read_id
    for read_id in plot_df_copy['read_count'].unique():
        read_data = plot_df_copy[plot_df_copy['read_count'] == read_id]
        min_rel_pos = read_data['rel_pos'].min()
        max_rel_pos = read_data['rel_pos'].max()
        fig.add_trace(
            go.Scatter(x=[min_rel_pos, max_rel_pos], y=[read_data['read_count'].iloc[0]] * 2,
                       mode='lines', line=dict(color='#000000', width=0.2),showlegend=False),row=1, col=1
        )
    # drop rows where mod_qual == 0
    plot_df_dropped = plot_df_copy[plot_df_copy['mod_qual'] != 0]

    # Upper scatter plot
    scatter_trace = go.Scatter(x=plot_df_dropped['rel_pos'], y=plot_df_dropped['read_count'], mode='markers',
                               marker=dict(size=2, color=plot_df_dropped['mod_qual'], colorscale=[[0, '#FFFFFF'], [1, '#0000FF']]))
    fig.add_trace(scatter_trace, row=1, col=1)



    # Lower line plot for moving average of the ratio
    line_trace = go.Scatter(x=agg_df['rel_pos'], y=agg_df['moving_avg'], mode='lines',
                            #set color to match blue
                            line=dict(color='#0000FF', width=1),
                            # smooth line
                            line_shape='spline')
    fig.add_trace(line_trace, row=2, col=1)

    ### PLOT NUCLEOSOMES
    #NUC_max_width  = 160 # Distance below which m6A marks are combined into NUC
    #NUC_min_width  = 100 # Distance above which m6A marks are combined into NUC
    # Group by 'read_count'
    grouped = plot_df_dropped.groupby('read_count')
    midpoints_dict = {}
    for read_count, group in grouped:
        # Sort by 'rel_pos'
        group = group.sort_values(by='rel_pos')
        # Initialize an empty list to hold midpoints for this read_count
        midpoints_list = []
        for i in range(len(group) - 1):
            x1 = group.iloc[i]['rel_pos']
            x2 = group.iloc[i + 1]['rel_pos']
            y = read_count
            # Check if the x-values are < MAD_dist_max apart
            if x2 - x1 > NUC_min_width  and x2 - x1 < NUC_max_width :
                # Check if there are no other points between x1 and x2
                in_between = group[(group['rel_pos'] > x1) & (group['rel_pos'] < x2)]
                if in_between.empty:
                    fig.add_trace(
                        go.Scatter(x=[x1, x2], y=[y, y], mode='lines',
                                   line=dict(color='#FF9999', width=1), showlegend=False),
                        row=1, col=1
                    )
                    # Calculate the midpoint and add to list
                    midpoint = (x1 + x2) / 2
                    midpoints_list.append(midpoint)

            # Add methylase accessible DNA sequences
            if x2 - x1 <MAD_dist_max:
                # Check if there are no other points between x1 and x2
                in_between = group[(group['rel_pos'] > x1) & (group['rel_pos'] < x2)]
                if in_between.empty:
                    fig.add_trace(
                        go.Scatter(x=[x1, x2], y=[y, y], mode='lines',
                                   line=dict(color='#0000FF', width=1), showlegend=False,opacity=0.5),
                        row=1, col=1
                    )

        # Add the list of midpoints to the dictionary
        midpoints_dict[read_count] = midpoints_list

    midpoints_df = pd.DataFrame.from_dict(midpoints_dict, orient='index')
    midpoint_series = [midpoint for sublist in midpoints_dict.values() for midpoint in sublist]

    # Add midpoint plot
    midpoint_trace = go.Histogram(x=midpoint_series, histnorm='probability', nbinsx=100, marker=dict(color='#FF9999'))
    fig.add_trace(midpoint_trace, row=3, col=1)

    # Update layout
    fig.update_layout(template="simple_white")
    fig.update_yaxes(title_text="Read_ID", row=1, col=1)
    fig.update_yaxes(title_text="% m6A", row=2, col=1)
    fig.update_yaxes(title_text="Nucleosome Probability", row=3, col=1)
    fig.update_xaxes(title_text="Genomic location (bp)", row=3, col=1)
    fig.update_layout(height=800)
    fig.update_layout(width=1100)
    # set y max to 60

# Add Rex Line
    fig.add_shape(
        go.layout.Shape(
            type="line",
            x0=0,
            x1=0,
            y0=0,
            y1=1,
            yref="paper",
            line=dict(
                color="grey",
                width=1,
                dash="dash",
            )
        )
    )
    #fig.add_annotation(
    #    x=0,
    #    y=1,
    #    yref="paper",
    #    text="rex",
    #    showarrow=False,
    #    font=dict(
    #        size=15,
    #        color="grey"
    #    )
    #)

    return fig

# Sample usage of the function
N2_fig = create_plot(down_sampled_plot_df, "N2_fiber", "X", "center_DPY27_chip_albretton;SDC2_ol1000;SDC3_ol1000", bed_window)
N2_fig.write_image("images/N2_fiber_sdc3_sdc2_dpy27-fibers.png",width=1600,height=1300)
#N2_fig.show(renderer='plotly_mimetype+notebook')
#SDC2_fig = create_plot(down_sampled_plot_df, "SDC2_degron_fiber", "X", "TSS_q4", bed_window)
#SDC2_fig.show(renderer='plotly_mimetype+notebook')

In [None]:
# Save each fig to /images folder as svg, with the filename incorporating the condition and chr_type and type
N2_fig.write_image("images/SDC2-degron_fiber_sdc3_sdc2_dpy27-fibers.svg")
#SDC2_fig.write_image("images/SDC2_degron_fiber_X_TSS_q4.svg")

In [None]:
### Plotting nucleosome offset
# Convert the dictionary to a DataFrame
midpoints_df = pd.DataFrame.from_dict(midpoints_dict, orient='index')
#display(midpoints_dict)

print("midpoints_df:")
display(midpoints_df.head(100))

# Find the least positive value for each row
least_positive_per_row = midpoints_df[midpoints_df > 0].min(axis=1)
#least_positive_per_row = least_positive_per_row[least_positive_per_row<480]

# Find the least negative value for each row
least_negative_per_row = midpoints_df[midpoints_df < 0].max(axis=1)
#least_negative_per_row = least_negative_per_row[least_negative_per_row > -480]

# Calculate the difference between the least positive and least negative nucleosome position for each read
differences = least_positive_per_row -least_negative_per_row

# Plot the distribution of the differences using a histogram
fig_diff = go.Figure()
fig_diff.add_trace(go.Histogram(x=differences, marker=dict(color='#FF9999'),nbinsx=100))
fig_diff.update_layout(title="Distribution of Nucleosome Position Differences",
                       xaxis_title="Difference between Least Positive and Least Negative Nucleosome Position",
                       yaxis_title="Frequency",
                       template="simple_white")
fig_diff.show()

# Identify the least negative and the second least negative nucleosome positions for each read
least_negative_per_row_sorted = midpoints_df[midpoints_df < 0].apply(lambda x: sorted(x.dropna()), axis=1)
second_least_negative_per_row = least_negative_per_row_sorted.apply(lambda x: x[-2] if len(x) > 1 else np.nan)

# Calculate the difference between the least negative and the second least negative nucleosome position for each read
diff_second_least = least_negative_per_row - second_least_negative_per_row

# Remove NaN values (reads that might not have a second least negative position)
diff_second_least = diff_second_least.dropna()

# Plot the distribution of the differences using a histogram
fig_diff_second = go.Figure()
fig_diff_second.add_trace(go.Histogram(x=diff_second_least, marker=dict(color='#FF9999'),nbinsx=100))
fig_diff_second.update_layout(title="Distribution of Differences between Least Negative and Second Least Negative Nucleosome Positions",
                              xaxis_title="Difference between Least Negative and Second Least Negative Nucleosome Position",
                              yaxis_title="Frequency",
                              template="simple_white")
fig_diff_second.show()

# Compute the averages
avg_least_positive = least_positive_per_row.mean()
avg_least_negative = least_negative_per_row.mean()
print("avg_least_positive:",avg_least_positive)
print("avg_least_negative:",avg_least_negative)

# Calculate the mean of each column, ignoring NaN values
range_n = 9
assigned_pos = [avg_least_positive + (160 * i) for i in range(range_n)]
assigned_neg = [avg_least_negative - ((range_n-1)*160) + (160*i) for i in range(range_n)]

assigned_col = assigned_neg + assigned_pos
print("assigned_col:")
print(assigned_col)

# Create an empty DataFrame to store the rearranged values
# Make it large enough to accommodate shifts; you can adjust the size as needed
max_cols = midpoints_df.shape[1] * 2  # Example size, adjust as needed
rearranged_df = pd.DataFrame(index=midpoints_df.index, columns=range(max_cols))
# Create an empty DataFrame to store the rearranged values
rearranged_df = pd.DataFrame(index=midpoints_df.index, columns=range(len(assigned_col)))


# Iterate through each row to find the closest column mean and rearrange
for idx, row in midpoints_df.iterrows():
    row_values = row.dropna().values  # Drop NaN values
    if len(row_values) == 0:  # Skip empty rows
        continue

    # For each value in row, find the closest column mean
    for value in row_values:
        closest_column = np.argmin(np.abs(assigned_col - value))
        rearranged_df.at[idx, closest_column] = value  # Place the value in the closest column

# Drop columns that are entirely NaN, if desired
rearranged_df.dropna(axis=1, how='all', inplace=True)

display(rearranged_df.head(100))

In [None]:

# find average value of each column and save as a list
mean_list = []
for col in rearranged_df.columns:
    mean_list.append(rearranged_df[col].mean())
print("mean_list:")
print(mean_list)

for index, row in rearranged_df.iterrows():
    non_nan_indices = row.dropna().index.tolist()

    if non_nan_indices:  # Check if there are any non-NaN values in the row
        start, end = non_nan_indices[0], non_nan_indices[-1]
        # Use the column average for filling NaNs
        rearranged_df.loc[index, start+1:end] = rearranged_df.loc[index, start+1:end].fillna(100000)

# For each column, calculate the % of non-NaN values == 100000
percent_100000 = (rearranged_df == 100000).sum() / len(rearranged_df)
print("percent_100000:")
print(percent_100000)
# Plot a go bar plot of percent_100000 with
fig = go.Figure(data=go.Bar(x=["n-9","n-8","n-7","n-6","n-5","n-4","n-3","n-2","n-1","n+1","n+2","n+3","n+4","n+5","n+6","n+7","n+8","n+9"], y=percent_100000.values))

# Create an empty DataFrame with the same shape as rearranged_df to store the mean differences
mean_diff_df = pd.DataFrame(index=rearranged_df.index, columns=rearranged_df.columns)
print("Rearranged df:")
display(rearranged_df.head(100))
# Iterate through each row of rearranged_df
for idx, row in rearranged_df.iterrows():
    for col in rearranged_df.columns:
        current_value = row[col]

        # Check if the value is NaN; if so, continue to the next iteration
        if pd.isna(current_value):
            continue

        if current_value == 100000.0:
            #mean_diff_df.at[idx, col] = 80
            continue

        mean_diff_df.at[idx, col] = abs(current_value - mean_list[col])
         # Calculate the differences between the current value and all other values in the row
        differences = rearranged_df.subtract(current_value, axis=1)
        #differences.at[idx, col] = np.nan
        if idx == 1:
            print("differences:")
            print(differences)

        # For the current row, filter differences with absolute values less than 80
        #valid_diffs = differences.loc[idx][differences.loc[idx].abs() < 80].abs()

        # Take all differences with absolute values less than 80 and convert to a list
        valid_diffs = differences[abs(differences) < 160].values.flatten().tolist()

        # drop all nan values
        valid_diffs = [x for x in valid_diffs if str(x) != 'nan']

        # Take absolute value
        valid_diffs = [abs(x) for x in valid_diffs]

        # Count number of 100000 values in current column in rearranged_df and add this many "80"s to valid diff list
        #valid_diffs.extend([80] * (rearranged_df[col] == 100000).sum())


        # if first iteration, print valid diffs:
        if idx == 1:
            print("valid_diffs:")
            print(valid_diffs)

        # Get the smallest difference value for the row
        #smallest_diff = valid_diffs.min() if not valid_diffs.empty else np.nan  # Set to NaN if there are no valid differences

        #get the mean of valid_diffs list
        mean_diff_df.at[idx, col] = np.mean(valid_diffs) if valid_diffs else np.nan
        # Store the smallest difference in the mean_diff_df
        #mean_diff_df.at[idx, col] = smallest_diff

print("mean_diff_df:")
display(mean_diff_df.head(10))

rearranged_df_abs_diff = mean_diff_df.abs()
#set col names to "n-7","n-6",...,"n+7"
rearranged_df_abs_diff.columns = ["n-9","n-8","n-7","n-6","n-5","n-4","n-3","n-2","n-1","n+1","n+2","n+3","n+4","n+5","n+6","n+7","n+8","n+9"]
#display(rearranged_df_abs_diff)

# If you want to combine n- and n+ nucleosomes:
combine_nucleosomes=0
if combine_nucleosomes==1:
    # Separate the columns into two sets: 'n-x' and 'n+x'
    cols_n_minus = ["n-9","n-8","n-7", "n-6", "n-5", "n-4", "n-3", "n-2", "n-1"]
    cols_n_plus = ["n+1", "n+2", "n+3", "n+4", "n+5", "n+6", "n+7","n+8","n+9"]

    # Extract columns for 'n-x' and 'n+x'
    n_minus_df = rearranged_df_abs_diff[cols_n_minus]
    n_plus_df = rearranged_df_abs_diff[cols_n_plus]

    # Flip the 'n-x' columns in reverse order
    flipped_n_minus_df = n_minus_df[cols_n_minus[::-1]]

    # Convert both DataFrames to NumPy arrays
    n_plus_np = n_plus_df.to_numpy()
    flipped_n_minus_np = flipped_n_minus_df.to_numpy()

    # Vertically concatenate the NumPy arrays
    result_np = np.vstack([n_plus_np, flipped_n_minus_np])

    # Convert the result back to a DataFrame with the original 'n+x' column names
    result_df = pd.DataFrame(result_np, columns=cols_n_plus)

else:
    result_df = rearranged_df_abs_diff

#print 100 rows of rearranged_df
print("result_df:")
display(result_df.head(10))

#print(result_df)
# Prepare data for box plot
data = []
for col in result_df.columns:
    col_data = result_df[col].dropna()  # Remove NaN values
    trace = go.Box(
        y=col_data,
        name=str(col),
        boxpoints='all',  # Show all points
        jitter=0.3,  # Add some jitter for visibility
        pointpos=-1.8  # Position of the points
    )
    data.append(trace)

# Create layout
layout = go.Layout(
    title="Box Plot of Rearranged DataFrame",
    xaxis=dict(title="Index"),
    yaxis=dict(title="Values")
)

# Create the figure
fig = go.Figure(data=data, layout=layout)

#set theme to plotly_white
fig.update_layout(template="simple_white")

# Add label for average values to each box plot
for i, col in enumerate(result_df.columns):
    col_data = result_df[col].dropna()  # Remove NaN values
    print(col_data.mean())
    fig.add_annotation(
        x=i ,
        y=col_data.mean(),
        text=f"{col_data.mean():.2f}",
        showarrow=False,
        font=dict(
            size=10,
            color="black"
        )
    )

# Show the figure
fig.show()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

# Function to plot data for each Cluster_ID
def plot_cluster(df_cluster):
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.02, row_heights=[0.7, 0.3])

    # Aggregate mod_qual by rel_pos
    agg_df = df_cluster.groupby('rel_pos')['mod_qual'].agg(['sum', 'count']).reset_index()
    agg_df['ratio'] = agg_df['sum'] / agg_df['count']
    agg_df['moving_avg'] = agg_df['ratio'].rolling(window=50, center=True).mean()
    agg_df.dropna(inplace=True)

    # Add line traces for each unique read_id
    for read_id in df_cluster['read_id'].unique():
        read_data = df_cluster[df_cluster['read_id'] == read_id]
        min_rel_pos = read_data['rel_pos'].min()
        max_rel_pos = read_data['rel_pos'].max()
        fig.add_trace(
            go.Scatter(x=[min_rel_pos, max_rel_pos], y=[read_data['read_count'].iloc[0]] * 2,
                       mode='lines', line=dict(color='gray', width=0.1), showlegend=False),
            row=1, col=1
        )

    # Drop rows where mod_qual == 0
    df_cluster_dropped = df_cluster[df_cluster['mod_qual'] != 0]

    # Upper scatter plot
    scatter_trace = go.Scatter(x=df_cluster_dropped['rel_pos'], y=df_cluster_dropped['read_count'], mode='markers',
                               marker=dict(size=3, color=df_cluster_dropped['mod_qual'], colorscale=[[0, '#FFFFFF'], [1, '#0000FF']]))
    fig.add_trace(scatter_trace, row=1, col=1)

    # Lower line plot for moving average of the ratio
    line_trace = go.Scatter(x=agg_df['rel_pos'], y=agg_df['moving_avg'], mode='lines',
                            line=dict(color='#0000FF', width=1),
                            line_shape='spline')
    fig.add_trace(line_trace, row=2, col=1)

    # Update layout
    fig.update_layout(template="simple_white", height=800)

    # Show figure
    fig.show()

# Assuming plot_df is already loaded
# Group by Cluster_ID and plot each group
for cluster_id, group_df in plot_df.groupby('Cluster_ID'):
    print(f"Plotting for Cluster_ID: {cluster_id}")
    plot_cluster(group_df)

In [None]:

### Achieve the same as above using modbampy
# NOTE This is unable to define methylation threshold
# print("Kicking off loop for:",new_bed_files,"and",new_bam_files)
'''import modbampy
importlib.reload(modbampy)
from modbampy import ModBam

# Assuming bedwindow is already defined in your code
columns = [int(i) for i in range(1, 2*bed_window + 1)]
index_desc = ['a', 'c', 'g', 't', 'A', 'C', 'G', 'T', 'd', 'D', 'm', 'M', 'f', 'F', 'n', 'N','Unk1','Unk2']
'''A, C, G, T are the usual DNA bases,
D indicates deletion counts,
M modified base counts,
F filtered counts - bases in reads with a modified-base record but which were filtered according to the thresholds provided.
N no call base counts.'''
counts_df = pd.DataFrame(columns=columns,index=index_desc)
#Set all values to 0
counts_df = counts_df.fillna(0)

# Load the BED file
print("Kicking off loop for:",new_bed_files,"and",new_bam_files)
for bed_file in new_bed_files:
    print("Starting on bed_file:",bed_file)
    regions = pysam.TabixFile(bed_file)
    # Iterate over the regions in the BED file
    for region in regions.fetch(multiple_iterators=True):
        #print("Region:",region)
        # Split the region string into the chromosome, start, and end positions
        chromosome, start, end, strand, region_type, chr_type = region.split()
        start = int(start)
        end = int(end)
        for bam_file in new_bam_files[0]:
            with ModBam(bam_file) as bam:
                positions, counts = bam.pileup(chromosome, start, end,threshold=0.1,mod_base="a")
            positions_reset = positions - start
            #convert positions from np array to list
            positions_list = positions_reset.tolist()

            # Create temp_counts_df from numpy array counts
            #print("counts shape",counts.shape)
            #print("counts",counts)

            #print("counts shape",counts.T.shape)
            #print("counts.T",counts.T)
            temp_counts_df = pd.DataFrame(counts.T,index=index_desc,columns=positions_list)
            # set row index to be the same as counts_df
            #print("print(temp_counts_df.shape)",temp_counts_df.shape)
            #print("print(counts_df.shape)",counts_df.shape)
            #print("Temp_count_df:",temp_counts_df)
            counts_df = counts_df.astype(float)
            temp_counts_df = temp_counts_df.astype(float)

            # Sum temp_counts_df to counts_df
            counts_df = counts_df.add(temp_counts_df, fill_value=0)


print("COUNTS_DF:",counts_df)
print("shape of counts_df:",counts_df.shape)
'''

'''counts_df_plot = counts_df.copy()
# Merge rows A and a, C and c, G and g, T and t, M and m, D and d, F and f, N and n, Unk1 and Unk2
counts_df_plot.loc['A'] = counts_df_plot.loc['A'] + counts_df_plot.loc['a']
counts_df_plot.loc['C'] = counts_df_plot.loc['C'] + counts_df_plot.loc['c']
counts_df_plot.loc['G'] = counts_df_plot.loc['G'] + counts_df_plot.loc['g']
counts_df_plot.loc['T'] = counts_df_plot.loc['T'] + counts_df_plot.loc['t']
counts_df_plot.loc['M'] = counts_df_plot.loc['M'] + counts_df_plot.loc['m']
counts_df_plot.loc['D'] = counts_df_plot.loc['D'] + counts_df_plot.loc['d']
counts_df_plot.loc['F'] = counts_df_plot.loc['F'] + counts_df_plot.loc['f']
counts_df_plot.loc['N'] = counts_df_plot.loc['N'] + counts_df_plot.loc['n']
counts_df_plot.loc['Unk1'] = counts_df_plot.loc['Unk1'] + counts_df_plot.loc['Unk2']
#drop merged rows
counts_df_plot = counts_df_plot.drop(['a','c','g','t','m','d','f','n','Unk2'])
# Add m6A_frac row
counts_df_plot.loc['m6A_frac'] = counts_df_plot.loc['M'] / (counts_df_plot.loc['A'])
print(counts_df_plot.index)
#Sort dataframe by columns in ascending order
counts_df_plot = counts_df_plot.sort_index(axis=1)'''

'''# Plot m6A_frac in a line plot using plotly
# Plotting with Plotly
# Set plotly renderer to notebook

# Compute the moving average for smoothing
window_size = 25  # This defines the number of data points to use for each average value
m6A_data = counts_df_plot.loc['m6A_frac']
smoothed_data = m6A_data.rolling(window=window_size).mean()

# Plotting with Plotly
fig = go.Figure()

# Original data
fig.add_trace(go.Scatter(x=m6A_data.index, y=m6A_data.values,
                         mode='lines',
                         name='Original'))

# Smoothed data
fig.add_trace(go.Scatter(x=smoothed_data.index, y=smoothed_data.values,
                         mode='lines',
                         name=f'Smoothed (window size: {window_size})'))

fig.update_layout(title='m6A Fraction vs Genomic Position',
                  xaxis_title='Genomic Position',
                  yaxis_title='m6A Fraction')

# Set to plotly white theme
fig.update_layout(template="plotly_white")

fig.show()'''

In [None]:
### Extract m6A frac by region
importlib.reload(nanotools)
result_list=[]
result_df=pd.DataFrame()
# Parallelize for each bam file:
args_list = [(bam_file, condition, bam_frac,file_prefix, selection, m6A_thresh, output_stem,new_bed_files) for bam_file, condition, bam_frac in zip(new_bam_files,conditions,bam_fracs)]
print("Args list:",args_list)
if __name__ == "__main__":
    with Pool(processes=10) as pool: #processes=1
        # append results to pandas df 'result'
        result_list = pool.starmap(nanotools.extract_m6A_per_region_parellized, args_list)
        for result in result_list:
            result_df=pd.concat([result_df,result])
    print("Program finished!")

print(result_df)

In [None]:
### Build dataframe for plotting
def reindex_df(df, weight_col):
    """expand the dataframe to prepare for resampling
    result is 1 row per count per sample"""
    df.reset_index(drop=True, inplace=True)
    df = df.reindex(df.index.repeat(np.ceil(df[weight_col])/100000))
    df.reset_index(drop=True, inplace=True)
    return(df)

'''# If combined regions file already exists, read dataframe from csv
if os.path.exists(output_stem  + file_prefix + "weighted_combined_regions_"  + str(m6A_thresh) +".csv"):
    weighted_combined_regions = pd.read_csv(output_stem + file_prefix + "weighted_combined_regions_"  + str(m6A_thresh) +".csv")
    print("File: ",
          output_stem +  file_prefix+"combined_regions_"  + str(m6A_thresh) +".csv",
          "already exists! Imported directly:")
    print(weighted_combined_regions)

else:'''
print("Building combined regions file...")
# Initialize variables
filenames = []
df_list = []
combined_regions = []

# Create "filenames" list that includes the name of each file to be read
for each_type in selection:
    for each_cond, each_frac in zip(conditions,bam_fracs):
        filenames.append(output_stem + file_prefix+"m6A_frac_" + each_cond + "_"  + str(m6A_thresh)+"_"+each_type+".csv")

# Loop through the list of file names
for filename in filenames:
    # Read each file into a dataframe
    df = pd.read_csv(filename)
    # Add the dataframe to the list of dataframes
    df_list.append(df)

# Concatenate the list of dataframes into a single dataframe
combined_regions = pd.concat(df_list)

# Reindex the dataframe to have the number repeated rows based on total bases in the region
# This helps ensure plots are weighted correctly.
weighted_combined_regions = reindex_df(combined_regions,'total_bases')

# Add column equal to average of autosome m6A_frac column for each condition
weighted_combined_regions['mean_autosome_m6A_frac'] = weighted_combined_regions.groupby('condition')['m6A_frac'].transform('mean')

# Add column equal to m6A normalized by the condition's mean_autosome_m6A_frac
weighted_combined_regions['norm_m6A_frac'] = weighted_combined_regions['m6A_frac']/weighted_combined_regions['mean_autosome_m6A_frac']

# Save final dataframe to .csv file
print("Weighted combined:",weighted_combined_regions)
print("Outputting file:",output_stem  + file_prefix+"combined_regions_"  + str(m6A_thresh) +".csv")
weighted_combined_regions.to_csv(output_stem  + file_prefix+"weighted_combined_regions_"  + str(m6A_thresh) +".csv", index=False, mode='w')

# Extract average m6A/A across each chromosome for each condition from weighted_combined_regions
# This is used for plotting the average m6A/A across the chromosome
chromosome_m6A_frac = weighted_combined_regions.groupby(['condition','condition_min','chr_type'])['m6A_frac'].median().reset_index()
# split condition column with character "-" and keep only first column
chromosome_m6A_frac['genotype'] = chromosome_m6A_frac['condition'].str.split('-').str[0]

# sort by genotype, chr_type and condition_min
chromosome_m6A_frac.sort_values(by=['genotype','chr_type','condition_min'], inplace=True)

#Add column for increase in methylation from previous timepoint for each condition and each chr_type, where the first timepoint is 0
chromosome_m6A_frac['m6A_frac_diff'] = chromosome_m6A_frac.groupby(['genotype','chr_type'])['m6A_frac'].diff()
# Set all Nan values in m6A_frac_diff to 0
chromosome_m6A_frac['m6A_frac_diff'].fillna(0, inplace=True)

#reset index
chromosome_m6A_frac.reset_index(drop=True, inplace=True)

# normalize m6A_frac_diff by the first m6A_frac value for each genotype and chr_type
print("chromosome_m6A_frac.groupby(['genotype','chr_type'])['m6A_frac'].transform(lambda x: x/x.iloc[0]):",chromosome_m6A_frac.groupby(['genotype','chr_type'])['m6A_frac'].transform(lambda x: x.iloc[0]))

chromosome_m6A_frac['norm_m6A_frac_diff'] = chromosome_m6A_frac['m6A_frac_diff']/chromosome_m6A_frac.groupby(['genotype','chr_type'])['m6A_frac'].transform(lambda x: x.iloc[0])

chromosome_m6A_frac['m6A_frac_diff_from_first'] = chromosome_m6A_frac['m6A_frac']-chromosome_m6A_frac.groupby(['genotype','chr_type'])['m6A_frac'].transform(lambda x: x.iloc[0])

print(chromosome_m6A_frac)

In [None]:
# Plot average m6A/A across the chromosome for each condition in a time course
# Set px background to white
px.defaults.template = "plotly_white"

# list of samples to consider
considered_samples = [0]

# Plot title
#plot_title = "AID::SDC-2 + Auxin; 2uM Hia5 Timecourse; m6A thresh = 75%"
plot_title = "Mean m6A/A across entire chromosomes; m6A Threshold = " + str(round(m6A_thresh/254*100-1)) + "%"

# plot boxplot of norm_m6A_frac by chromosome
fig = px.box(result_df, x="condition", y="m6A_frac", color="chromosome", title=plot_title, points="all")
#Update background to white
fig.update_layout(plot_bgcolor='white')
fig.show()

# plot boxplot of norm_m6A_frac by chromosome
fig = px.box(result_df, x="condition", y="m6A_frac", color="chr_type", title=plot_title, points="all")
fig.update_layout(plot_bgcolor='white')
fig.show()

In [None]:
# Plot average m6A/A across the chromosome for each condition in a time course


# list of samples to consider
considered_samples = [0,1,2]

# Plot title
#plot_title = "AID::SDC-2 + Auxin; 2uM Hia5 Timecourse; m6A thresh = 75%"
plot_title = "Mean m6A/A on X Chromosome 200kb Regions;<br>3min 2uM Hia5 treatment; m6A thresh = " + str(m6A_thresh/254) + "%"

# Plot the boxplot
marker_colors =["#c45746","#16415e"]

plotly_conditions = conditions
#plotly_conditions = ["N2<br>No-Met","N2<br>3-min","N2<br>10-min","N2<br>30-min", "N2<br>120-min",
#"#021+Aux<br>No-Met","#021+Aux<br>3-min","#021+Aux<br>10-min","#021+Aux<br>30-min", "#021+Aux<br>120-min"]

fig = make_subplots(rows=1, cols=len(considered_samples),
                y_title = "m6A/A",
                shared_yaxes=True,
                subplot_titles=(list( plotly_conditions[i] for i in considered_samples )))

plot_iter=0
print("weighted_combined_regions ",weighted_combined_regions)
for i in considered_samples:
    tube_df = weighted_combined_regions.loc[weighted_combined_regions['condition']==conditions[i]]
    chr_type = "Autosome"
    df_plot=tube_df.loc[tube_df['chr_type']==chr_type]
    #df_plot=tube_df.sample(frac=17/100,replace=False,random_state=1)
    trace0 = go.Box(x=df_plot['condition']+" ", y=df_plot['m6A_frac'], #+ " " makes box plots not overlap
                         name=chr_type, marker_color =marker_colors[1],)
    chr_type = "X"
    df_plot=tube_df.loc[tube_df['chr_type']==chr_type]
    trace1 = go.Box(x=df_plot['condition'], y=df_plot['m6A_frac'],
                         name=chr_type, marker_color=marker_colors[0])#, #add scatter points
                            #boxpoints='all', jitter=0.4, pointpos=0) #jitter for SDC-2 degron and N2 only for 3min
    plot_iter += 1
    fig.append_trace(trace0, row = 1, col = plot_iter)
    fig.append_trace(trace1, row = 1, col = plot_iter)

# remove boxplot fill color
fig.update_traces(fillcolor='rgba(0,0,0,0)')
fig['layout'].update(height = 600,width = 1000)
fig.update_layout(template="plotly_white",title=plot_title)
fig.update_xaxes(showticklabels=False)
fig.update_annotations(font_size=12)
fig.update_traces(marker=dict(size=3))
'''fig = add_p_value_annotation(fig, [[0,1]], 1, _format=dict(interline=0.07, text_height=1.07, color='black'))
fig = add_p_value_annotation(fig, [[0,1]], 2, _format=dict(interline=0.07, text_height=1.07, color='black'))
fig = add_p_value_annotation(fig, [[0,1]], 3, _format=dict(interline=0.07, text_height=1.07, color='black'))
fig = add_p_value_annotation(fig, [[0,1]], 4, _format=dict(interline=0.07, text_height=1.07, color='black'))
fig = add_p_value_annotation(fig, [[0,1]], 5, _format=dict(interline=0.07, text_height=1.07, color='black'))'''
#fig.update_layout(boxmode='group', xaxis_tickangle=0)

for i in range(0,len([0,10])):
    fig.layout.annotations[i].update(y=-0.1)
fig.update_yaxes(tickformat="1%")
fig.show()
#Export plotly figure to .svg
fig.write_image(output_stem + "combined_regions_"  + str(m6A_thresh) +".svg")

In [None]:
# Plot the boxplot
marker_colors =["#fde725","#a0da39","#4ac16d"]#,"#1fa187","#277f8e","#365c8d","#46327e","#440154","#c45746","#16415e"]

plotly_conditions = conditions

fig = make_subplots(rows=1, cols=len(conditions),
                y_title = "Coverage",
                shared_yaxes=True,
                subplot_titles=(plotly_conditions))

print("Total MB aligned for ALL conditons: ",int(combined_regions['total_bases'].sum()/1000000),
     " | across ", int(combined_regions['overlapping_reads'].sum())," reads with avg. length of: ",
     int(combined_regions['total_bases'].sum()/combined_regions['overlapping_reads'].sum()))
for i in range(0,len(conditions)):
    tube_df = combined_regions.loc[combined_regions['condition']==conditions[i]]
    m6A_frac_tube = [tube_df['total_bases'].sum()/100000000*3.125] #3.125 is the scaling factor for adenosines in c elegans genome.
    print("Total MB aligned for ",conditions[i],
          ": ",int(tube_df['total_bases'].sum()/1000000), 
          " | across ", int(tube_df['overlapping_reads'].sum()),
          " reads with avg. length of: ",
          int(tube_df['total_bases'].sum()/tube_df['overlapping_reads'].sum()))
    trace0 = go.Bar(x=tube_df['condition']+" ", y=m6A_frac_tube,
                         name=plotly_conditions[i], marker_color =marker_colors[i])

    fig.append_trace(trace0, row = 1, col = i+1)
    
fig['layout'].update(height = 800)
fig.update_layout(template="plotly_white")
fig.update_xaxes(showticklabels=False)
#fig.update_yaxes(range=[0.7, 1.3])
    
#fig.update_layout(boxmode='group', xaxis_tickangle=0)
fig.show()