# Imports and setup

### Auto-re-import python modules, useful for editing local files

In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Imports

In [6]:
import itertools
import logging
import os
import warnings
from collections import Counter

import matplotlib as mpl
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)

pl.Config.set_verbose()
# mpl.rcParams["figure.max_open_warning"] = 0

from notifications import notify, notify_done
from scop_constants import n_scop_cols, same_scop_cols
from sensitivity_until_first_false_positive import (
    sensitivity_until_first_false_positive,
    tidify_sensitivity,
)
from sourmash_constants import sourmash_score_cols

INFO:aiobotocore.credentials:Found credentials in shared credentials file: ~/.aws/credentials


# Read in data

In [7]:
pq = "/home/ec2-user/data/seanome-kmerseek/scope-benchmark/analysis-outputs/hp/00_cleaned_multisearch_results/scope40.multisearch.hp.k19.filtered.pq"

### Read one of the files for schema

In [8]:
multisearch_schema = pl.read_parquet_schema(pq)
multisearch_schema

{'query_name': String,
 'query_md5': String,
 'match_name': String,
 'match_md5': String,
 'containment': Float64,
 'max_containment': Float64,
 'jaccard': Float64,
 'intersect_hashes': Float64,
 'prob_overlap': Float64,
 'prob_overlap_adjusted': Float64,
 'containment_adjusted': Float64,
 'containment_adjusted_log10': Float64,
 'tf_idf_score': Float64,
 'query_family': Categorical(ordering='physical'),
 'query_superfamily': Categorical(ordering='physical'),
 'query_fold': Categorical(ordering='physical'),
 'query_class': Categorical(ordering='physical'),
 'n_family': Int64,
 'n_superfamily': Int64,
 'n_fold': Int64,
 'n_class': Int64,
 'query_scop_id': String,
 'match_family': Categorical(ordering='physical'),
 'match_superfamily': Categorical(ordering='physical'),
 'match_fold': Categorical(ordering='physical'),
 'match_class': Categorical(ordering='physical'),
 'match_scop_id': String,
 'same_family': Boolean,
 'same_superfamily': Boolean,
 'same_fold': Boolean,
 'same_class': Boole

## Test initial function

In [9]:
# dfs = []

# for sourmash_col in sourmash_score_cols:
#     df = (
#         multisearch.sort(sourmash_col, descending=True)
#         # .head(1000)
#         .group_by("query_scop_id").agg(
#             sensitivity_until_first_false_positive(same_scop_cols, n_scop_cols)
#         )
#     ).fill_nan(0)
#     tidy = tidify_sensitivity(df)
#     tidy = tidy.with_columns(pl.lit(sourmash_col).alias("sourmash_score"))

#     dfs.append(tidy)

# sensitivity = pl.concat(dfs)
# sensitivity = sensitivity.with_columns(pl.lit("hp").alias("moltype"))
# sensitivity = sensitivity.with_columns(pl.lit(19).alias("ksize"))
# sensitivity

In [10]:
# dict(
#     [
#         ("query_scop_id", str),
#         ("lineage", str),
#         ("sensitivity", float),
#         ("sensitivity_rank", int),
#     ]
# )

### Function to compute sensitivity

In [11]:
# def compute_sensitivity(multisearch, sourmash_cols, moltype, ksize):
#     dfs = []

#     for sourmash_col in tqdm(sourmash_score_cols):
#         df = (
#             multisearch.sort(sourmash_col, descending=True)
#             # .head(1000)
#             .group_by("query_scop_id").agg(
#                 sensitivity_until_first_false_positive(same_scop_cols, n_scop_cols)
#             )
#         ).fill_nan(0)
#         tidy = tidify_sensitivity(df)
#         tidy = tidy.with_columns(pl.lit(sourmash_col).alias("sourmash_score"))
#         dfs.append(tidy)

#     sensitivity = pl.concat(dfs)
#     # print("sensitivity.shape:", sensitivity.shape)
#     sensitivity = sensitivity.with_columns(pl.lit(moltype).alias("moltype"))
#     sensitivity = sensitivity.with_columns(pl.lit(ksize).alias("ksize"))
#     return sensitivity

In [12]:
from sensitivity_until_first_false_positive import compute_sensitivity

## Iterate over all ksizes and moltypes

In [None]:

moltype_info = {
    # "protein": dict(
    #     ksizes=range(5, 21),
    #     pipeline_outdir="s3://seanome-kmerseek/scope-benchmark/pipeline-outputs/2024-10-08__protein_k5-20",
    #     analysis_outdir="s3://seanome-kmerseek/scope-benchmark/analysis-outputs/protein",
    # ),
    # "dayhoff": dict(
    #     ksizes=range(5, 21),
    #     pipeline_outdir="s3://seanome-kmerseek/scope-benchmark/pipeline-outputs/2024-10-09__dayhoff_k5-20",
    #     analysis_outdir="s3://seanome-kmerseek/scope-benchmark/analysis-outputs/dayhoff",
    # ),
    "hp": dict(
        ksizes=reversed(range(10, 16)),
        # pipeline_outdir="s3://seanome-kmerseek/scope-benchmark/pipeline-outputs/2024-10-09__hp_k20-60",
        pipeline_outdir="/home/ec2-user/data/seanome-kmerseek/scope-benchmark/pipeline-outputs/2024-10-09__hp_k20-60",
        # analysis_outdir="s3://seanome-kmerseek/scope-benchmark/analysis-outputs/hp",
        analysis_outdir="/home/ec2-user/data/seanome-kmerseek/scope-benchmark/analysis-outputs/hp",
    ),
}

basename_template = r"scope40.multisearch.{moltype}.k{ksize}.filtered.pq"

for moltype, info in moltype_info.items():
    ksizes = info["ksizes"]
    cleaned_outdir = os.path.join(
        info["analysis_outdir"], "00_cleaned_multisearch_results"
    )
    sensitivity_outdir = os.path.join(
        info["analysis_outdir"], "01_sensitivity_until_first_false_positive"
    )
    if not os.path.exists(sensitivity_outdir):
        ! mkdir $sensitivity_outdir
    for ksize in ksizes:
        notify(f"--- moltype: {moltype}, ksize: {ksize} --")

        pq_out = os.path.join(
            sensitivity_outdir, 
            f"scope40.multisearch.{moltype}.{ksize}.sensitivity_to_first_fp.pq"
        )
        notify(f'pq out: {pq_out}')
        if os.path.exists(pq_out):
            continue
        
        basename = basename_template.format(moltype=moltype, ksize=ksize)
        pq = os.path.join(cleaned_outdir, basename)
        notify(f"Reading {pq} ...")
        multisearch = pl.scan_parquet(pq, schema=multisearch_schema, parallel='row_groups')
        notify_done()
        sensitivity = compute_sensitivity(
            multisearch, sourmash_score_cols, moltype, ksize
        )

        notify(f'Writing {pq_out} ... ')
        sensitivity.sink_parquet(pq_out, row_group_size=1000)
        # sensitivity.collect().write_parquet(pq_out)
        notify_done()

INFO:notifications:--- moltype: hp, ksize: 15 --
INFO:notifications:pq out: /home/ec2-user/data/seanome-kmerseek/scope-benchmark/analysis-outputs/hp/01_sensitivity_until_first_false_positive/scope40.multisearch.hp.15.sensitivity_to_first_fp.pq
INFO:notifications:Reading /home/ec2-user/data/seanome-kmerseek/scope-benchmark/analysis-outputs/hp/00_cleaned_multisearch_results/scope40.multisearch.hp.k15.filtered.pq ...
INFO:notifications:Done.
  0%|                                                                                                                                 | 0/8 [00:00<?, ?it/s]INFO:notifications:Writing 'containment' sensitivity dataframe to /tmp/tmpesfo2d44.parquet ...
keys/aggregates are not partitionable: running default HASH AGGREGATION
keys/aggregates are not partitionable: running default HASH AGGREGATION
INFO:notifications:Done.
 12%|██████████████▋                                                                                                      | 1/8 [17:15<2:

In [None]:
# sensitivity.show_graph()

In [None]:
sensitivity.comm

In [None]:
sensitivity.shape

In [None]:
# g = sns.catplot(
#     data=sensitivity,
#     col="sourmash_score",
#     hue="ksize",
#     y="variable",
#     x="sensitivity",
#     col_wrap=4, height=3
# )

In [None]:
# %debug

In [None]:
cv