# Tool results comparison on IMG/VR4 (subsampled)
Following up on the prep work in [`subsample_prep.ipynb`](./subsample_prep.ipynb), and [scaling_cheack_imgvr4_subsamples.ipynb](scaling_cheack_imgvr4_subsamples.ipynb), here we test the different aligners/search outputs. 
Unlike the similar notebook for simulated data, here we do not have a "ground truth" (we can't tell false and true positives (spacer-protospacer pairs) are not real or spurious).  

We will use the 5% sample size - the largest for which all tools finished (sassy and indelfree_bruteforce timed out for the 10% sample size)

In [1]:
# %load_ext autoreload
# %autoreload 2
import os
os.chdir('/clusterfs/jgi/scratch/science/metagen/neri/code/blits/spacer_bench/')
from src.bench.utils.functions import *
from src.bench.commands.generate_scripts import load_tool_configs
import matplotlib.pyplot as plt
import upsetplot as up
import numpy as np
import polars as pl
import json
pl.Config(tbl_rows=50)

# need to disable future deprecation warrnings
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

from bench import *
from bench.utils.functions import *

TOOL_STYLES = json.load(open('notebooks/tool_styles.json', 'r'))
MAX_MISMATCHES = 3
base_dir = "/clusterfs/jgi/scratch/science/metagen/neri/code/blits/spacer_bench/results/real_data/subsamples/fraction_0.001/"
spacers_file = "/clusterfs/jgi/scratch/science/metagen/neri/code/blits/spacer_bench/imgvr4_data/spacers/All_CRISPR_spacers_nr_clean.fna"
contigs_file="/clusterfs/jgi/scratch/science/metagen/neri/code/blits/spacer_bench/results/real_data/subsamples/fraction_0.001/subsampled_data/subsampled_contigs.fa"
threads = 12
spacers = read_fasta(spacers_file)
spacer_lendf = pl.DataFrame({"spacer_id": spacers.keys(), "length": [len(seq) for seq in spacers.values()]})
tools = load_tool_configs(
    results_dir=base_dir,
    contigs_file=contigs_file,
    spacers_file=spacers_file,
    threads=12)

Next, we'll read all tool results, removing entries of unmmaped contigs or matches with more than 3 mismatches. 

In [2]:
# results_df = read_results(
#     tools=tools,
#     max_mismatches=MAX_MISMATCHES+1, #tool reported, not validated for the scalling tests
#     spacer_lendf=spacer_lendf,
#     ref_file=contigs_file,
#     threads=18,
#     memory_limit="150GB",
#     output_parquet='results/real_data/subsamples_analysis/alignments_fraction_0.001.parquet'
# )

## DuckDB-Based Workflow (Memory Efficient)
Instead of loading the full dataset into memory and then validating, we'll use a streaming approach:
1. Extract unique regions via DuckDB (streaming)
2. Validate sequences in batches (controlled memory)
3. Join back and filter via DuckDB (streaming)

This avoids OOM errors by never loading the full dataset.

In [3]:
from src.bench.utils.functions import recalculate_mismatches_streaming

# This replaces all the manual steps of:
# 1. Extracting unique regions
# 2. Populating spacer sequences  
# 3. Populating contig sequences
# 4. Running test_alignment_polars
# 5. Joining back and filtering

# Instead, everything is done in batches without loading the full dataset
recalculate_mismatches_streaming(
    parquet_path='results/real_data/subsamples_analysis/alignments_fraction_0.001.parquet',
    spacers_file=spacers_file,
    contigs_file=contigs_file,
    output_parquet='results/real_data/subsamples_analysis/alignments_fraction_0.001_validated.parquet',
    max_mismatches=3,  # Only keep alignments with ≤3 mismatches after recalculation
    batch_size=10000000,  # Process 10M unique regions at a time (adjust based on memory)
    threads=18,
    memory_limit="150GB",
    ignore_region_strands=True
)


[Streaming Mismatch Recalculation]
  Input: results/real_data/subsamples_analysis/alignments_fraction_0.001.parquet
  Output: results/real_data/subsamples_analysis/alignments_fraction_0.001_validated.parquet
  Batch size: 10,000,000
  Memory limit: 150GB

[Step 1/4] Extracting unique regions via DuckDB...


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  Found 975,312,127 unique regions

[Step 2/4] Populating sequences in batches...
  Processing batch 1/98 (offset=0)...
Initial pldf shape: (10000000, 5)
Unique entries in minipldf: (488005, 1)
After filtering nulls: (488005, 1)
Actual number of sequences in file: 3882812

Processing chunk 500000/3882812
Number of sequences in chunk: 500000
Joining with nascent df
Null count in seqcol after chunk: 425074

Processing chunk 1000000/3882812
Number of sequences in chunk: 500000
Joining with nascent df
Null count in seqcol after chunk: 361447

Processing chunk 1500000/3882812
Number of sequences in chunk: 500000
Joining with nascent df
Null count in seqcol after chunk: 303129

Processing chunk 2000000/3882812
Number of sequences in chunk: 500000
Joining with nascent df
Null count in seqcol after chunk: 240745

Processing chunk 2500000/3882812
Number of sequences in chunk: 500000
Joining with nascent df
Null count in seqcol after chunk: 177635

Processing chunk 3000000/3882812
Number of sequ

parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: missing result
parasail_nw_trace: s2Len must be > 0
parasail_result_is_trace: m

  Removed temp directory: results/real_data/subsamples_analysis/mismatch_recalc_temp


AttributeError: 'Result' object has no traceback

### Load Results for Analysis (DuckDB)
Now we set up a DuckDB connection to query the validated results without loading into memory.

In [None]:
import duckdb

# Create DuckDB connection for querying without loading into memory
con = duckdb.connect(database=":memory:")
con.execute("SET threads TO 12;")
con.execute("SET memory_limit = '100GB';")

# Get tool list
tools_list = con.execute("""
    SELECT DISTINCT tool 
    FROM read_parquet('results/real_data/subsamples_analysis/alignments_fraction_0.001_validated.parquet')
    ORDER BY tool
""").pl()['tool'].to_list()

print(f"Tools: {tools_list}")

# Create a view with renamed columns for consistency
# 'mismatches' (tool-reported) → 'tool_reported_mismatches'
# 'alignment_test' (parasail-validated) → 'mismatches'
con.execute("""
    CREATE VIEW tools_results AS
    SELECT 
        spacer_id,
        contig_id,
        strand,
        start,
        end,
        tool,
        spacer_length,
        mismatches as tool_reported_mismatches,
        alignment_test as mismatches,
        spacer_seq,
        contig_seq
    FROM read_parquet('results/real_data/subsamples_analysis/alignments_fraction_0.001_validated.parquet')
""")

print("Created view: tools_results (with recalculated mismatches)")

### Summary Statistics (DuckDB)

In [None]:
from src.bench.utils.functions import get_summary_stats_duckdb

summary_stats = get_summary_stats_duckdb(
    'results/real_data/subsamples_analysis/alignments_fraction_0.001_validated.parquet',
    threads=12,
    memory_limit="100GB"
)

print(summary_stats)

### Matched Contigs Summary (DuckDB)

In [None]:
matched_contigs = con.execute("""
    SELECT 
        contig_id,
        COUNT(DISTINCT spacer_id) as n_spacers,
        LIST(DISTINCT tool) as tools,
        COUNT(DISTINCT tool) as n_tools
    FROM tools_results
    GROUP BY contig_id
    ORDER BY n_spacers DESC
    LIMIT 10
""").pl()

print(matched_contigs)

### Upset Plots (DuckDB version)
Generate upset plots using DuckDB to create the contig-tool mappings without loading full dataset.

In [None]:
plot_mismatches = [0, 1, 2, 3]

for n_mismatches in plot_mismatches:
    print(f"\nProcessing upset plot for {n_mismatches} mismatches...")
    
    # Use DuckDB to create contig-tools mapping (streaming, no memory load)
    contig_tool_table = con.execute(f"""
        SELECT 
            contig_id,
            LIST(DISTINCT tool) as tools
        FROM tools_results
        WHERE mismatches = {n_mismatches}
        GROUP BY contig_id
        ORDER BY contig_id
    """).pl()
    
    # Create upset plot from the small aggregated result
    test_upset = up.from_memberships(contig_tool_table['tools'])
    print(f"n mismatches: {n_mismatches}")
    
    plt.figure(figsize=(10, 6))
    up.plot(test_upset, subset_size='count', sort_by='cardinality')
    plt.title(f'Matches with == {n_mismatches} mismatches')
    plt.savefig(f'results/real_data/plots/upset_{n_mismatches}.pdf')
    plt.show()

### Tool Comparison Matrix (DuckDB version)
Generate tool comparison matrices using streaming DuckDB queries.

In [None]:
from src.bench.utils.functions import create_tool_comparison_matrix_duckdb

all_charts = []

for n_mismatches in [0, 1, 2, 3]:
    print(f"\nCreating matrix for {n_mismatches} mismatches...")
    
    matrix = create_tool_comparison_matrix_duckdb(
        parquet_path='results/real_data/subsamples_analysis/alignments_fraction_0.001_validated.parquet',
        tools_list=tools_list,
        n_mismatches=n_mismatches,
        output_csv=f'results/real_data/results/matrix_{n_mismatches}.tsv',
        threads=12,
        memory_limit="100GB"
    )
    
    print(matrix)
    
    # Create heatmap
    heatmap_filename = f'results/real_data/plots/matrix_{n_mismatches}'
    chart = plot_matrix(matrix, f"Matrix for {n_mismatches} mismatches", heatmap_filename)
    all_charts.append(chart)

### Spacer Counts for Recall Analysis (DuckDB version)

In [None]:
from src.bench.utils.functions import create_spacer_counts_with_tools_duckdb

# Create spacer counts for recall vs occurrences analysis
spacer_counts_with_tools = create_spacer_counts_with_tools_duckdb(
    parquet_path='results/real_data/subsamples_analysis/alignments_fraction_0.001_validated.parquet',
    tools_list=tools_list,
    mismatches=3,
    exact_or_max="max",
    threads=12,
    memory_limit="100GB"
)

print(f"Created spacer counts: {spacer_counts_with_tools.height:,} rows")
print(spacer_counts_with_tools.head())

# Save for later use
spacer_counts_with_tools.write_parquet('results/real_data/results/spacer_counts_max3.parquet')

### Distribution Data (DuckDB version)
Get aggregated distribution data for plotting without loading the full dataset.

In [None]:
# Get distributions using DuckDB (small aggregated results)
length_distribution = con.execute("""
    SELECT 
        spacer_length,
        COUNT(*) as count
    FROM (
        SELECT DISTINCT spacer_id, spacer_length
        FROM tools_results
    )
    GROUP BY spacer_length
    ORDER BY spacer_length
""").pl()

mismatch_distribution = con.execute("""
    SELECT 
        mismatches,
        COUNT(*) as count
    FROM (
        SELECT DISTINCT spacer_id, contig_id, strand, start, end, mismatches
        FROM tools_results
    )
    GROUP BY mismatches
    ORDER BY mismatches
""").pl()

occurrence_distribution = con.execute("""
    SELECT 
        n_occurrences,
        COUNT(*) as count
    FROM (
        SELECT 
            spacer_id,
            COUNT(DISTINCT contig_id) as n_occurrences
        FROM tools_results
        WHERE mismatches <= 3
        GROUP BY spacer_id
    )
    GROUP BY n_occurrences
    ORDER BY n_occurrences
""").pl()

print(f"Length distribution: {length_distribution.height} unique lengths")
print(f"Mismatch distribution: {mismatch_distribution.height} mismatch levels")
print(f"Occurrence distribution: {occurrence_distribution.height} occurrence levels")

### Tool Performance by Mismatches (DuckDB version)

In [None]:
# Calculate total possible matches and per-tool matches using DuckDB
total_matches = con.execute("""
    SELECT 
        mismatches,
        COUNT(*) as total_possible
    FROM (
        SELECT DISTINCT spacer_id, contig_id, mismatches
        FROM tools_results
    )
    GROUP BY mismatches
    ORDER BY mismatches
""").pl()

tool_matches = con.execute("""
    SELECT 
        mismatches,
        tool,
        COUNT(*) as tool_matches
    FROM (
        SELECT DISTINCT spacer_id, contig_id, tool, mismatches
        FROM tools_results
    )
    GROUP BY mismatches, tool
    ORDER BY mismatches, tool
""").pl()

# Create all combinations
all_combinations = pl.DataFrame({
    'mismatches': np.repeat(range(4), len(tools_list)),
    'tool': tools_list * 4
})

# Calculate recall
mismatch_performance = all_combinations\
    .join(total_matches, on='mismatches')\
    .join(tool_matches, on=['mismatches', 'tool'], how='left')\
    .with_columns([
        pl.col('tool_matches').fill_null(0),
        (pl.col('tool_matches') / pl.col('total_possible')).alias('recall')
    ])

print(mismatch_performance)

# Save results
mismatch_performance.write_csv('results/real_data/results/tool_recall_by_mismatches.tsv', separator='\t')

### Cleanup
Close the DuckDB connection when analysis is complete.

In [None]:
# Close DuckDB connection
con.close()
print("DuckDB connection closed")
print("\nMemory-efficient workflow complete!")
print("All analyses performed without loading full dataset into memory.")

## Filter alignments using DuckDB
Since the combined parquet file is large and causes memory issues with Polars, we'll use DuckDB to filter for mismatches <= 3 and save a smaller filtered file.

In [None]:
# # Use Polars streaming mode - processes the data in chunks without loading everything into memory, cause it's too big (raw sassy tsv output is ~4.7tb...)
# pl.scan_parquet('results/real_data/subsamples_analysis/alignments_fraction_0.001.parquet') \
#     .filter(pl.col('mismatches') <= 3) \
#     .sink_parquet(
#         'results/real_data/subsamples_analysis/alignments_fraction_0.001_maxmis_3.parquet',
#         compression='snappy'
#     )
# no need to rerun - next cell will read the filtered parquet

Now we can load the filtered results using Polars (much smaller file, so memory-friendly)

To reduce memory usage and time to access matched contigs/spacers, we'll index the spacers fasta file, and also craete and index a fasta file containing only the matched contigs.

In [None]:
tools_results = pl.scan_parquet('results/real_data/subsamples_analysis/alignments_fraction_0.001_maxmis_3.parquet')

In [None]:
tools_results.collect_schema()

Next -  indexing the contigs fasta file for faster access.

In [None]:
%%bash
pyfastx index results/real_data/subsamples/fraction_0.1/subsampled_data/subsampled_contigs.fa

### summary 
reminder - loading of pre-combined results:  
`tools_results = pl.read_parquet('results/real_data/results/tools_results.parquet')`


Now we print some summary statistics


In [None]:
tools_results = tools_results.collect()

In [None]:
summary_stats = tools_results.group_by('tool').agg(
     pl.col('mismatches').mean().alias('mean_mismatches'),
     pl.col('spacer_id').n_unique().alias('n_spacers'),
     pl.col('contig_id').n_unique().alias('n_contigs'),
     pl.col('strand').value_counts().alias('strand_counts'),
     )

#### Summary about the matched contigs


In [None]:
matched_contigs = tools_results.group_by('contig_id').agg(
    pl.col('spacer_id').n_unique().alias('n_spacers'),
    pl.col('tool').unique().alias('tools'),
    pl.col('tool').n_unique().alias('n_tools'),
)
matched_contigs.sort('n_spacers',descending=True).head(10)

### Let's closely examine all the contigs that were only detected by a single tool with 0 mismatches.

In [None]:
one_tool_only_n0 = tools_results.filter(pl.col('mismatches') == 0).group_by('contig_id').agg(
    pl.col('spacer_id').n_unique().alias('n_spacers'),
    pl.col('tool').unique().alias('tools'),
    pl.col('tool').n_unique().alias('n_tools'),
    ).filter(pl.col('n_tools') == 1)
results_n0 = tools_results.filter(pl.col('mismatches') < 1).filter(pl.col('contig_id').is_in(one_tool_only_n0['contig_id']))
results_n0

### Validation/correction of the results
To avoid relying on the tool-reported mismatches, we'll recalculate the mismatches in a consistent way.   
To do this, we'll use the parasail library, on a set of pairs of spacer & unique aligned-to regions from contigs 

In [None]:
unique_regions = tools_results.select(["spacer_id","contig_id","strand","start","end"]).unique()
unique_regions.write_parquet('results/real_data/results/unique_regions.parquet')

First we'll populate the unique regions with the spacer sequences.


In [None]:
unique_regions = populate_pldf_withseqs_needletail(seqfile=spacers_file, pldf=unique_regions,chunk_size=2000000, reverse_by_strand_col=False,trim_to_region=False, idcol="spacer_id",seqcol="spacer_seq")
unique_regions.write_parquet('results/real_data/results/unique_regions_with_spacer_seqs.parquet')
unique_regions

Next we'll populate the unique regions with the contig sequences, this will take a while.

In [None]:
contigs_file  = "results/real_data/results/matched_contigs.fna"

In [None]:
unique_regions = populate_pldf_withseqs_needletail(seqfile=contigs_file, trim_to_region=True,reverse_by_strand_col=True, chunk_size=200000, pldf=unique_regions, idcol="contig_id",start_col="start",end_col="end",strand_col="strand",seqcol="contig_seq")
unique_regions.write_parquet('results/real_data/results/unique_regions_with_contig_seqs.parquet')
unique_regions
### quick check to see if how many contigs were only detected by each tool with n mismatches

Next we'll use parasail to recalculate the mismatches between the spacer and the contig seqs.

In [None]:
unique_regions = pl.read_parquet('results/real_data/results/unique_regions_with_contig_seqs.parquet')

In [None]:
test3 = unique_regions
test4 = test_alignment_polars(
    results=test3, 
    return_deviations=False,
    ignore_region_strands=True 
)
test4

In [None]:
test4.write_parquet('results/real_data/results/unique_regions_mm_recalced.parquet')

Next, we merge the recalculated mismatches with the original results.

In [None]:
tools_results = tools_results.join(test4[["spacer_id", "contig_id", "strand", "start", "end", "spacer_seq","contig_seq"]], on=['spacer_id', 'contig_id', 'strand', 'start', 'end'], how='left')
tools_results.write_parquet('results/real_data/results/tools_results_mm_recalced.parquet')
tools_results

Next we'll test the reported alignments, and further filter them.

In [None]:
tools_results = tools_results.with_columns((
    pl.col("alignment_test") -  pl.col("mismatches")  ).alias("deviation")
)

deviated_rows = tools_results.filter(pl.col("deviation") != 0)# .filter(pl.col("tool") == "vsearch")
deviated_rows = deviated_rows.sort("deviation", descending=False)
deviated_rows

maybe this is a tool parsing issue?, let's get the frequency of deviations per tool


In [None]:
deviation_counts = deviated_rows.group_by("tool").agg(
    pl.col("deviation").count().alias("deviation_count"),
    pl.col("deviation").mean().alias("mean_deviation")
    ).sort("deviation_count", descending=True)
print(deviation_counts)
deviation_counts.write_csv('results/real_data/results/deviation_counts.csv')
for tool in deviation_counts['tool']:
    tmp = deviated_rows.filter(pl.col('tool') == tool).sort("alignment_test", descending=False)
    print(f"{tool} had: min {tmp['alignment_test'].min()} max {tmp['alignment_test'].max()} mean {tmp['alignment_test'].mean()} std {tmp['alignment_test'].std()}")
    print(f"worst 5 rows for {tool}:")
    for row in tmp.tail(10).iter_rows(named=True):
        print(prettify_alignment(row['spacer_seq'], row['contig_seq'], None,None, None))
        print(f"mismatches: {row['mismatches']} recalc: {row['alignment_test']}")
        print("\n")
    for row in tmp.head(10).iter_rows(named=True):
        print(prettify_alignment(row['spacer_seq'], row['contig_seq'], None,None, None))
        print(f"mismatches: {row['mismatches']} recalc: {row['alignment_test']}")
        print("\n")


## Overall results
Next, we try to answer the question: "Which is the single best tool?".  
For that, let's use 2 metrics to define "best tool":
1.  Has the highest number of unique spacer-contig pairs.  
2.  For every spacer, has the highest fraction of identified occurences (regardless of the number of unique contigs it was found in).

### 1. Tool comparison matrixes - unique spacer-contig pairs 

In [None]:
tools_results = pl.read_parquet('results/real_data/results/tools_results_mm_recalced.parquet')
tools_results = tools_results.filter(pl.col('alignment_test') < 4)
tools_results = tools_results.rename({'mismatches': 'tool_reported_mismatches'}).rename({'alignment_test': 'mismatches'})
tools_list = tools_results['tool'].unique().to_list() # might want to consider removing certain tools as they aren't very informative

In [None]:
plot_mismatches = [0,1,2,3]
tools_list

### Upset plots
THis let's us examine the set() intersections of the results.


In [None]:
plot_mismatches = [0,1,2,3]
for n_mismatches in plot_mismatches: 
    # first we create a table where each row is a contig and a 2nd column is a list of tools that matched the contig
    nmism_tools_results = tools_results.filter(pl.col('mismatches') == n_mismatches)
    contig_tool_table = nmism_tools_results.group_by('contig_id').agg(pl.col('tool').unique().alias('tools')).sort('contig_id')
    test_upset = up.from_memberships(contig_tool_table['tools']) # need to disable future deprecation warrnings
    print("n mismatches: ", n_mismatches)
    plt.figure(figsize=(10, 6))
    up.plot(test_upset, subset_size='count',sort_by='cardinality')
    plt.title(f'Matches with == {n_mismatches} mismatches') #≤
    plt.savefig(f'results/real_data/plots/upset_{n_mismatches}.pdf')
    plt.show()

And as one figure for the supplementary material (I couldn't figure out how to merge upset plots, and I really do not want this to be done in inkscape manually, so using use svgutils)

In [None]:
# First create individual plots and save them
for i, n_mismatches in enumerate(plot_mismatches):
    # Create the data for the upset plot
    nmism_tools_results = tools_results.filter(pl.col('mismatches') == n_mismatches)
    contig_tool_table = nmism_tools_results.group_by('contig_id').agg(pl.col('tool').unique().alias('tools')).sort('contig_id')
    test_upset = up.from_memberships(contig_tool_table['tools'])
    
    # Create individual plot
    plt.figure(figsize=(12, 10))
    up.plot(test_upset, subset_size='count', sort_by='cardinality')
    plt.title(f'Matches with == {n_mismatches} mismatches')
    
    # Save individual plot
    plt.savefig(f'results/real_data/plots/upset_{n_mismatches}.svg', bbox_inches='tight')
    plt.close()


In [None]:
# from svgutils.compose import Figure, Panel, SVG, Text
# height = f"2200"
# width = f"7500"
# fig = Figure(width, height,
#         SVG("results/real_data/plots/upset_0.svg",fix_mpl=True),
#         SVG("results/real_data/plots/upset_1.svg",fix_mpl=True),
#         SVG("results/real_data/plots/upset_2.svg",fix_mpl=True),
#         SVG("results/real_data/plots/upset_3.svg",fix_mpl=True)
#         ).tile(1, 4)
# fig.save("results/real_data/plots/upset_combined.svg")
# I can't figure out how to make this look good, will just point to the plots seperatly

#### Matrixes
In a tool vs tool manner, it possible to get specific insights.  
In the following tables A `cell(i,x)` is the number of unique spacer-contig pairs that are in tool `i` but not in tool `x`.   

In [None]:
all_charts = []
for n_mismatches in [0,1,2,3]:
    print(f"n_mismatches: {n_mismatches}")
    
    # Filter for current mismatch level
    tmp = tools_results.filter(pl.col('mismatches') == n_mismatches)
    
    # Create empty matrix
    matrix = pl.DataFrame(data=np.zeros((len(tools_list), len(tools_list)), dtype=int), schema=tools_list)
    
    # Get unique pairs for each tool
    tool_pairs = {}
    for tool in tools_list:
        tool_pairs[tool] = tmp.filter(pl.col('tool') == tool).select(["contig_id","spacer_id","strand","start","end"]).unique()
    
    # Fill matrix with counts
    for i, tool_x in enumerate(tools_list):
        for j, tool_y in enumerate(tools_list):
            if tool_x == tool_y:
                continue
            # Count pairs in x but not in y
            unique_pairs = tool_pairs[tool_x].join(tool_pairs[tool_y], on=['contig_id','spacer_id','strand','start','end'], how='anti')
            matrix[i,j] = unique_pairs.height
    
    # Convert to DataFrame for better visualization
    matrix = matrix.with_columns(pl.Series(name="tool1", values=tools_list, dtype=pl.Utf8))
    print(f"Matrix for {n_mismatches} mismatches:")
    print(matrix)
    heatmap_filename = f'results/real_data/plots/matrix_{n_mismatches}'
    chart = plot_matrix(matrix, f"Matrix for {n_mismatches} mismatches", heatmap_filename)
    all_charts.append(chart)
    # plot_matrix(matrix, f"Matrix for {n_mismatches} mismatches (% of tool's total matches)", heatmap_filename+"_percent", as_percent=True)  # as percentages
    matrix.write_csv(f'results/real_data/results/matrix_{n_mismatches}.tsv',separator='\t')
    print("\n")

Let's combine all the matrixes svgs into one for the manuscript

In [None]:
import altair as alt

row_1 = alt.hconcat(all_charts[0],all_charts[1])
row_2 = alt.hconcat(all_charts[2],all_charts[3])
chart = alt.vconcat(row_1,row_2)
chart.save("results/real_data/plots/matrix_combined.svg")
chart.save("results/real_data/plots/matrix_combined.pdf",format="pdf")
chart

Summary of the matrixes: (one matrix with <=3 mismatches)

In [None]:
n_mismatches = 3
# Filter for current mismatch level
tmp = tools_results.filter(pl.col('mismatches') <= n_mismatches)

# Create empty matrix
matrix = pl.DataFrame(data=np.zeros((len(tools_list), len(tools_list)), dtype=int), schema=tools_list)

# Get unique pairs for each tool
tool_pairs = {}
for tool in tools_list:
    tool_pairs[tool] = tmp.filter(pl.col('tool') == tool).select(["contig_id","spacer_id","strand","start","end"]).unique()

# Fill matrix with counts
for i, tool_x in enumerate(tools_list):
    for j, tool_y in enumerate(tools_list):
        if tool_x == tool_y:
            continue
        # Count pairs in x but not in y
        unique_pairs = tool_pairs[tool_x].join(tool_pairs[tool_y], on=['contig_id','spacer_id','strand','start','end'], how='anti')
        matrix[i,j] = unique_pairs.height

# Convert to DataFrame for better visualization
matrix = matrix.with_columns(pl.Series(name="tool1", values=tools_list, dtype=pl.Utf8))
print(f"Matrix for {n_mismatches} mismatches:")
print(matrix)
heatmap_filename = f'results/real_data/plots/matrix__less_or_equal_{n_mismatches}_mismatches'
chart = plot_matrix(matrix, f"Matrix <= {n_mismatches} mismatches", heatmap_filename)
# plot_matrix(matrix, f"Matrix for {n_mismatches} mismatches (% of tool's total matches)", heatmap_filename+"_percent", as_percent=True)  # as percentages
matrix.write_csv(f'results/real_data/results/matrix_less_or_equal_{n_mismatches}_mismatches.tsv',separator='\t')
print("\n")
chart

### 2. Tool comparison - spacer-contig pairs as a function of the number of occurrences 
Next, we check if an increase in the number of occurrences (meaning more of the same spacer in the reference file) corrosponds to a lower true positive rate (in the sense that the tools do not find as much). The exact effect should be tool specific.  

Next, for each spacer, we add a column that specifies the fraction of the occurrences that tool identified


In [None]:
def create_spacer_counts_with_tools(recalc_only, tools_list, mismatches=3, exact_or_max="exact"):
    # First get total occurrences per spacer across all tools
    if exact_or_max == "max":
        spacer_counts = recalc_only.filter(pl.col('mismatches') <= mismatches)
    else:
        spacer_counts = recalc_only.filter(pl.col('mismatches') == mismatches)
    
    spacer_counts = spacer_counts.select(["spacer_id", "contig_id"])\
        .unique()\
        .group_by('spacer_id')\
        .agg(pl.count('contig_id').alias('n_occurrences'))

    # Calculate matches per tool and spacer without joining to spacer_counts yet
    if exact_or_max == "max":
        tool_matches = recalc_only.filter(pl.col('mismatches') <= mismatches)
    else:
        tool_matches = recalc_only.filter(pl.col('mismatches') == mismatches)
    
    tool_matches = tool_matches.select(['spacer_id', 'tool', 'contig_id'])\
        .unique()\
        .group_by(['spacer_id', 'tool'])\
        .agg(pl.count('contig_id').alias('tool_matches'))

    # Create a cross join of all spacers with all tools
    all_combinations = spacer_counts.select('spacer_id', 'n_occurrences')\
        .join(
            pl.DataFrame({'tool': tools_list}),
            how='cross'
        )

    # Join the actual matches and calculate fractions
    complete_fractions = all_combinations\
        .join(
            tool_matches,
            on=['spacer_id', 'tool'],
            how='left'
        )\
        .with_columns([
            pl.col('tool_matches').fill_null(0),
            (pl.col('tool_matches') / pl.col('n_occurrences')).alias('fraction')
        ])

    # Pivot to get tools as columns
    spacer_counts_with_tools = complete_fractions\
        .pivot(
            index=['spacer_id', 'n_occurrences'],
            on='tool',
            values='fraction'
        )\
        .fill_null(0)
    
    return spacer_counts_with_tools

Recall plots vs occurrences

In [None]:
def plot_combined_recall_vs_occurrences(recalc_only, tools_list, n_high_occ_bins=3, 
                         output_prefix='results/real_data/plots/recall_vs_occurrences', 
                         max_bin=3, n_bins=150, color_dict=None, marker_dict=None, exact_or_max="exact",
                         plot_mismatches=[1,3]):
    
    # Create color and marker dictionaries for consistent styling
    import matplotlib.colors as mcolors
    if color_dict is None:
        color_dict = dict(zip(tools_list, mcolors.TABLEAU_COLORS))
    if marker_dict is None:
        marker_dict = dict(zip(tools_list, ['o', 's', '^', 'D', 'v', '<', '>', 'p','x']))

    # Create figure with two subplots
    fig, axes = plt.subplots(len(plot_mismatches), 1, figsize=(15, 36))
    
    for i, mismatches in enumerate(plot_mismatches):
        # Create a new figure for the single plot
        fig_single, ax_single = plt.subplots(figsize=(15, 12))
        
        # Plot on both the combined and single figures
        plot_on_axis(axes[i], recalc_only, tools_list, n_high_occ_bins, n_bins, max_bin, 
                    mismatches=mismatches, exact_or_max=exact_or_max, color_dict=color_dict, 
                    marker_dict=marker_dict, output_prefix=output_prefix)
        plot_on_axis(ax_single, recalc_only, tools_list, n_high_occ_bins, n_bins, max_bin, 
                    mismatches=mismatches, exact_or_max=exact_or_max, color_dict=color_dict, 
                    marker_dict=marker_dict, output_prefix=output_prefix)
        
        # Set titles
        if exact_or_max == "exact":
            title = f'Recall vs number of occurrences (mismatches == {mismatches})'
        else:
            title = f'Recall vs number of occurrences (mismatches ≤ {mismatches})'
        axes[i].set_title(title)
        ax_single.set_title(title)
        
        # Save single plot
        plt.figure(fig_single.number)
        plt.tight_layout()
        fig_single.savefig(f'{output_prefix}_{exact_or_max}_nm_{mismatches}.pdf', bbox_inches='tight',format='pdf')
        fig_single.savefig(f'{output_prefix}_{exact_or_max}_nm_{mismatches}.svg', bbox_inches='tight',format='svg')
        plt.close(fig_single)
    
    # Save combined plot
    plt.figure(fig.number)
    plt.tight_layout()
    plot_mismatches_str = "_".join([str(mismatch) for mismatch in plot_mismatches])
    fig.savefig(f'{output_prefix}_{exact_or_max}_nm_{plot_mismatches_str}_combined.pdf', bbox_inches='tight')
    fig.savefig(f'{output_prefix}_{exact_or_max}_nm_{plot_mismatches_str}_combined.svg', bbox_inches='tight')
    plt.show()

def plot_on_axis(ax: plt.Axes, recalc_only, tools_list, n_high_occ_bins, n_bins, max_bin, 
                 mismatches, exact_or_max, color_dict, marker_dict,output_prefix):
    
    spacer_counts_with_tools = create_spacer_counts_with_tools(recalc_only, tools_list, 
                                                             mismatches=mismatches, 
                                                             exact_or_max=exact_or_max)

    # Create range bins for number of occurrences
    bins = np.logspace(np.log10(1), max_bin, n_bins)

    # Calculate mean fraction for each tool within each bin
    bin_stats = []
    for i in range(len(bins)-1):
        mask = (spacer_counts_with_tools['n_occurrences'] >= bins[i]) & \
               (spacer_counts_with_tools['n_occurrences'] < bins[i+1])
        bin_data = spacer_counts_with_tools.filter(mask)
        if bin_data.height > 0:
            stats = {
                'bin_start': bins[i],
                'bin_end': bins[i+1],
                'n_spacers': bin_data.height
            }
            for tool in tools_list:
                stats[tool] = bin_data[tool].mean()
            bin_stats.append(stats)

    # Add points for high occurrences in multiple bins
    if n_high_occ_bins > 0:
        high_occ_edges = np.logspace(3, 4, n_high_occ_bins + 1)
        for i in range(n_high_occ_bins):
            bin_start = high_occ_edges[i]
            bin_end = high_occ_edges[i + 1]
            
            if i == n_high_occ_bins - 1:
                high_occ_mask = (spacer_counts_with_tools['n_occurrences'] >= bin_start)
            else:
                high_occ_mask = (spacer_counts_with_tools['n_occurrences'] >= bin_start) & \
                               (spacer_counts_with_tools['n_occurrences'] < bin_end)
            
            high_occ_data = spacer_counts_with_tools.filter(high_occ_mask)
            if high_occ_data.height > 0:
                high_occ_stats = {
                    'bin_start': bin_start,
                    'bin_end': bin_end,
                    'n_spacers': high_occ_data.height
                }
                for tool in tools_list:
                    high_occ_stats[tool] = high_occ_data[tool].mean()
                bin_stats.append(high_occ_stats)

    # Plot on the provided axis
    for tool in tools_list:
        x = [(stat['bin_start'] + stat['bin_end'])/2 for stat in bin_stats]
        y = [stat[tool] for stat in bin_stats]
        ax.plot(x, y, label=tool, markersize=4, linewidth=1, 
                color=color_dict[tool], markerfacecolor=color_dict[tool],
                marker=marker_dict[tool])

    ax.set_xscale('log')
    ax.set_xlabel('Number of occurrences (log scale)')
    ax.set_ylabel('Mean Detection Fraction')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.grid(True, which="both", ls="-", alpha=0.2)
    ax.grid(True, which="major", ls="-", alpha=0.5)
    ax.minorticks_on()
    ax.set_ylim(0, 1.05)
    ax.set_xlim(1, 10**4)
    

In [None]:
n_mismatches = [0,1,2,3]
plot=plot_combined_recall_vs_occurrences(tools_results,
                                        tools_list,
                                        n_high_occ_bins=3,
                                        output_prefix='results/real_data/plots/recall_vs_occurrences',
                                        plot_mismatches=n_mismatches,
                                        exact_or_max="exact")

And in a multi-panel plot with 2 plots (up to 1 and up to 3):

In [None]:
tools_list

Similar plots, but with 1 and 3 mismatches (max, not exact), and in a multi-panel plot (for the main text)

In [None]:
# Call the plotting function combined
plot_combined_recall_vs_occurrences(tools_results, tools_list, n_high_occ_bins=3, 
                                  output_prefix='results/real_data/plots/recall_vs_occurrences',
                                  exact_or_max="max",
                                  plot_mismatches=[1,3])

Next, we'll look at the performance of the tools as a function of the number of mismatches.

In [None]:
# First calculate the total matches per spacer-contig pair across all tools (per mismatch level)
import altair as alt
total_matches = tools_results\
    .select(['spacer_id', 'contig_id', 'mismatches'])\
    .unique()\
    .group_by(['mismatches'])\
    .agg(pl.len().alias('total_possible'))

# Calculate matches per tool at each mismatch level
tool_matches = tools_results\
    .select(['spacer_id', 'contig_id', 'tool', 'mismatches'])\
    .unique()\
    .group_by(['mismatches', 'tool'])\
    .agg(pl.len().alias('tool_matches'))

# Create all combinations of mismatches (0-3) and tools
all_combinations = pl.DataFrame({
    'mismatches': np.repeat(range(4), len(tools_list)),
    'tool': tools_list * 4
})

# Calculate fractions
mismatch_performance = all_combinations\
    .join(
        total_matches,
        on='mismatches'
    )\
    .join(
        tool_matches,
        on=['mismatches', 'tool'],
        how='left'
    )\
    .with_columns([
        pl.col('tool_matches').fill_null(0),
        (pl.col('tool_matches') / pl.col('total_possible')).alias('recall')
    ])

mismatch_performance.write_csv('results/real_data/results/tool_recall_by_mismatches.tsv',separator='\t')

metrics = {
    # 'precision': 'true_positives / (true_positives + false_positives)',
    'recall': 'true_positives / (true_positives + false_negatives)',
    # 'f1_score': ' 2 * (precision * recall) / (precision + recall)'
}

charts = []

for metric in metrics.keys():
    base_chart = alt.Chart(mismatch_performance).mark_trail(color="tool:N").encode(
        x=alt.X("mismatches:Q", title="Number of Mismatches"),
        y=alt.Y(f"{metric}:Q", 
                title=metrics[metric],
                scale=alt.Scale(domain=[0, 1.05])
                ),
        color=alt.Color("tool:N",
                        legend=None
                       ),  # Hide color legend
        shape=alt.Shape("tool:N",     # Shape legend will show both shape and color
                       legend=alt.Legend(
                           title="Tool",
                           orient="right",
                           symbolFillColor="tool:N",  # Use the color encoding for fill
                           symbolStrokeColor="tool:N" # Use the color encoding for stroke
                       )),
        tooltip=['tool', 'mismatches', metric]
    ).properties(
        width=300,
        height=300,
        title=metric.title()
    )
    
    charts.append(base_chart)

# Combine charts horizontally
combined_chart = alt.hconcat(*charts).configure_axis(
    grid=True,
    gridOpacity=0.9
).configure_view(step=1,
    strokeWidth=0.1
).configure_title(
    fontSize=16,
    anchor='middle'
)
combined_chart.save('./results/real_data/plots/tool_performance_by_mismatches.html')
combined_chart.save('./results/real_data/plots/tool_performance_by_mismatches.json',format='json')

import json as json
# Save the chart specification as JSON with the correct format
chart_json = combined_chart.to_json(format="vega")

# Parse the JSON
chart_spec = json.loads(chart_json)

# Replace the legends in the specification
# We'll need to traverse the spec to find where legends are defined damn it
def update_legends_in_spec(spec):
    # Define the new legend configuration
    new_legend = {
    "orient": "right",
    "symbolSize": 190,
    "symbolOpacity": 1,
    "symbolFillColor": "shape",
    "symbolStrokeColor": "tool:N",
    "title": "Tool",
    "shape": "shape",
    "fill": "color",
    "offset": 0,
    "encode": {
        "symbols": {
            "update": {
                "fillOpacity": {"value": 0.9}
            }
        }
    }}
    if isinstance(spec, dict):
        if 'legends' in spec:
            spec['legends'] = [new_legend]
        for value in spec.values():
            update_legends_in_spec(value)
    elif isinstance(spec, list):
        for item in spec:
            update_legends_in_spec(item)
    return spec

new_chart_spec = update_legends_in_spec(chart_spec)

# Write the modified specification to a file
def change_chart_spec(new_chart_spec,output_path):
        html_content = f"""
        <!DOCTYPE html>
        <html>
        <head>
        <script src="https://cdn.jsdelivr.net/npm/vega@5"></script>
        <script src="https://cdn.jsdelivr.net/npm/vega-lite@5"></script>
        <script src="https://cdn.jsdelivr.net/npm/vega-embed@6"></script>
        </head>
        <body>
        <div id="vis"></div>
        <script type="text/javascript">
            var spec = {json.dumps(new_chart_spec)};
            vegaEmbed('#vis', spec);
        </script>
        </body>
        </html>
        """
        
        with open(output_path, 'w') as f:
            f.write(html_content)
# chart_spec
change_chart_spec(new_chart_spec, './results/real_data/plots/tool_performance_by_mismatches.html')


In [None]:
combined_chart

In [None]:
spacer_counts_with_tools = create_spacer_counts_with_tools(tools_results, tools_list, mismatches=3,exact_or_max="max")

Now we chart the distribution of the spacer lengths, number of mismatches, and number of occurrences.

In [None]:
# alt.data_transformers.enable("vegafusion")
def plot_spacer_distributions(tools_results, spacer_counts_with_tools, output_prefix='results/real_data/plots'):
    """Create a three-panel figure showing spacer length, mismatch, and occurrence distributions using Altair.
    Uses VegaFusion for efficient handling of large datasets.
    """
    # Pre-aggregate in Polars to reduce data size
    spacer_df = tools_results.select([
        'spacer_id', 'contig_id', 'strand', 'start', 'end', 'spacer_length', 'mismatches'
    ]).unique()

    occurrence_df = spacer_counts_with_tools['n_occurrences'].value_counts()
    mismatches_df=spacer_df['mismatches'].value_counts()
    length_df=spacer_df['spacer_length'].value_counts()

   
    # Base chart for occurrence distribution
    base = alt.Chart(occurrence_df).encode(
        tooltip=['n_occurrences:Q', 'count:Q']
    )

    chart3 = base.mark_bar(opacity=0.5).encode(
        alt.X('n_occurrences:Q',
                scale=alt.Scale(type='log'),
                title='Number of Occurrences'),
        alt.Y('count:Q', title='Count (Linear Scale)')
    ).properties(
        title='B.',
    )
    
    chart1 = alt.Chart(length_df).mark_bar().encode(
        x=alt.X('spacer_length:Q', title='Spacer Length (bp)', axis=alt.Axis(grid=True,tickCount=100)),
        y=alt.Y("count:Q", axis=alt.Axis(title="Number of Spacers", grid=True,ticks=False,gridDash=[2,2]),
        # scale=alt.Scale(domain=[16, 108])
        )

        
    ).properties(
        title='A.',
    )

    chart2 = alt.Chart(mismatches_df).mark_bar(width=13).encode(
        x=alt.X('mismatches:Q', title='Number of Mismatches', axis=alt.Axis(grid=True,tickCount=6)),
        y=alt.Y("count:Q", axis=alt.Axis(title="Number of Spacers",  grid=True,ticks=False,gridDash=[2,2])),
        # color=alt.Color('mismatches:Q', legend=T)
    ).properties(
        title='C.',
    )
    
    combined_chart = chart1 | chart3 | chart2
    combined_chart  = combined_chart.configure_title(anchor='start')
    return combined_chart


In [None]:
combined_chart =  plot_spacer_distributions(spacer_counts_with_tools=spacer_counts_with_tools,tools_results=tools_results)


In [None]:
combined_chart

In [None]:

combined_chart.save('./results/real_data/plots/spacer_distributions.html')
combined_chart.save('./results/real_data/plots/spacer_distributions.svg',format='svg')
combined_chart.save('./results/real_data/plots/spacer_distributions.pdf',format='pdf')


In [None]:
combined_chart

For safe keeping, also print some summary statistics

In [None]:
# Print some summary statistics
# plot_spacer_distributions(spacer_counts_with_tools=spacer_counts_with_tools,tools_results=tools_results)
print("\nSummary Statistics:")
print("\nSpacer Lengths:")
print(f"Mean: {tools_results['spacer_length'].mean():.2f}")
print(f"Median: {tools_results['spacer_length'].median():.2f}")
print(f"Std: {tools_results['spacer_length'].std():.2f}")
print(f"Min: {tools_results['spacer_length'].min()}")
print(f"Max: {tools_results['spacer_length'].max()}")

print("\nMismatches:")
print(f"Mean: {tools_results['mismatches'].mean():.2f}")
print(f"Median: {tools_results['mismatches'].median():.2f}")
print(f"Std: {tools_results['mismatches'].std():.2f}")
print(f"Min: {tools_results['mismatches'].min()}")
print(f"Max: {tools_results['mismatches'].max()}")

print("\nOccurrences:")
print(f"Mean: {spacer_counts_with_tools['n_occurrences'].mean():.2f}")
print(f"Median: {spacer_counts_with_tools['n_occurrences'].median():.2f}")
print(f"Std: {spacer_counts_with_tools['n_occurrences'].std():.2f}")
print(f"Min: {spacer_counts_with_tools['n_occurrences'].min()}")
print(f"Max: {spacer_counts_with_tools['n_occurrences'].max()}")

Finnaly, we'll targz the plots folder in the results folder for easy download/sharing 

In [None]:
%%bash
tar -czvf results/real_data/plots.tar.gz results/real_data/plots 