In [29]:
import os
import polars as pl
from dataclasses import dataclass

PROJECT_PATH = "/data1/datasets_1/human_cistrome/chip-atlas/peak_calls/tfbinding_scripts/tf-binding"

@dataclass
class SampleConfig:
    label: str
    sample: str
    ground_truth_file: str

# create a list of SampleConfig objects
sample_configs = [
    SampleConfig(label="FOX-Contrasting", sample="A549", ground_truth_file=f"{PROJECT_PATH}/data/data_splits/validation_FOXA1_FOXA2.csv"),
]


In [30]:

def enrich_chip_data(chip_df, reference_df):
    """
    Add probabilities and predicted columns from reference dataframe to chip dataframe
    by matching on chr, start, and end coordinates.
    
    Args:
        chip_df: The chip data dataframe to enrich
        reference_df: The source dataframe containing probabilities and predicted values
        
    Returns:
        Enriched chip dataframe with added columns
    """
    # Join chip_data with df on the chromosome coordinates
    enriched_chip_df = chip_df.join(
        reference_df.select(['chr', 'start', 'end', 'probabilities', 'predicted']),
        on=['chr', 'start', 'end'],
        how='left'
    )
    
    return enriched_chip_df


dfs = []
for sample_config in sample_configs:
    parquet_path = PROJECT_PATH + "/data/processed_results/" + sample_config.label + "_" + sample_config.sample + "_processed.parquet"
    df = pl.read_parquet(parquet_path, columns=["chr_name", "start", "end", "cell_line", "targets", "predicted", "weights", "probabilities"])
    df = df.rename({"chr_name": "chr"})
    chip_data = pl.read_csv(sample_config.ground_truth_file, separator="\t", has_header=False, new_columns=["chr", "start", "end", "count", "targets"])
    # Apply the function to add the columns
    enriched_chip_data = enrich_chip_data(chip_data, df)

    dfs.append(enriched_chip_data)

dfs[0]

chr,start,end,count,targets,column_6,probabilities,predicted
str,i64,i64,f64,f64,str,f64,f64
"""chr2""",739972,740939,0.0,0.0,"""A549""",0.526855,1.0
"""chr7""",635940,636669,0.0,0.0,"""A549""",0.10302,0.0
"""chr19""",42242543,42243532,0.0,0.0,"""A549""",0.136167,0.0
"""chr7""",112391260,112392219,0.0,0.0,"""A549""",0.021645,0.0
"""chr8""",109333630,109334843,0.0,0.0,"""A549""",0.045493,0.0
…,…,…,…,…,…,…,…
"""chr15""",78303575,78303790,0.0,0.0,"""A549""",0.015832,0.0
"""chr2""",56812631,56813654,1.0,1.0,"""A549""",0.067692,0.0
"""chr15""",87043753,87044630,0.0,0.0,"""A549""",0.033558,0.0
"""chr22""",40405618,40406300,0.0,0.0,"""A549""",0.018706,0.0


In [31]:
df = dfs[0]

In [40]:
# Find the threshold that maximizes the number of correct predictions
thresholds = [t/100 for t in range(1, 100)]  # Test thresholds from 0.01 to 0.99
best_threshold = 0
max_correct = 0

for threshold in thresholds:
    # Apply threshold to create predicted column
    temp_df = df.with_columns(pl.col('probabilities').gt(threshold).alias('predicted'))
    # Convert boolean to int
    temp_df = temp_df.with_columns(pl.col('predicted').cast(pl.Int64))
    # Check where predictions match targets
    temp_df = temp_df.with_columns((pl.col('targets') == pl.col('predicted')).alias('correct'))
    # Count correct predictions
    correct_count = temp_df.filter(pl.col('correct') == True).shape[0]
    
    # Update best threshold if we found more correct predictions
    if correct_count > max_correct:
        max_correct = correct_count
        best_threshold = threshold

print(f"Best threshold: {best_threshold:.2f}")
print(f"Maximum correct predictions: {max_correct} out of {df.shape[0]} ({max_correct/df.shape[0]:.2%})")

# Apply the best threshold to the dataframe
df = df.with_columns(pl.col('probabilities').gt(best_threshold).alias('predicted'))
df = df.with_columns(pl.col('predicted').cast(pl.Int64))
df = df.with_columns((pl.col('targets') == pl.col('predicted')).alias('correct'))
df['correct'].value_counts()


Best threshold: 0.29
Maximum correct predictions: 7967 out of 12004 (66.37%)


correct,count
bool,u32
False,4037
True,7967


In [37]:
df

chr,start,end,count,targets,column_6,probabilities,predicted,correct
str,i64,i64,f64,f64,str,f64,i64,bool
"""chr2""",739972,740939,0.0,0.0,"""A549""",0.526855,1,false
"""chr7""",635940,636669,0.0,0.0,"""A549""",0.10302,0,true
"""chr19""",42242543,42243532,0.0,0.0,"""A549""",0.136167,0,true
"""chr7""",112391260,112392219,0.0,0.0,"""A549""",0.021645,0,true
"""chr8""",109333630,109334843,0.0,0.0,"""A549""",0.045493,0,true
…,…,…,…,…,…,…,…,…
"""chr15""",78303575,78303790,0.0,0.0,"""A549""",0.015832,0,true
"""chr2""",56812631,56813654,1.0,1.0,"""A549""",0.067692,0,false
"""chr15""",87043753,87044630,0.0,0.0,"""A549""",0.033558,0,true
"""chr22""",40405618,40406300,0.0,0.0,"""A549""",0.018706,0,true
