Skip to content

Commit

Permalink
Merge pull request #150 from dib-lab/fix/jaccard
Browse files Browse the repository at this point in the history
[MRG] Fix Jaccard calculation to be intersection over union
  • Loading branch information
ctb committed Apr 18, 2017
2 parents c0e3476 + d650b22 commit 1033479
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 23 deletions.
6 changes: 3 additions & 3 deletions doc/api-example.rst
Expand Up @@ -19,10 +19,10 @@ Create two estimators using 3-mers, and add the sequences:
>>> E2 = sourmash_lib.MinHash(n=20, ksize=3)
>>> E2.add_sequence(seq2)

One of the 3-mers (out of 4) overlaps, so Jaccard index is 1/4:
One of the 3-mers (out of 7) overlaps, so Jaccard index is 1/7:

>>> E1.jaccard(E2)
0.25
>>> round(E1.jaccard(E2), 2)
0.14

and of course the estimators match themselves:

Expand Down
2 changes: 1 addition & 1 deletion sourmash_lib/_minhash.pxd
Expand Up @@ -46,7 +46,7 @@ cdef extern from "kmer_min_hash.hh":
void add_hash(HashIntoType) except +ValueError
void add_word(string word) except +ValueError
void add_sequence(const char *, bool) except +ValueError
void merge(const KmerMinAbundance&) except +ValueError
void merge_abund(const KmerMinAbundance&) except +ValueError
void merge(const KmerMinHash&) except +ValueError
unsigned int count_common(const KmerMinAbundance&) except +ValueError
unsigned long size()
Expand Down
58 changes: 55 additions & 3 deletions sourmash_lib/_minhash.pyx
Expand Up @@ -223,9 +223,61 @@ cdef class MinHash(object):

return n

def downsample_n(self, new_num):
if self.num < new_num:
raise ValueError('new sample n is higher than current sample n')

a = MinHash(new_num, deref(self._this).ksize,
deref(self._this).is_protein, self.track_abundance,
deref(self._this).seed, deref(self._this).max_hash)
a.merge(self)
return a

def compare(self, MinHash other):
n = self.count_common(other)
size = max(deref(self._this).size(), 1)
cdef KmerMinAbundance *mh = NULL;
cdef KmerMinAbundance *other_mh = NULL;
cdef KmerMinAbundance *cmh = NULL;

if self.num != other.num:
err = 'must have same num: {} != {}'.format(self.num,
other.num)
raise TypeError(err)
else:
num = self.num

if self.track_abundance and other.track_abundance:
combined_mh = new KmerMinAbundance(num,
deref(self._this).ksize,
deref(self._this).is_protein,
deref(self._this).seed,
deref(self._this).max_hash)

mh = <KmerMinAbundance*>address(deref(self._this))
other_mh = <KmerMinAbundance*>address(deref(other._this))
cmh = <KmerMinAbundance*>combined_mh

cmh.merge_abund(deref(mh))
cmh.merge_abund(deref(other_mh))

common = set(self.get_mins())
common.intersection_update(other.get_mins())
common.intersection_update([it.first for it in cmh.mins])
n = len(common)
else:
combined_mh = new KmerMinHash(num,
deref(self._this).ksize,
deref(self._this).is_protein,
deref(self._this).seed,
deref(self._this).max_hash)
combined_mh.merge(deref(self._this))
combined_mh.merge(deref(other._this))

common = set(self.get_mins())
common.intersection_update(other.get_mins())
common.intersection_update(combined_mh.mins)
n = len(common)

size = max(combined_mh.size(), 1)
return n / size

def jaccard(self, MinHash other):
Expand Down Expand Up @@ -271,7 +323,7 @@ cdef class MinHash(object):
cdef KmerMinAbundance *mh = <KmerMinAbundance*>address(deref(self._this))
cdef KmerMinAbundance *other_mh = <KmerMinAbundance*>address(deref(other._this))
if self.track_abundance:
mh.merge(deref(other_mh))
deref(mh).merge_abund(deref(other_mh))
else:
deref(self._this).merge(deref(other._this))

Expand Down
2 changes: 1 addition & 1 deletion sourmash_lib/kmer_min_hash.hh
Expand Up @@ -355,7 +355,7 @@ class KmerMinAbundance: public KmerMinHash {
}
}

virtual void merge(const KmerMinAbundance& other) {
virtual void merge_abund(const KmerMinAbundance& other) {
if (ksize != other.ksize) {
throw minhash_exception("different ksizes cannot be merged");
}
Expand Down
Binary file modified tests/test-data/genome-s10.fa.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
36 changes: 30 additions & 6 deletions tests/test__minhash.py
Expand Up @@ -311,15 +311,20 @@ def test_mh_asymmetric(track_abundance):
for i in range(0, 40, 2):
a.add_hash(i)

b = MinHash(10, 10, track_abundance=track_abundance) # different size: 10
# different size: 10
b = MinHash(10, 10, track_abundance=track_abundance)
for i in range(0, 80, 4):
b.add_hash(i)

assert a.count_common(b) == 10
assert b.count_common(a) == 10

with pytest.raises(TypeError):
a.compare(b)

a = a.downsample_n(10)
assert a.compare(b) == 0.5
assert b.compare(a) == 1.0
assert b.compare(a) == 0.5


def test_mh_merge(track_abundance):
Expand Down Expand Up @@ -376,7 +381,8 @@ def test_mh_asymmetric_merge(track_abundance):
for i in range(0, 40, 2):
a.add_hash(i)

b = MinHash(10, 10, track_abundance=track_abundance) # different size: 10
# different size: 10
b = MinHash(10, 10, track_abundance=track_abundance)
for i in range(0, 80, 4):
b.add_hash(i)

Expand All @@ -388,8 +394,17 @@ def test_mh_asymmetric_merge(track_abundance):
assert len(c) == len(a)
assert len(d) == len(b)

# can't compare different sizes without downsampling
with pytest.raises(TypeError):
d.compare(a)

a = a.downsample_n(d.num)
print(a.get_mins())
print(d.get_mins())
assert d.compare(a) == 1.0
assert c.compare(b) == 0.5

c = c.downsample_n(b.num)
assert c.compare(b) == 1.0


def test_mh_inplace_concat_asymmetric(track_abundance):
Expand All @@ -398,7 +413,8 @@ def test_mh_inplace_concat_asymmetric(track_abundance):
for i in range(0, 40, 2):
a.add_hash(i)

b = MinHash(10, 10, track_abundance=track_abundance) # different size: 10
# different size: 10
b = MinHash(10, 10, track_abundance=track_abundance)
for i in range(0, 80, 4):
b.add_hash(i)

Expand All @@ -413,7 +429,15 @@ def test_mh_inplace_concat_asymmetric(track_abundance):
assert len(c) == len(a)
assert len(d) == len(b)

assert d.compare(a) == 1.0
try:
d.compare(a)
except TypeError as exc:
assert 'must have same num' in str(exc)

a = a.downsample_n(d.num)
assert d.compare(a) == 1.0 # see: d += a, above.

c = c.downsample_n(b.num)
assert c.compare(b) == 0.5


Expand Down
47 changes: 43 additions & 4 deletions tests/test_estimators.py
Expand Up @@ -7,6 +7,7 @@

import pytest
from sourmash_lib import MinHash
from . import sourmash_tst_utils as utils

# below, 'track_abundance' is toggled to both True and False by py.test --
# see conftest.py.
Expand All @@ -21,8 +22,11 @@ def test_jaccard_1(track_abundance):
for i in [1, 2, 3, 4, 6]:
E2.add_hash(i)

assert round(E1.jaccard(E2), 2) == 4 / 5.0
assert round(E2.jaccard(E1), 2) == 4 / 5.0
# here the union is [1, 2, 3, 4, 5]
# and the intesection is [1, 2, 3, 4] => 4/5.

assert round(E1.jaccard(E2), 2) == round(4 / 5.0, 2)
assert round(E2.jaccard(E1), 2) == round(4 / 5.0, 2)


def test_jaccard_2_difflen(track_abundance):
Expand All @@ -34,8 +38,9 @@ def test_jaccard_2_difflen(track_abundance):
for i in [1, 2, 3, 4]:
E2.add_hash(i)

print(E1.jaccard(E2))
assert round(E1.jaccard(E2), 2) == 4 / 5.0
assert round(E2.jaccard(E1), 2) == 4 / 4.0
assert round(E2.jaccard(E1), 2) == 4 / 5.0


def test_common_1(track_abundance):
Expand Down Expand Up @@ -153,7 +158,7 @@ def test_abund_similarity():
assert round(E1.similarity(E2), 2) == 0.5

assert round(E1.similarity(E1, ignore_abundance=True)) == 1.0
assert round(E1.similarity(E2, ignore_abundance=True), 2) == 1.0
assert round(E1.similarity(E2, ignore_abundance=True), 2) == 0.5


def test_abund_similarity_zero():
Expand All @@ -164,3 +169,37 @@ def test_abund_similarity_zero():
E1.add_hash(i)

assert E1.similarity(E2) == 0.0


####

def test_jaccard_on_real_data():
from sourmash_lib.signature import load_signatures

afile = 'n10000/GCF_000005845.2_ASM584v2_genomic.fna.gz.sig.gz'
a = utils.get_test_data(afile)
sig1 = list(load_signatures(a))[0]
mh1 = sig1.estimator

bfile = 'n10000/GCF_000006945.1_ASM694v1_genomic.fna.gz.sig.gz'
b = utils.get_test_data(bfile)
sig2 = list(load_signatures(b))[0]
mh2 = sig2.estimator

assert mh1.compare(mh2) == 0.0183
assert mh2.compare(mh1) == 0.0183

mh1 = mh1.downsample_n(1000)
mh2 = mh2.downsample_n(1000)
assert mh1.compare(mh2) == 0.011
assert mh2.compare(mh1) == 0.011

mh1 = mh1.downsample_n(100)
mh2 = mh2.downsample_n(100)
assert mh1.compare(mh2) == 0.01
assert mh2.compare(mh1) == 0.01

mh1 = mh1.downsample_n(10)
mh2 = mh2.downsample_n(10)
assert mh1.compare(mh2) == 0.0
assert mh2.compare(mh1) == 0.0
19 changes: 14 additions & 5 deletions tests/test_sourmash.py
Expand Up @@ -441,7 +441,7 @@ def test_do_sourmash_check_protein_comparisons():
print(name2, name4, round(sig2_aa.similarity(sig2_trans), 3))

assert round(sig1_aa.similarity(sig1_trans), 3) == 0.0
assert round(sig2_aa.similarity(sig1_trans), 3) == 0.273
assert round(sig2_aa.similarity(sig1_trans), 3) == 0.166
assert round(sig1_aa.similarity(sig2_trans), 3) == 0.174
assert round(sig2_aa.similarity(sig2_trans), 3) == 0.0

Expand Down Expand Up @@ -594,7 +594,7 @@ def test_search():
in_directory=location)
print(status, out, err)
assert '1 matches' in err
assert '0.958' in out
assert '0.930' in out


def test_search_gzip():
Expand All @@ -619,7 +619,7 @@ def test_search_gzip():
in_directory=location)
print(status, out, err)
assert '1 matches' in err
assert '0.958' in out
assert '0.930' in out


def test_search_2():
Expand All @@ -640,7 +640,8 @@ def test_search_2():
in_directory=location)
print(status, out, err)
assert '2 matches' in err
assert '0.958' in out
assert '0.930' in out
assert '0.896' in out


def test_mash_csv_to_sig():
Expand All @@ -654,7 +655,8 @@ def test_mash_csv_to_sig():
in_directory=location)

status, out, err = utils.runscript('sourmash',
['compute', '-k', '31', testdata2],
['compute', '-k', '31',
'-n', '970', testdata2],
in_directory=location)

status, out, err = utils.runscript('sourmash',
Expand All @@ -663,6 +665,7 @@ def test_mash_csv_to_sig():
in_directory=location)
print(status, out, err)
assert '1 matches; showing 3:' in err
assert 'short.fa \t 1.000 \t xxx.sig' in out


def test_do_sourmash_sbt_index_bad_args():
Expand Down Expand Up @@ -1092,6 +1095,12 @@ def test_sbt_categorize():
'--ksize', '21', '--dna', '--csv', 'out.csv']
status, out, err = utils.runscript('sourmash', args,
in_directory=location)

print(out)
print(err)

# mash dist genome-s10.fa.gz genome-s10+s11.fa.gz
# yields 521/1000 ==> ~0.5
assert 'for s10+s11, found: 0.50 genome-s10.fa.gz' in err

out_csv = open(os.path.join(location, 'out.csv')).read()
Expand Down

0 comments on commit 1033479

Please sign in to comment.