# Imports and setup

### Auto-re-import python modules, useful for editing local files

In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Imports

In [9]:
import itertools
import logging
import os
import warnings
from collections import Counter

import matplotlib as mpl
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)

pl.Config.set_verbose()
# mpl.rcParams["figure.max_open_warning"] = 0

from notifications import notify, notify_done
from scop_constants import n_scop_cols, same_scop_cols
from sourmash_constants import sourmash_score_cols

# Read in data

In [10]:
pq = "/home/ec2-user/data/seanome-kmerseek/scope-benchmark/analysis-outputs/hp/00_cleaned_multisearch_results/scope40.multisearch.hp.k19.filtered.pq"

In [11]:
multisearch_head = pl.read_parquet(pq, n_rows=1000)
multisearch_head

query_name,query_md5,match_name,match_md5,containment,max_containment,jaccard,intersect_hashes,prob_overlap,prob_overlap_adjusted,containment_adjusted,containment_adjusted_log10,tf_idf_score,query_family,query_superfamily,query_fold,query_class,n_family,n_superfamily,n_fold,n_class,query_scop_id,match_family,match_superfamily,match_fold,match_class,match_scop_id,same_family,same_superfamily,same_fold,same_class,ksize,moltype,log10_prob_overlap_adjusted,log10_containment,log10_max_containment,log10_tf_idf_score,log10_jaccard
str,str,str,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,cat,cat,cat,cat,i64,i64,i64,i64,str,cat,cat,cat,cat,str,bool,bool,bool,bool,i32,str,f64,f64,f64,f64,f64
"""d5icea1 a.4.5.0 (A:3-104) auto…","""421e96bce1ca11c5806b2d7bd54e56…","""d1ji7a_ a.60.1.1 (A:) Etv6 tra…","""883fd41e87af891d5c2d9b86c9322e…",0.035714,0.050847,0.021429,3.0,1.4040e-11,0.003234,11.043209,1.043095,0.314312,"""a.4.5.0""","""a.4.5""","""a.4""","""a""",75,252,425,2644,"""d5icea1""","""a.60.1.1""","""a.60.1""","""a.60""","""a""","""d1ji7a_""",false,false,false,true,19,"""hp""",-2.490253,-1.447158,-1.293731,-0.50264,-1.669007
"""d2cgqa1 a.28.1.0 (A:1-73) auto…","""3f70e3825f0095d77bb2dffff700eb…","""d1ji7a_ a.60.1.1 (A:) Etv6 tra…","""883fd41e87af891d5c2d9b86c9322e…",0.036364,0.036364,0.017857,2.0,4.5241e-11,0.010421,3.489516,0.542765,0.293295,"""a.28.1.0""","""a.28.1""","""a.28""","""a""",10,16,24,2644,"""d2cgqa1""","""a.60.1.1""","""a.60.1""","""a.60""","""a""","""d1ji7a_""",false,false,false,true,19,"""hp""",-1.982098,-1.439333,-1.439333,-0.532695,-1.748188
"""d3cw2c2 a.60.14.1 (C:85-175) E…","""4e396c3d2826c868d44b7a6d5edad5…","""d1ji7a_ a.60.1.1 (A:) Etv6 tra…","""883fd41e87af891d5c2d9b86c9322e…",0.054795,0.067797,0.03125,4.0,7.8781e-11,0.018147,3.019546,0.479942,0.449557,"""a.60.14.1""","""a.60.14""","""a.60""","""a""",2,3,91,2644,"""d3cw2c2""","""a.60.1.1""","""a.60.1""","""a.60""","""a""","""d1ji7a_""",false,false,true,true,19,"""hp""",-1.741204,-1.261263,-1.168792,-0.347215,-1.50515
"""d1u7ka_ a.73.1.1 (A:) AKV caps…","""f716da72651c514cfe72e7d0f39c2b…","""d1ji7a_ a.60.1.1 (A:) Etv6 tra…","""883fd41e87af891d5c2d9b86c9322e…",0.035398,0.067797,0.02381,4.0,1.4196e-11,0.00327,10.825202,1.034436,0.314387,"""a.73.1.1""","""a.73.1""","""a.73""","""a""",5,8,8,2644,"""d1u7ka_""","""a.60.1.1""","""a.60.1""","""a.60""","""a""","""d1ji7a_""",false,false,false,true,19,"""hp""",-2.485454,-1.451018,-1.168792,-0.502536,-1.623249
"""d3bgea1 a.80.1.2 (A:251-434) U…","""639dbe524cd1ba75742ff8fccc62d3…","""d1ji7a_ a.60.1.1 (A:) Etv6 tra…","""883fd41e87af891d5c2d9b86c9322e…",0.012048,0.033898,0.008969,2.0,8.1121e-12,0.001869,6.447842,0.809414,0.106626,"""a.80.1.2""","""a.80.1""","""a.80""","""a""",3,13,13,2644,"""d3bgea1""","""a.60.1.1""","""a.60.1""","""a.60""","""a""","""d1ji7a_""",false,false,false,true,19,"""hp""",-2.728492,-1.919078,-1.469822,-0.972138,-2.047275
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""d4e6pa_ c.2.1.2 (A:) automated…","""205a7c8dc115031f5e4103eba74fc5…","""d3dqya_ b.33.1.0 (A:) automate…","""b5cd595a3d65f01f347e0b0b384e8b…",0.008403,0.022727,0.006173,2.0,1.5600e-11,0.003593,2.338562,0.368949,0.071899,"""c.2.1.2""","""c.2.1""","""c.2""","""c""",64,372,372,4463,"""d4e6pa_""","""b.33.1.0""","""b.33.1""","""b.33""","""b""","""d3dqya_""",false,false,false,false,19,"""hp""",-2.444496,-2.075547,-1.643453,-1.143274,-2.209515
"""d1y81a1 c.2.1.8 (A:6-121) Hypo…","""688e649de5b1005d5d50d73eb95ecb…","""d3dqya_ b.33.1.0 (A:) automate…","""b5cd595a3d65f01f347e0b0b384e8b…",0.020408,0.022727,0.01087,2.0,2.5584e-11,0.005893,3.463027,0.539456,0.170001,"""c.2.1.8""","""c.2.1""","""c.2""","""c""",5,372,372,4463,"""d1y81a1""","""b.33.1.0""","""b.33.1""","""b.33""","""b""","""d3dqya_""",false,false,false,false,19,"""hp""",-2.229652,-1.690196,-1.643453,-0.769549,-1.963788
"""d2aefa1 c.2.1.9 (A:116-244) Po…","""507188243d769461ea2905877e3be7…","""d3dqya_ b.33.1.0 (A:) automate…","""b5cd595a3d65f01f347e0b0b384e8b…",0.018018,0.022727,0.010152,2.0,1.5288e-11,0.003522,5.116545,0.708977,0.154021,"""c.2.1.9""","""c.2.1""","""c.2""","""c""",4,372,372,4463,"""d2aefa1""","""b.33.1.0""","""b.33.1""","""b.33""","""b""","""d3dqya_""",false,false,false,false,19,"""hp""",-2.45327,-1.744293,-1.643453,-0.81242,-1.993436
"""d2z1na_ c.2.1.0 (A:) automated…","""6c898f22db65a18dc0454a119dd782…","""d3dqya_ b.33.1.0 (A:) automate…","""b5cd595a3d65f01f347e0b0b384e8b…",0.012397,0.034091,0.009174,3.0,3.0265e-11,0.006971,1.778279,0.25,0.10456,"""c.2.1.0""","""c.2.1""","""c.2""","""c""",196,372,372,4463,"""d2z1na_""","""b.33.1.0""","""b.33.1""","""b.33""","""b""","""d3dqya_""",false,false,false,false,19,"""hp""",-2.156694,-1.906694,-1.467361,-0.980634,-2.037426


### Read one of the files for schema

In [12]:
multisearch_schema = pl.read_parquet_schema(pq)
multisearch_schema

{'query_name': String,
 'query_md5': String,
 'match_name': String,
 'match_md5': String,
 'containment': Float64,
 'max_containment': Float64,
 'jaccard': Float64,
 'intersect_hashes': Float64,
 'prob_overlap': Float64,
 'prob_overlap_adjusted': Float64,
 'containment_adjusted': Float64,
 'containment_adjusted_log10': Float64,
 'tf_idf_score': Float64,
 'query_family': Categorical(ordering='physical'),
 'query_superfamily': Categorical(ordering='physical'),
 'query_fold': Categorical(ordering='physical'),
 'query_class': Categorical(ordering='physical'),
 'n_family': Int64,
 'n_superfamily': Int64,
 'n_fold': Int64,
 'n_class': Int64,
 'query_scop_id': String,
 'match_family': Categorical(ordering='physical'),
 'match_superfamily': Categorical(ordering='physical'),
 'match_fold': Categorical(ordering='physical'),
 'match_class': Categorical(ordering='physical'),
 'match_scop_id': String,
 'same_family': Boolean,
 'same_superfamily': Boolean,
 'same_fold': Boolean,
 'same_class': Boole

## Test initial function

In [13]:
# dfs = []

# for sourmash_col in sourmash_score_cols:
#     df = (
#         multisearch.sort(sourmash_col, descending=True)
#         # .head(1000)
#         .group_by("query_scop_id").agg(
#             sensitivity_until_first_false_positive(same_scop_cols, n_scop_cols)
#         )
#     ).fill_nan(0)
#     tidy = tidify_sensitivity(df)
#     tidy = tidy.with_columns(pl.lit(sourmash_col).alias("sourmash_score"))

#     dfs.append(tidy)

# sensitivity = pl.concat(dfs)
# sensitivity = sensitivity.with_columns(pl.lit("hp").alias("moltype"))
# sensitivity = sensitivity.with_columns(pl.lit(19).alias("ksize"))
# sensitivity

In [14]:
# dict(
#     [
#         ("query_scop_id", str),
#         ("lineage", str),
#         ("sensitivity", float),
#         ("sensitivity_rank", int),
#     ]
# )

### Function to compute sensitivity

In [15]:
# def compute_sensitivity(multisearch, sourmash_cols, moltype, ksize):
#     dfs = []

#     for sourmash_col in tqdm(sourmash_score_cols):
#         df = (
#             multisearch.sort(sourmash_col, descending=True)
#             # .head(1000)
#             .group_by("query_scop_id").agg(
#                 sensitivity_until_first_false_positive(same_scop_cols, n_scop_cols)
#             )
#         ).fill_nan(0)
#         tidy = tidify_sensitivity(df)
#         tidy = tidy.with_columns(pl.lit(sourmash_col).alias("sourmash_score"))
#         dfs.append(tidy)

#     sensitivity = pl.concat(dfs)
#     # print("sensitivity.shape:", sensitivity.shape)
#     sensitivity = sensitivity.with_columns(pl.lit(moltype).alias("moltype"))
#     sensitivity = sensitivity.with_columns(pl.lit(ksize).alias("ksize"))
#     return sensitivity

## Iterate over all ksizes and moltypes

In [16]:
from sensitivity_until_first_false_positive import MultisearchSensitivityCalculator

In [None]:
# from sensitivity_outdir


moltype_info = {
    # "protein": dict(
    #     ksizes=range(5, 21),
    #     pipeline_outdir="s3://seanome-kmerseek/scope-benchmark/pipeline-outputs/2024-10-08__protein_k5-20",
    #     analysis_outdir="s3://seanome-kmerseek/scope-benchmark/analysis-outputs/protein",
    # ),
    # "dayhoff": dict(
    #     ksizes=range(5, 21),
    #     pipeline_outdir="s3://seanome-kmerseek/scope-benchmark/pipeline-outputs/2024-10-09__dayhoff_k5-20",
    #     analysis_outdir="s3://seanome-kmerseek/scope-benchmark/analysis-outputs/dayhoff",
    # ),
    "hp": dict(
        ksizes=reversed(range(10, 16)),
        # pipeline_outdir="s3://seanome-kmerseek/scope-benchmark/pipeline-outputs/2024-10-09__hp_k20-60",
        pipeline_outdir="/home/ec2-user/data/seanome-kmerseek/scope-benchmark/pipeline-outputs/2024-10-09__hp_k20-60",
        # analysis_outdir="s3://seanome-kmerseek/scope-benchmark/analysis-outputs/hp",
        analysis_outdir="/home/ec2-user/data/seanome-kmerseek/scope-benchmark/analysis-outputs/hp",
    ),
}

basename_template = r"scope40.multisearch.{moltype}.k{ksize}.filtered.pq"

for moltype, info in moltype_info.items():
    ksizes = info["ksizes"]
    cleaned_outdir = os.path.join(
        info["analysis_outdir"], "00_cleaned_multisearch_results"
    )
    sensitivity_outdir = os.path.join(
        info["analysis_outdir"], "01_sensitivity_until_first_false_positive"
    )
    if not os.path.exists(sensitivity_outdir):
        ! mkdir $sensitivity_outdir
    for ksize in ksizes:

        msc = MultisearchSensitivityCalculator(moltype, ksize, cleaned_outdir, sensitivity_outdir)
        msc.calculate_sensitivity()

        # multisearch_sensitivity
        # notify(f"--- moltype: {moltype}, ksize: {ksize} --")

        # pq_out = os.path.join(
        #     sensitivity_outdir, 
        #     f"scope40.multisearch.{moltype}.{ksize}.sensitivity_to_first_fp.pq"
        # )
        # notify(f'pq out: {pq_out}')
        # if os.path.exists(pq_out):
        #     continue
        
        # basename = basename_template.format(moltype=moltype, ksize=ksize)
        # pq = os.path.join(cleaned_outdir, basename)
        # notify(f"Reading {pq} ...")
        # multisearch = pl.scan_parquet(pq, schema=multisearch_schema, parallel='row_groups')
        # notify_done()
        # sensitivity = compute_sensitivity(
        #     multisearch, sourmash_score_cols, moltype, ksize
        # )

        # notify(f'Writing {pq_out} ... ')
        # sensitivity.sink_parquet(pq_out, row_group_size=1000)
        # # sensitivity.collect().write_parquet(pq_out)
        # notify_done()

INFO:notifications:--- moltype: hp, ksize: 15 --
INFO:notifications:pq out: /home/ec2-user/data/seanome-kmerseek/scope-benchmark/analysis-outputs/hp/01_sensitivity_until_first_false_positive/scope40.multisearch.hp.15.sensitivity_to_first_fp.pq
INFO:notifications:Reading /home/ec2-user/data/seanome-kmerseek/scope-benchmark/analysis-outputs/hp/00_cleaned_multisearch_results/scope40.multisearch.hp.k15.filtered.pq ...
INFO:notifications:Done.
  0%|                                                                                                                                 | 0/8 [00:00<?, ?it/s]INFO:notifications:Writing 'containment' sensitivity dataframe to /tmp/sensitivityaog6__mn.containment.parquet ...
keys/aggregates are not partitionable: running default HASH AGGREGATION
keys/aggregates are not partitionable: running default HASH AGGREGATION


In [None]:
# sensitivity.show_graph()

In [14]:
# sensitivity.comm

In [15]:
sensitivity.shape

NameError: name 'sensitivity' is not defined

In [None]:
# g = sns.catplot(
#     data=sensitivity,
#     col="sourmash_score",
#     hue="ksize",
#     y="variable",
#     x="sensitivity",
#     col_wrap=4, height=3
# )

In [None]:
# %debug

In [None]:
cv