/
sketchcomparison.py
190 lines (162 loc) · 7.64 KB
/
sketchcomparison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Sketch Comparison Classes
"""
import numpy as np
from dataclasses import dataclass
from .signature import MinHash
@dataclass
class BaseMinHashComparison:
"""Class for standard comparison between two MinHashes"""
mh1: MinHash
mh2: MinHash
ignore_abundance: bool = False # optionally ignore abundances
def downsample_and_handle_ignore_abundance(self, cmp_num=None, cmp_scaled=None):
"""
Downsample and/or flatten minhashes for comparison
"""
if self.ignore_abundance:
self.mh1_cmp = self.mh1.flatten()
self.mh2_cmp = self.mh2.flatten()
else:
self.mh1_cmp = self.mh1
self.mh2_cmp = self.mh2
if cmp_scaled is not None:
self.mh1_cmp = self.mh1_cmp.downsample(scaled=cmp_scaled)
self.mh2_cmp = self.mh2_cmp.downsample(scaled=cmp_scaled)
elif cmp_num is not None:
self.mh1_cmp = self.mh1_cmp.downsample(num=cmp_num)
self.mh2_cmp = self.mh2_cmp.downsample(num=cmp_num)
else:
raise ValueError("Error: must pass in a comparison scaled or num value.")
def check_compatibility_and_downsample(self, cmp_num=None, cmp_scaled=None):
if not any([(self.mh1.num and self.mh2.num), (self.mh1.scaled and self.mh2.scaled)]):
raise TypeError("Error: Both sketches must be 'num' or 'scaled'.")
#need to downsample first because is_compatible checks scaled (though does not check num)
self.downsample_and_handle_ignore_abundance(cmp_num=cmp_num, cmp_scaled=cmp_scaled)
if not self.mh1_cmp.is_compatible(self.mh2_cmp):
raise TypeError("Error: Cannot compare incompatible sketches.")
self.ksize = self.mh1.ksize
self.moltype = self.mh1.moltype
@property
def intersect_mh(self):
# flatten and intersect
return self.mh1_cmp.flatten().intersection(self.mh2_cmp.flatten())
@property
def jaccard(self):
return self.mh1_cmp.jaccard(self.mh2_cmp)
def estimate_jaccard_ani(self, jaccard=None):
jinfo = self.mh1_cmp.jaccard_ani(self.mh2_cmp, jaccard=jaccard)
# propagate params
self.jaccard_ani = jinfo.ani
if jinfo.p_exceeds_threshold:
self.potential_false_negative = True
self.jaccard_ani_untrustworthy = jinfo.je_exceeds_threshold
@property
def angular_similarity(self):
# Note: this currently throws TypeError if self.ignore_abundance.
return self.mh1_cmp.angular_similarity(self.mh2_cmp)
@property
def cosine_similarity(self):
return self.angular_similarity
@dataclass
class NumMinHashComparison(BaseMinHashComparison):
"""Class for standard comparison between two num minhashes"""
cmp_num: int = None
def __post_init__(self):
"Initialize NumMinHashComparison using values from provided MinHashes"
if self.cmp_num is None: # record the num we're doing this comparison on
self.cmp_num = min(self.mh1.num, self.mh2.num)
self.check_compatibility_and_downsample(cmp_num=self.cmp_num)
@dataclass
class FracMinHashComparison(BaseMinHashComparison):
"""Class for standard comparison between two scaled minhashes"""
cmp_scaled: int = None # optionally force scaled value for this comparison
threshold_bp: int = 0
estimate_ani_ci: bool = False
ani_confidence: float = 0.95
def __post_init__(self):
"Initialize ScaledComparison using values from provided FracMinHashes"
if self.cmp_scaled is None:
# comparison scaled defaults to maximum scaled between the two sigs
self.cmp_scaled = max(self.mh1.scaled, self.mh2.scaled)
self.check_compatibility_and_downsample(cmp_scaled=self.cmp_scaled)
self.potential_false_negative = False
@property
def pass_threshold(self):
return self.intersect_bp >= self.threshold_bp
@property
def intersect_bp(self):
return (len(self.intersect_mh) * self.cmp_scaled) + (self.ksize - 1)
@property
def mh1_containment(self):
return self.mh1_cmp.contained_by(self.mh2_cmp)
def estimate_mh1_containment_ani(self, containment = None):
# build result once
m1_cani = self.mh1_cmp.containment_ani(self.mh2_cmp,
containment=containment,
confidence=self.ani_confidence,
estimate_ci=self.estimate_ani_ci)
# propagate params
self.mh1_containment_ani = m1_cani.ani
if m1_cani.p_exceeds_threshold:
# only update if True
self.potential_false_negative = True
if self.estimate_ani_ci:
self.mh1_containment_ani_low = m1_cani.ani_low
self.mh1_containment_ani_high = m1_cani.ani_high
@property
def mh2_containment(self):
return self.mh2_cmp.contained_by(self.mh1_cmp)
def estimate_mh2_containment_ani(self, containment=None):
m2_cani = self.mh2_cmp.containment_ani(self.mh1_cmp,
containment=containment,
confidence=self.ani_confidence,
estimate_ci=self.estimate_ani_ci)
self.mh2_containment_ani = m2_cani.ani
if m2_cani.p_exceeds_threshold:
self.potential_false_negative = True
if self.estimate_ani_ci:
self.mh2_containment_ani_low = m2_cani.ani_low
self.mh2_containment_ani_high = m2_cani.ani_high
@property
def max_containment(self):
return self.mh1_cmp.max_containment(self.mh2_cmp)
def estimate_max_containment_ani(self, max_containment=None):
mc_ani_info = self.mh1_cmp.max_containment_ani(self.mh2_cmp,
max_containment=max_containment,
confidence=self.ani_confidence,
estimate_ci=self.estimate_ani_ci)
# propagate params
self.max_containment_ani = mc_ani_info.ani
if mc_ani_info.p_exceeds_threshold:
self.potential_false_negative = True
if self.estimate_ani_ci:
self.max_containment_ani_low = mc_ani_info.ani_low
self.max_containment_ani_high = mc_ani_info.ani_high
@property
def avg_containment(self):
return self.mh1_cmp.avg_containment(self.mh2_cmp)
@property
def avg_containment_ani(self):
"Returns single average_containment_ani value."
return self.mh1_cmp.avg_containment_ani(self.mh2_cmp)
def estimate_all_containment_ani(self):
"Estimate all containment ANI values."
self.estimate_mh1_containment_ani()
self.estimate_mh2_containment_ani()
self.max_containment_ani = max([self.mh1_containment_ani, self.mh2_containment_ani])
def weighted_intersection(self, from_mh=None, from_abundD={}):
# map abundances to all intersection hashes.
abund_mh = self.intersect_mh.copy_and_clear()
abund_mh.track_abundance = True
# if from_mh is provided, it takes precedence over from_abund dict
if from_mh is not None and from_mh.track_abundance:
from_abundD = from_mh.hashes
if from_abundD:
# this sets any hash not present in abundD to 1. Is that desired? Or should we return 0?
abunds = {k: from_abundD.get(k, 1) for k in self.intersect_mh.hashes }
abund_mh.set_abundances(abunds)
return abund_mh
# if no abundances are passed in, return intersect_mh
# future note: do we want to return 1 as abundance instead?
return self.intersect_mh