# Imports and setup

### Auto-re-import python modules, useful for editing local files

In [1]:
%load_ext autoreload
%autoreload 2

## Imports

In [2]:
import itertools
import warnings
from collections import Counter

import matplotlib as mpl
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
from tqdm import tqdm

pd.options.display.max_columns = 100

mpl.rcParams["figure.max_open_warning"] = 0

In [3]:
assert pl.__version__ == "1.9.0"

# Read in data

## read in unfiltered data

In [89]:
pq = "s3://seanome-kmerseek/scope-benchmark/analysis-outputs/2024-10-09__hp_k20-60/00_cleaned_multisearch_results/scope40.multisearch.hp.k20.pq"
multisearch_unfiltered = pd.read_parquet(pq)

In [90]:
lineage_cols

['family', 'superfamily', 'fold', 'class']

### Set SCOP lineage column names

In [11]:
lineage_cols = ["family", "superfamily", "fold", "class"]
query_scop_cols = [f"query_{x}" for x in lineage_cols]
match_scop_cols = [f"match_{x}" for x in lineage_cols]

same_scop_cols = [f"same_{x}" for x in lineage_cols]

### Make query metadata

In [92]:
query_metadata = pd.DataFrame(
    multisearch_unfiltered[query_scop_cols].values,
    index=multisearch_unfiltered["query_scop_id"].values,
    columns=query_scop_cols,
)
query_metadata = query_metadata.sort_index()
print(query_metadata.shape)
query_metadata = query_metadata.loc[~query_metadata.index.duplicated()]
print(query_metadata.shape)
query_metadata.head()

(3471675, 4)
(15177, 4)


Unnamed: 0,query_family,query_superfamily,query_fold,query_class
d12asa_,d.104.1.1,d.104.1,d.104,d
d16vpa_,d.180.1.1,d.180.1,d.180,d
d1914a1,d.49.1.1,d.49.1,d.49,d
d1914a2,d.49.1.1,d.49.1,d.49,d
d1a04a1,a.4.6.2,a.4.6,a.4,a


## Filter to remove self matches

In [93]:
multisearch = multisearch_unfiltered.query(
    "intersect_hashes > 1 and query_md5 != match_md5"
)

### remove unfiltered metadata

In [94]:
del multisearch_unfiltered


## Count sensitivity to first false positive

In [95]:
multisearch.columns

Index(['query_name', 'query_md5', 'match_name', 'match_md5', 'containment',
       'max_containment', 'jaccard', 'intersect_hashes', 'prob_overlap',
       'prob_overlap_adjusted', 'containment_adjusted',
       'containment_adjusted_log10', 'tf_idf_score', 'query_scop_id',
       'query_scop_lineage', 'query_scop_lineage_fixed', 'query_family',
       'query_superfamily', 'query_fold', 'query_class', 'match_scop_id',
       'match_scop_lineage', 'match_scop_lineage_fixed', 'match_family',
       'match_superfamily', 'match_fold', 'match_class', 'same_family',
       'same_superfamily', 'same_fold', 'same_class'],
      dtype='object')

In [96]:
cols = [sourmash_col] + list(reversed(same_scop_cols))
cols

['containment', 'same_class', 'same_fold', 'same_superfamily', 'same_family']

In [97]:
multisearch = multisearch.sort_values(by=cols, ascending=False)

In [98]:
sourmash_col = "containment"

groupby = ["query_scop_id"] + query_scop_cols

for query, df in multisearch.groupby(groupby):
    print("query:", query, df.shape)
    # display(df.head())

    # Sort the values so the biggest one is first
    df = df.sort_values(sourmash_col, ascending=False)

    # display(df)
    break

  for query, df in multisearch.groupby(groupby):


query: ('d12asa_', 'd.104.1.1', 'd.104.1', 'd.104', 'd') (203, 31)


In [99]:
def sum_until_first_zero(series):
    found_false_positive = False

    total = 0
    while not found_false_positive:
        for v in series:
            if v:
                print(f"v: {v}")
                total += 1
            else:
                found_false_positive = True
    return total


# Index of minimum value, will return the first thing
first_zeros = df[same_scop_cols].idxmin()
count_until_first_fp = df[same_scop_cols].apply(
    lambda x: x.loc[: first_zeros[x.name]].sum()
)
count_until_first_fp

same_family         0
same_superfamily    0
same_fold           0
same_class          1
dtype: int64

In [100]:
df[[sourmash_col] + same_scop_cols].head()

Unnamed: 0,containment,same_family,same_superfamily,same_fold,same_class
893134,0.038961,False,False,False,True
1772250,0.025974,False,False,False,False
1459459,0.025974,False,False,False,True
2513908,0.022727,False,False,False,False
459483,0.019481,False,False,False,True


### Count number of groups per sample

In [104]:
def count_scop_lineage(df, col):
    return Counter(df[col])


n_groups_per_scop_lineage = {
    lineage: pd.Series(
        count_scop_lineage(query_metadata, f"query_{lineage}"), name=f"n_{lineage}"
    )
    for lineage in lineage_cols
}
n_groups_per_scop_lineage.keys()
# n_groups_per_scop_lineage

dict_keys(['family', 'superfamily', 'fold', 'class'])

In [105]:
n_groups_per_scop_lineage["class"]

d    3653
a    2644
c    4463
b    3059
f     332
g     722
e     304
Name: n_class, dtype: int64

### Add number of groups per sample

In [106]:
for lineage, series in n_groups_per_scop_lineage.items():
    on = f"query_{lineage}"
    multisearch = multisearch.join(series, on=on)
multisearch.head()

Unnamed: 0,query_name,query_md5,match_name,match_md5,containment,max_containment,jaccard,intersect_hashes,prob_overlap,prob_overlap_adjusted,containment_adjusted,containment_adjusted_log10,tf_idf_score,query_scop_id,query_scop_lineage,query_scop_lineage_fixed,query_family,query_superfamily,query_fold,query_class,match_scop_id,match_scop_lineage,match_scop_lineage_fixed,match_family,match_superfamily,match_fold,match_class,same_family,same_superfamily,same_fold,same_class,n_family,n_superfamily,n_fold,n_class
945139,d2kq5a_ a.298.1.1 (A:) PthA {Xanthomonas axono...,d435bfe0fd1203ad9dde6bf11e6b2bab,d3v6ta_ a.298.1.1 (A:) dHax3 {Xanthomonas [Tax...,4eb3271d07bf470d7edc17f3d2632960,1.0,1.0,0.182857,32.0,8.295542e-10,0.191081,5.233393,0.718783,8.769065,d2kq5a_,a.298.1.1,a.298.1.1,a.298.1.1,a.298.1,a.298,a,d3v6ta_,a.298.1.1,a.298.1.1,a.298.1.1,a.298.1,a.298,a,True,True,True,True,3,3,3,2644
2048779,d1y8xb1 c.111.1.2 (B:349-440) UBA3 {Human (Hom...,a885b33673a0ed0500344799540c4cdb,d1ngvb_ c.111.1.2 (B:) automated matches {Homo...,759286aed279d3e9c2db8f07ff4d31f8,1.0,1.0,0.177616,73.0,2.789917e-10,0.064263,15.560978,1.192037,8.926526,d1y8xb1,c.111.1.2,c.111.1.2,c.111.1.2,c.111.1,c.111,c,d1ngvb_,c.111.1.2,c.111.1.2,c.111.1.2,c.111.1,c.111,c,True,True,True,True,3,4,4,4463
2195396,d3j9ca1 b.179.1.1 (A:174-225) PA14 {Bacillus a...,d5e1cc5e0bae9ee26fd79b2cc0402ffe,d3tewa2 b.179.1.1 (A:15-225) PA14 {Bacillus an...,c09894ddcd3b37e7268e52140edcd93d,1.0,1.0,0.171875,33.0,1.274173e-10,0.029349,34.072178,1.5324,8.946405,d3j9ca1,b.179.1.1,b.179.1.1,b.179.1.1,b.179.1,b.179,b,d3tewa2,b.179.1.1,b.179.1.1,b.179.1.1,b.179.1,b.179,b,True,True,True,True,3,6,6,3059
2867133,d3h3ba2 g.3.9.0 (A:166-192) automated matches ...,b988c74c89b65c4fe06b1579d5306191,d5my6a2 g.3.9.0 (A:188-344) automated matches ...,9dcdbbd95a7ddb417a15e9e8aed167bf,1.0,1.0,0.057971,8.0,1.294698e-11,0.002982,335.320091,2.52546,9.27868,d3h3ba2,g.3.9.0,g.3.9.0,g.3.9.0,g.3.9,g.3,g,d5my6a2,g.3.9.0,g.3.9.0,g.3.9.0,g.3.9,g.3,g,True,True,True,True,3,7,139,722
3276864,d1t1ra3 c.2.1.3 (A:275-300) 1-deoxy-D-xylulose...,0ec5455412a302af86d27ae54b11af17,"d2egha2 c.2.1.3 (A:1-125,A:275-300) 1-deoxy-D-...",5c42cf08e107247f0e5e46b921b15703,1.0,1.0,0.052632,7.0,4.815646e-11,0.011092,90.151631,1.954974,8.6436,d1t1ra3,c.2.1.3,c.2.1.3,c.2.1.3,c.2.1,c.2,c,d2egha2,c.2.1.3,c.2.1.3,c.2.1.3,c.2.1,c.2,c,True,True,True,True,30,372,372,4463


### Write function for sensitive to first FP

In [107]:
n_scop_cols = "n_family	n_superfamily	n_fold	n_class".split()
n_scop_cols

['n_family', 'n_superfamily', 'n_fold', 'n_class']

In [114]:
def sensitivity_until_first_false_positive(
    df, same_scop_cols=same_scop_cols, n_scop_cols=n_scop_cols
):
    # Index of minimum value, will return the first thing
    first_zeros = df[same_scop_cols].idxmin()
    count_until_first_fp = df[same_scop_cols].apply(
        lambda x: x.loc[: first_zeros[x.name]].sum()
    )

    # Subtract 1 to ignore self-matches
    n_per_lineage = df[n_scop_cols] - 1
    # Take the first row since all the values are the same for the query
    n_per_lineage = n_per_lineage.values[0]

    sensitivity_until_first_fp = count_until_first_fp / n_per_lineage
    sensitivity_until_first_fp.index = sensitivity_until_first_fp.index.str.replace(
        "same", "sensitivity"
    )
    sensitivity_until_first_fp = sensitivity_until_first_fp.fillna(0)
    return sensitivity_until_first_fp


multisearch = multisearch.sort_values(sourmash_col, ascending=False)
multisearch_sensitivity = multisearch.groupby("query_scop_id").apply(
    sensitivity_until_first_false_positive
)

  multisearch_sensitivity = multisearch.groupby("query_scop_id").apply(


In [115]:
multisearch_sensitivity

Unnamed: 0_level_0,sensitivity_family,sensitivity_superfamily,sensitivity_fold,sensitivity_class
query_scop_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
d12asa_,0.0,0.0,0.0,0.000548
d16vpa_,,,,0.000000
d1914a1,0.0,0.0,0.0,0.000274
d1914a2,0.0,0.0,0.0,0.000000
d1a04a1,0.0,0.0,0.0,0.000000
...,...,...,...,...
g2dbx.1,0.0,0.0,0.0,0.000000
g2vt1.1,0.0,0.0,0.0,0.000000
g3bzy.1,0.0,0.0,0.0,0.000000
g3n55.1,,0.0,0.0,0.000000


In [122]:
multisearch_sensitivity.query(
    "sensitivity_family > .9 and sensitivity_superfamily > .5"
).sort_index()

Unnamed: 0_level_0,sensitivity_family,sensitivity_superfamily,sensitivity_fold,sensitivity_class
query_scop_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
d1k91a1,1.0,1.0,1.0,0.001635
d1ouva1,1.0,1.0,0.005618,0.000378
d2ekna_,1.0,1.0,0.002132,0.000548
d2i8da1,1.0,1.0,0.05,0.000274
d3jqka_,1.0,1.0,0.002132,0.000274
d7dkzl_,1.0,1.0,1.0,0.009063
d7lx0l_,1.0,1.0,1.0,0.003021


In [112]:
# %debug

## Count number of same scop per column

In [15]:
same_scop_counts = multisearch.group_by(["query_scop_id"] + query_scop_cols).agg(
    pl.col(same_scop_cols).sum()
)

# Subtract 1 for self-matches
# Doing this after the fact makes sure that we have ALL 15,177 samples for
# each analyses, otherwise they'd be unobserved
# same_scop_counts -= 1

# same_scop_bool = same_scop_counts > 0
# same_scop_bool.columns = same_scop_bool.columns + "_bool"
# same_scop_counts = pd.concat([same_scop_counts, same_scop_bool], axis=1)
same_scop_counts.head()

query_scop_id,query_family,query_superfamily,query_fold,query_class,same_family,same_superfamily,same_fold,same_class
str,cat,cat,cat,cat,u64,u64,u64,u64
"""d2e9ja1""","""b.1.18.10""","""b.1.18""","""b.1""","""b""",0,0,2,7
"""d2qxza_""","""b.80.1.0""","""b.80.1""","""b.80""","""b""",0,2,2,33
"""d6pzda1""","""b.68.1.1""","""b.68.1""","""b.68""","""b""",0,0,0,29
"""d2cw9a1""","""d.17.4.13""","""d.17.4""","""d.17""","""d""",0,0,0,30
"""d1vj7a1""","""a.211.1.1""","""a.211.1""","""a.211""","""a""",0,0,0,10
