Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Count-Min Sketches as internal nodes in SBTs #505

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 13 additions & 13 deletions sourmash_lib/commands.py
Expand Up @@ -466,7 +466,7 @@ def dump(args):


def sbt_combine(args):
from sourmash_lib.sbt import SBT, GraphFactory
from sourmash_lib.sbt import SBT
from sourmash_lib.sbtmh import SigLeaf

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -495,8 +495,8 @@ def sbt_combine(args):


def index(args):
from sourmash_lib.sbt import SBT, GraphFactory
from sourmash_lib.sbtmh import search_minhashes, SigLeaf
from sourmash_lib.sbt import SBT, GraphFactory, CountgraphFactory
from sourmash_lib.sbtmh import SigLeaf

parser = argparse.ArgumentParser()
parser.add_argument('sbt_name', help='name to save SBT into')
Expand All @@ -511,7 +511,9 @@ def index(args):
parser.add_argument('--append', action='store_true', default=False,
help='add signatures to an existing SBT.')
parser.add_argument('-x', '--bf-size', type=float, default=1e5,
help='Bloom filter size used for internal nodes.')
help='Size used for internal nodes.')
parser.add_argument('-c', '--countgraph', action='store_true',
help='Use a Count-Min Sketch for internal nodes.')

sourmash_args.add_moltype_args(parser)

Expand All @@ -522,15 +524,17 @@ def index(args):
if args.append:
tree = SBT.load(args.sbt_name, leaf_loader=SigLeaf.load)
else:
factory = GraphFactory(1, args.bf_size, 4)
if args.countgraph:
factory = CountgraphFactory(1, args.bf_size, 4)
else:
factory = GraphFactory(1, args.bf_size, 4)
tree = SBT(factory)

if args.traverse_directory:
inp_files = list(sourmash_args.traverse_find_sigs(args.signatures))
else:
inp_files = list(args.signatures)


notify('loading {} files into SBT', len(inp_files))

n = 0
Expand Down Expand Up @@ -567,9 +571,7 @@ def index(args):


def search(args):
from sourmash_lib.sbt import SBT, GraphFactory
from sourmash_lib.sbtmh import search_minhashes, SigLeaf
from sourmash_lib.sbtmh import SearchMinHashesFindBest
from sourmash_lib.sbtmh import search_minhashes, SearchMinHashesFindBest

parser = argparse.ArgumentParser()
parser.add_argument('query', help='query signature')
Expand Down Expand Up @@ -713,7 +715,7 @@ def search(args):


def categorize(args):
from sourmash_lib.sbt import SBT, GraphFactory
from sourmash_lib.sbt import SBT
from sourmash_lib.sbtmh import search_minhashes, SigLeaf
from sourmash_lib.sbtmh import SearchMinHashesFindBest

Expand Down Expand Up @@ -791,8 +793,6 @@ def categorize(args):


def gather(args):
from sourmash_lib.sbt import SBT, GraphFactory
from sourmash_lib.sbtmh import search_minhashes, SigLeaf
from sourmash_lib.sbtmh import SearchMinHashesFindBestIgnoreMaxHash

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -1036,7 +1036,7 @@ def format_bp(bp):

def watch(args):
"Build a signature from raw FASTA/FASTQ coming in on stdin, search."
from sourmash_lib.sbt import SBT, GraphFactory
from sourmash_lib.sbt import SBT
from sourmash_lib.sbtmh import search_minhashes, SigLeaf
from sourmash_lib.sbtmh import SearchMinHashesFindBest

Expand Down
71 changes: 60 additions & 11 deletions sourmash_lib/sbt.py
Expand Up @@ -70,6 +70,15 @@ def create_nodegraph():
return create_nodegraph


def CountgraphFactory(ksize, starting_size, n_tables):
"Build new countgraphs (Count-Min sketches) of a specific (fixed) size."

def create_graph():
return khmer.Countgraph(ksize, starting_size, n_tables)

return create_graph


class SBT(object):

def __init__(self, factory, d=2):
Expand Down Expand Up @@ -165,7 +174,7 @@ def child(self, parent, pos):
return NodePos(cd, self.nodes[cd])

def save(self, tag):
version = 2
version = 3
basetag = os.path.basename(tag)
dirprefix = os.path.dirname(tag)
dirname = os.path.join(dirprefix, '.sbt.' + basetag)
Expand All @@ -176,6 +185,9 @@ def save(self, tag):
info = {}
info['d'] = self.d
info['version'] = version
info['factory'] = 'nodegraph'
if self.factory.__name__ == 'create_graph':
info['factory'] = 'countgraph'

structure = {}
for i, node in iter(self):
Expand Down Expand Up @@ -210,16 +222,9 @@ def load(cls, sbt_name, leaf_loader=None):
loaders = {
1: cls._load_v1,
2: cls._load_v2,
3: cls._load_v3,
}

# @CTB hack: check to make sure khmer Nodegraph supports the
# correct methods.
x = khmer.Nodegraph(1, 1, 1)
try:
x.count(10)
except TypeError:
raise Exception("khmer version is too old; need >= 2.1.")

if leaf_loader is None:
leaf_loader = Leaf.load

Expand Down Expand Up @@ -295,6 +300,41 @@ def _load_v2(cls, info, leaf_loader, dirname):

return tree

@classmethod
def _load_v3(cls, info, leaf_loader, dirname):
nodes = {int(k): v for (k, v) in info['nodes'].items()}

if nodes[0] is None:
raise ValueError("Empty tree!")

sbt_nodes = defaultdict(lambda: None)

if info['factory'] == 'countgraph':
sample_cg = os.path.join(dirname, nodes[0]['filename'])
k, size, ntables = khmer.extract_countgraph_info(sample_cg)[:3]
factory = CountgraphFactory(k, size, ntables)
else: # Defaults to nodegraph
sample_bf = os.path.join(dirname, nodes[0]['filename'])
k, size, ntables = khmer.extract_nodegraph_info(sample_bf)[:3]
factory = GraphFactory(k, size, ntables)

for k, node in nodes.items():
if node is None:
continue

if 'internal' in node['filename']:
node['factory'] = factory
sbt_node = Node.load(node, dirname)
else:
sbt_node = leaf_loader(node, dirname)

sbt_nodes[k] = sbt_node

tree = cls(factory, d=info['d'])
tree.nodes = sbt_nodes

return tree

def print_dot(self):
print("""
digraph G {
Expand Down Expand Up @@ -397,7 +437,13 @@ def data(self):
if self._filename is None:
self._data = self._factory()
else:
self._data = khmer.load_nodegraph(self._filename)
# TODO: add a khmer.load() method to khmer,
# detecting if the file is from a nodegraph or countgraph
# automatically...
if self._factory.__name__ == 'create_graph':
self._data = khmer.load_countgraph(self._filename)
else:
self._data = khmer.load_nodegraph(self._filename)
return self._data

@data.setter
Expand Down Expand Up @@ -433,7 +479,10 @@ def __str__(self):
def data(self):
if self._data is None:
# TODO: what if self._filename is None?
self._data = khmer.load_nodegraph(self._filename)
if self._factory.__name__ == 'create_graph':
self._data = khmer.load_countgraph(self._filename)
else:
self._data = khmer.load_nodegraph(self._filename)
return self._data

@data.setter
Expand Down
6 changes: 6 additions & 0 deletions sourmash_lib/sbtmh.py
Expand Up @@ -22,7 +22,13 @@ def save(self, filename):
signature.save_signatures([self.data], fp)

def update(self, parent):
# TODO: check if track_abundance is set!
for v in self.data.minhash.get_mins():
# TODO: if tracking abundance, count up to
# max(abundance, current_value_at_internal_node)
# if not tracking abundance,
# behave like a nodegraph (only count if hash not at internal
# node yet)
parent.data.count(v)

@property
Expand Down
30 changes: 30 additions & 0 deletions tests/test_sourmash.py
Expand Up @@ -22,6 +22,7 @@
from sourmash_lib import signature
from sourmash_lib import VERSION


def test_run_sourmash():
status, out, err = utils.runscript('sourmash', [], fail_ok=True)
assert status != 0 # no args provided, ok ;)
Expand All @@ -32,6 +33,7 @@ def test_run_sourmash_badcmd():
assert status != 0 # bad arg!
assert "Unrecognized command" in err


def test_sourmash_info():
status, out, err = utils.runscript('sourmash', ['info'], fail_ok=False)

Expand Down Expand Up @@ -1106,6 +1108,34 @@ def test_do_sourmash_index_bad_args():
assert status != 0


def test_do_sourmash_scmst_search():
with utils.TempDirectory() as location:
testdata1 = utils.get_test_data('short.fa')
testdata2 = utils.get_test_data('short2.fa')
status, out, err = utils.runscript('sourmash',
['compute', testdata1, testdata2],
in_directory=location)

status, out, err = utils.runscript('sourmash',
['index', '-k', '31',
'--countgraph',
'zzz',
'short.fa.sig',
'short2.fa.sig'],
in_directory=location)

assert os.path.exists(os.path.join(location, 'zzz.sbt.json'))

status, out, err = utils.runscript('sourmash',
['search', 'short.fa.sig',
'zzz'],
in_directory=location)
print(out)

assert 'short.fa' in out
assert 'short2.fa' in out


def test_do_sourmash_sbt_search():
with utils.TempDirectory() as location:
testdata1 = utils.get_test_data('short.fa')
Expand Down