Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Apply downsampling magic to sbt_search #183

Merged
merged 13 commits into from May 16, 2017
36 changes: 33 additions & 3 deletions sourmash_lib/_minhash.pyx
Expand Up @@ -10,6 +10,7 @@ from libc.stdint cimport uint32_t

from ._minhash cimport KmerMinHash, KmerMinAbundance, _hash_murmur
import math
import copy


# default MurmurHash seed
Expand Down Expand Up @@ -207,7 +208,14 @@ cdef class MinHash(object):
def max_hash(self):
mm = deref(self._this).max_hash
if mm == 18446744073709551615:
return 0
mm = 0

# legacy - check cardinality => estimate
if mm == 0:
if self.hll:
genome_size = self.hll.estimate_cardinality()
mm = max(self.get_mins())

return mm

def add_hash(self, uint64_t h):
Expand Down Expand Up @@ -248,9 +256,20 @@ cdef class MinHash(object):

return a

def downsample_max_hash(self, *others):
max_hashes = [ x.max_hash for x in others ]
new_max_hash = min(self.max_hash, *max_hashes)
new_scaled = int(get_minhash_max_hash() / new_max_hash)

return self.downsample_scaled(new_scaled)

def downsample_scaled(self, new_num):
old_scaled = int(get_minhash_max_hash() / self.max_hash)
max_hash = self.max_hash

if max_hash is None:
raise ValueError('no max_hash available - cannot downsample')

old_scaled = int(get_minhash_max_hash() / self.max_hash)
if old_scaled > new_num:
raise ValueError('new scaled is lower than current sample scaled')

Expand Down Expand Up @@ -336,6 +355,10 @@ cdef class MinHash(object):
if not self.track_abundance or ignore_abundance:
return self.jaccard(other)
else:
# can we merge? if not, raise exception.
aa = copy.copy(self)
aa.merge(other)

a = self.get_mins(with_abundance=True)
b = other.get_mins(with_abundance=True)

Expand All @@ -345,6 +368,12 @@ cdef class MinHash(object):
distance = 2*math.acos(prod) / math.pi
return 1.0 - distance

def containment(self, other):
"""\
Calculate containment of self by other.
"""
return self.count_common(other) / len(self.get_mins())

def similarity_ignore_maxhash(self, MinHash other):
a = set(self.get_mins())

Expand All @@ -367,7 +396,8 @@ cdef class MinHash(object):
cpdef set_abundances(self, dict values):
if self.track_abundance:
for k, v in values.items():
(<KmerMinAbundance*>address(deref(self._this))).mins[k] = v
if not self.max_hash or k < self.max_hash:
(<KmerMinAbundance*>address(deref(self._this))).mins[k] = v
else:
raise RuntimeError("Use track_abundance=True when constructing "
"the MinHash to use set_abundances.")
Expand Down
31 changes: 19 additions & 12 deletions sourmash_lib/commands.py
Expand Up @@ -27,6 +27,7 @@ def search(args):
parser.add_argument('-k', '--ksize', default=DEFAULT_K, type=int)
parser.add_argument('-f', '--force', action='store_true')
parser.add_argument('--save-matches', type=argparse.FileType('wt'))
parser.add_argument('--containment', action='store_true')

sourmash_args.add_moltype_args(parser)

Expand Down Expand Up @@ -62,7 +63,10 @@ def search(args):
# compute query x db
distances = []
for (x, filename) in against:
distance = query.similarity(x)
if args.containment:
distance = query.containment(x)
else:
distance = query.similarity(x)
if distance >= args.threshold:
distances.append((distance, x, filename))

Expand Down Expand Up @@ -622,12 +626,19 @@ def sbt_search(args):

results = []
for leaf in tree.find(search_fn, query, args.threshold):
results.append((query.similarity(leaf.data), leaf.data))
#results.append((leaf.data.similarity(ss), leaf.data))
results.append((query.similarity(leaf.data, downsample=True),
leaf.data))

results.sort(key=lambda x: -x[0]) # reverse sort on similarity

if args.best_only:
notify("(truncated search because of --best-only; only trust top result")

notify("similarity match")
notify("---------- -----")
for (similarity, query) in results:
print('{:.2f} {}'.format(similarity, query.name()))
pct = '{:.1f}%'.format(similarity*100)
notify('{:>6} {}', pct, query.name())

if args.save_matches:
outname = args.save_matches.name
Expand Down Expand Up @@ -798,21 +809,17 @@ def build_new_signature(mins):
# figure out what the resolution of the banding on the genome is,
# based either on an explicit --scaled parameter, or on genome
# cardinality (deprecated)
if best_leaf.minhash.max_hash:
R_genome = sourmash_lib.MAX_HASH / \
float(best_leaf.minhash.max_hash)
elif best_leaf.minhash.hll:
genome_size = best_leaf.minhash.hll.estimate_cardinality()
genome_max_hash = max(found_mins)
R_genome = float(genome_size) / float(genome_max_hash)
else:
if not best_leaf.minhash.max_hash:
error('Best hash match in sbt_gather has no cardinality')
error('Please prepare database of sequences with --scaled')
sys.exit(-1)

R_genome = sourmash_lib.MAX_HASH / float(best_leaf.minhash.max_hash)

# pick the highest R / lowest resolution
R_comparison = max(R_metagenome, R_genome)

# CTB: these could probably be replaced by minhash.downsample_scaled.
new_max_hash = sourmash_lib.MAX_HASH / float(R_comparison)
query_mins = set([ i for i in query_mins if i < new_max_hash ])
found_mins = set([ i for i in found_mins if i < new_max_hash ])
Expand Down
30 changes: 25 additions & 5 deletions sourmash_lib/sbtmh.py
Expand Up @@ -2,7 +2,7 @@
from __future__ import division

from .sbt import Leaf
from . import MinHash
from . import _minhash, MinHash


class SigLeaf(Leaf):
Expand Down Expand Up @@ -38,11 +38,21 @@ def data(self, new_data):
self._data = new_data


def search_minhashes(node, sig, threshold, results=None):
def search_minhashes(node, sig, threshold, results=None, downsample=True):
mins = sig.minhash.get_mins()

if isinstance(node, SigLeaf):
matches = node.data.minhash.count_common(sig.minhash)
try:
matches = node.data.minhash.count_common(sig.minhash)
except Exception as e:
if 'mismatch in max_hash' in str(e) and downsample:
xx = sig.minhash.downsample_max_hash(node.data.minhash)
yy = node.data.minhash.downsample_max_hash(sig.minhash)

matches = yy.count_common(xx)
else:
raise

else: # Node or Leaf, Nodegraph by minhash comparison
matches = sum(1 for value in mins if node.data.get(value))

Expand All @@ -55,14 +65,24 @@ def search_minhashes(node, sig, threshold, results=None):


class SearchMinHashesFindBest(object):
def __init__(self):
def __init__(self, downsample=True):
self.best_match = 0.
self.downsample = downsample

def search(self, node, sig, threshold, results=None):
mins = sig.minhash.get_mins()

if isinstance(node, SigLeaf):
matches = node.data.minhash.count_common(sig.minhash)
try:
matches = node.data.minhash.count_common(sig.minhash)
except Exception as e:
if 'mismatch in max_hash' in str(e) and self.downsample:
xx = sig.minhash.downsample_max_hash(node.data.minhash)
yy = node.data.minhash.downsample_max_hash(sig.minhash)

matches = yy.count_common(xx)
else:
raise
else: # Node or Leaf, Nodegraph by minhash comparison
matches = sum(1 for value in mins if node.data.get(value))

Expand Down
16 changes: 14 additions & 2 deletions sourmash_lib/signature.py
Expand Up @@ -101,14 +101,26 @@ def _save(self):
return self.d.get('email'), self.d.get('name'), \
self.d.get('filename'), sketch

def similarity(self, other, ignore_abundance=False):
def similarity(self, other, ignore_abundance=False, downsample=False):
"Compute similarity with the other MinHash signature."
return self.minhash.similarity(other.minhash, ignore_abundance)
try:
return self.minhash.similarity(other.minhash, ignore_abundance)
except ValueError as e:
if 'mismatch in max_hash' in str(e) and downsample:
xx = self.minhash.downsample_max_hash(other.minhash)
yy = other.minhash.downsample_max_hash(self.minhash)
return xx.similarity(yy, ignore_abundance)
else:
raise

def jaccard(self, other):
"Compute Jaccard similarity with the other MinHash signature."
return self.minhash.similarity(other.minhash, True)

def containment(self, other):
"Compute containment by the other signature. Note: ignores abundance."
return self.minhash.containment(other.minhash)


def _guess_open(filename):
"""
Expand Down
28 changes: 26 additions & 2 deletions tests/test_signature.py
Expand Up @@ -36,7 +36,7 @@ def test_roundtrip_empty(track_abundance):

def test_roundtrip_max_hash(track_abundance):
e = sourmash_lib.MinHash(n=1, ksize=20, track_abundance=track_abundance,
max_hash=10)
max_hash=10)
e.add_hash(5)
sig = SourmashSignature('titus@idyll.org', e)
s = save_signatures([sig])
Expand All @@ -52,7 +52,7 @@ def test_roundtrip_max_hash(track_abundance):

def test_roundtrip_seed(track_abundance):
e = sourmash_lib.MinHash(n=1, ksize=20, track_abundance=track_abundance,
seed=10)
seed=10)
e.add_hash(5)
sig = SourmashSignature('titus@idyll.org', e)
s = save_signatures([sig])
Expand Down Expand Up @@ -80,6 +80,30 @@ def test_roundtrip_empty_email(track_abundance):
assert sig2.similarity(sig) == 1.0


def test_similarity_downsample(track_abundance):
e = sourmash_lib.MinHash(n=0, ksize=20, track_abundance=track_abundance,
max_hash=2**63)
f = sourmash_lib.MinHash(n=0, ksize=20, track_abundance=track_abundance,
max_hash=2**2)

e.add_hash(1)
e.add_hash(5)
assert len(e.get_mins()) == 2

f.add_hash(1)
f.add_hash(5) # should be discarded due to max_hash
assert len(f.get_mins()) == 1

ee = SourmashSignature('', e)
ff = SourmashSignature('', f)

with pytest.raises(ValueError): # mismatch in max_hash
ee.similarity(ff)

x = ee.similarity(ff, downsample=True)
assert round(x, 1) == 1.0


def test_md5(track_abundance):
e = sourmash_lib.MinHash(n=1, ksize=20, track_abundance=track_abundance)
e.add_hash(5)
Expand Down