In [None]:
import pyarrow as pa
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
import polars as pl
from typing import Union


notebook_dir = os.path.dirname(os.path.abspath("__file__"))
sys.path.append(notebook_dir)
from src.inference.aws_inference import process_jsonl_files


project_path = "/data1/datasets_1/human_cistrome/chip-atlas/peak_calls/tfbinding_scripts/tf-binding"
jaspar_file = "/data1/datasets_1/human_cistrome/chip-atlas/peak_calls/tfbinding_scripts/tf-binding/src/inference/interpretability/motif.jaspar"  # Update this path
model = "AR"
sample = "LuCaP_78"
# ground_truth_file = "/data1/datasets_1/human_prostate_PDX/processed/external_data/ChIP_atlas/AR/SRX8406456.05.bed"
# ground_truth_file = "/data1/projects/human_cistrome/aligned_chip_data/merged_cell_lines/22Rv1/bam/22Rv1_merge.sorted.nodup.shifted.bam"

df = pl.read_parquet(project_path + "/data/processed_results/" + model + "_" + sample + "_processed.parquet")
df = df.rename({"chr_name": "chr"})
df


In [None]:
import tempfile
from src.utils.generate_training_peaks import run_bedtools_command
def intersect_bed_files(main_df: pl.DataFrame, intersect_df: pl.DataFrame, region_type: str = None) -> pl.DataFrame:
    """
    Intersect two BED files using bedtools and return the original DataFrame with overlap flags.
    
    Args:
        main_df: Primary Polars DataFrame with BED data
        intersect_df: Secondary Polars DataFrame to intersect with
        region_type: Optional region type label to add to results
        
    Returns:
        Original DataFrame with additional column indicating overlaps
    """
    with tempfile.NamedTemporaryFile(delete=False, mode='w') as main_file, \
         tempfile.NamedTemporaryFile(delete=False, mode='w') as intersect_file, \
         tempfile.NamedTemporaryFile(delete=False, mode='w') as result_file:
        
        main_path = main_file.name
        intersect_path = intersect_file.name
        result_path = result_file.name

        # Write DataFrames to temporary files
        main_df.write_csv(main_path, separator="\t", include_header=False)
        intersect_df.write_csv(intersect_path, separator="\t", include_header=False)

        # Run bedtools intersect with -c flag to count overlaps
        command = f"bedtools intersect -a {main_path} -b {intersect_path} -c > {result_path}"
        run_bedtools_command(command)

        # Read results back into Polars DataFrame
        result_df = pl.read_csv(
            result_path,
            separator="\t",
            has_header=False,
            new_columns=[*main_df.columns, "overlap_count"]
        )

    # Clean up temporary files
    os.remove(main_path)
    os.remove(intersect_path) 
    os.remove(result_path)

    # Add boolean overlap column
    result_df = result_df.with_columns(
        pl.col("overlap_count").gt(0).alias("overlaps_ground_truth")
    ).drop("overlap_count")

    return result_df

ground_truth_file = "/data1/datasets_1/human_prostate_PDX/processed/external_data/ChIP_atlas/AR/SRX8406455.10.bed"

df_ground_truth = pl.read_csv(ground_truth_file, 
                             separator="\t", 
                             has_header=False,
                             new_columns=["chr", "start", "end"],
                             columns=[0,1,2])

intersected_df = intersect_bed_files(df[["chr", "start", "end"]], df_ground_truth)




intersected_df


chr,start,end,overlaps_ground_truth
str,i64,i64,bool
"""chr5""",10039,11783,true
"""chr2""",739950,740882,false
"""chr10""",10010,10428,true
"""chr21""",46184427,46184968,false
"""chr12""",133264943,133265299,true
…,…,…,…
"""chr17""",8110181,8110532,false
"""chr16""",2474208,2474999,false
"""chr3""",171703978,171704435,false
"""chr7""",101017134,101017572,false


In [3]:
# add overlaps ground truth to df from intersected_df
ground_truth_df = df.join(intersected_df, on=["chr", "start", "end"], how="left")
# add overlaps_ground_truth to df under targets, 1 if overlaps_ground_truth is true, 0 otherwise
ground_truth_df = ground_truth_df.with_columns(pl.when(pl.col("overlaps_ground_truth")).then(1).otherwise(0).alias("targets"))
ground_truth_df

chr,start,end,cell_line,targets,predicted,weights,probabilities,linear_512_output,attributions,overlaps_ground_truth
str,i64,i64,str,i32,f64,f64,f64,list[list[f64]],list[list[f64]],bool
"""chr5""",10039,11783,"""SRR12455442""",1,1.0,-1.0,0.666859,"[[0.584581, 0.677384, … -0.000385]]","[[-0.00093, 0.0, … 0.001799], [0.0, -0.00078, … 0.002326], … [0.0, 0.0, … 0.010341]]",true
"""chr2""",739950,740882,"""SRR12455442""",0,0.0,-1.0,0.069024,"[[0.104039, 0.11396, … -2.424644]]","[[0.0, 0.0, … 0.000124], [0.0, 0.0, … 0.000094], … [0.0, 0.0, … 0.01256]]",false
"""chr10""",10010,10428,"""SRR12455442""",1,1.0,-1.0,0.702575,"[[0.506404, 0.561959, … 0.109321]]","[[0.001685, 0.0, … -0.002936], [0.001869, 0.0, … -0.003507], … [0.0, 0.0, … 0.000278]]",true
"""chr21""",46184427,46184968,"""SRR12455442""",0,1.0,-1.0,0.554604,"[[0.11399, 0.118262, … 0.11204]]","[[0.0, 0.0, … -0.00003], [0.0, 0.0, … 0.000005], … [0.0, 0.0, … -0.000189]]",false
"""chr12""",133264943,133265299,"""SRR12455442""",1,1.0,-1.0,0.686865,"[[0.103133, 0.130933, … 0.188699]]","[[0.0, 0.0, … 0.000012], [0.000376, 0.0, … -0.000288], … [-0.006379, 0.0, … -0.000818]]",true
…,…,…,…,…,…,…,…,…,…,…
"""chr17""",8110181,8110532,"""SRR12455442""",0,0.0,-1.0,0.204443,"[[-1.413845, -1.43306, … 0.104353]]","[[0.0, 0.0, … 0.042807], [0.0, 0.0, … 0.061726], … [0.0, 0.0, … -0.000001]]",false
"""chr16""",2474208,2474999,"""SRR12455442""",0,0.0,-1.0,0.498441,"[[0.040351, 0.105451, … 0.097793]]","[[0.0, 0.0, … -0.000901], [0.0, 0.0, … -0.001582], … [0.0, 0.0, … 0.000171]]",false
"""chr3""",171703978,171704435,"""SRR12455442""",0,0.0,-1.0,0.100252,"[[-1.576183, -1.940581, … 0.097735]]","[[0.0, 0.0, … -0.017684], [0.0, 0.0, … -0.017785], … [0.0, 0.0, … -0.002764]]",false
"""chr7""",101017134,101017572,"""SRR12455442""",0,0.0,-1.0,0.21604,"[[-0.021217, 0.111424, … -1.994048]]","[[0.0, 0.0, … -0.00581], [0.0, 0.001947, … -0.006327], … [0.0, 0.0, … 0.022207]]",false


In [None]:
# get number of targets == predicted vs not predicted   
# get dataframe whre targets == 1
df_targets_1 = ground_truth_df.filter(pl.col("targets") == 1)
# get number of 1s in targets
print("Number of CHIP hits in ATAC peaks", df_targets_1["targets"].sum())
# get number of predicted in df_targets_1
print("Number of predicted in df_targets_1", df_targets_1["predicted"].sum())

Number of CHIP hits in ATAC peaks 1511
Number of predicted in df_targets_1 163.0


In [20]:
threshold = 0.5

# get how many 1s in targets
print("Number of CHIP hits in ATAC peaks", df["targets"].sum())
# get ATAC peaks with probability >= threshold
df_positive = ground_truth_df.filter(pl.col("probabilities") >= threshold)
# get number of 1s in targets
print("Number of CHIP hits in ATAC peaks with probability >= threshold:", df_positive["targets"].sum())
# get length of df_positive
print("Number of ATAC peaks with probability >= threshold:", len(df_positive))
# get ground truth positives
df_ground_truth_positive = df.filter(pl.col("targets") == 1)
print("Number of CHIP hits in ground truth:", len(df_ground_truth_positive))
# get ground truth negatives
df_ground_truth_negative = df.filter(pl.col("targets") == 0)
print("Number of Negatives in ground truth:", len(df_ground_truth_negative))


Number of CHIP hits in ATAC peaks 1511
Number of CHIP hits in ATAC peaks with probability >= threshold: 163
Number of ATAC peaks with probability >= threshold: 5727
Number of CHIP hits in ground truth: 1511
Number of Negatives in ground truth: 51769


In [4]:
# calculate precision, recall, f1 score
precision = df_positive["targets"].sum() / len(df_positive)
recall = df_ground_truth_positive["targets"].sum() / len(df_ground_truth_positive)
f1_score = 2 * precision * recall / (precision + recall)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1_score)

Precision: 0.0
Recall: 1.0
F1 Score: 0.0


In [5]:
# generate precision recall curve
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

# Calculate precision and recall values
precision, recall, thresholds = precision_recall_curve(df_ground_truth_positive["targets"], df_positive["probabilities"])

# plot precision recall curve
plt.figure(figsize=(10, 5))
plt.plot(recall, precision, marker='o')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.show()

ValueError: Found input variables with inconsistent numbers of samples: [3625, 1]

In [9]:
# new model

# Number of CHIP hits in ATAC peaks 1511
# Number of CHIP hits in ATAC peaks with probability >= threshold: 726
# Number of ATAC peaks with probability >= threshold: 7881
# Number of CHIP hits in ground truth: 1511
# Number of Negatives in ground truth: 51769



# old model
# new model

# Number of CHIP hits in ATAC peaks 1511
# Number of CHIP hits in ATAC peaks with probability >= threshold: 726
# Number of ATAC peaks with probability >= threshold: 7881
# Number of CHIP hits in ground truth: 1511
# Number of Negatives in ground truth: 51769



