Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add a MultiIndex class that wraps multiple Index classes. #1374

Merged
merged 24 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
92e5fdc
add an IndexOfIndexes class
ctb Mar 6, 2021
5c71e11
rename to MultiIndex
ctb Mar 7, 2021
85efdaf
switch to using MultiIndex for loading from a directory
ctb Mar 7, 2021
04f9de1
some more MultiIndex tests
ctb Mar 7, 2021
201a89a
add test of MultiIndex.signatures
ctb Mar 7, 2021
07d2c32
add docstring for MultiIndex
ctb Mar 7, 2021
61d15c3
stop special-casing SIGLISTs
ctb Mar 7, 2021
16f9ee2
fix test to match more informative error message
ctb Mar 7, 2021
c6bf314
switch to using LinearIndex.load for stdin, too
ctb Mar 7, 2021
dd0f3b8
add __len__ to MultiIndex
ctb Mar 8, 2021
9211a74
add check_csv to check for appropriate filename loading info
ctb Mar 8, 2021
75069ff
add comment
ctb Mar 8, 2021
d2294fb
Merge branch 'latest' of github.com:dib-lab/sourmash into add/multi_i…
ctb Mar 9, 2021
9f39623
fix databases load
ctb Mar 9, 2021
ac63cf8
more tests needed
ctb Mar 9, 2021
d5059eb
Merge branch 'latest' into add/multi_index
ctb Mar 9, 2021
3e06dbf
Merge branch 'latest' of github.com:dib-lab/sourmash into add/multi_i…
ctb Mar 9, 2021
5590d70
add tests for incompatible signatures
ctb Mar 9, 2021
14891bd
add filter to LinearIndex and MultiIndex
ctb Mar 9, 2021
40395ff
clean up sourmash_args some more
ctb Mar 9, 2021
8c51452
Merge branch 'latest' of github.com:dib-lab/sourmash into add/multi_i…
ctb Mar 9, 2021
fbf3bb9
Merge branch 'latest' into add/multi_index
ctb Mar 12, 2021
dd52be6
Merge branch 'latest' of github.com:dib-lab/sourmash into add/multi_i…
ctb Mar 24, 2021
d99828d
Merge branch 'latest' into add/multi_index
ctb Mar 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 93 additions & 6 deletions src/sourmash/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def select(self, ksize=None, moltype=None):
""

class LinearIndex(Index):
"An Index for a collection of signatures. Can load from a .sig file."
def __init__(self, _signatures=None, filename=None):
self._signatures = []
if _signatures:
Expand Down Expand Up @@ -155,11 +156,97 @@ def load(cls, location):
return lidx

def select(self, ksize=None, moltype=None):
def select_sigs(siglist, ksize, moltype):
for ss in siglist:
if (ksize is None or ss.minhash.ksize == ksize) and \
(moltype is None or ss.minhash.moltype == moltype):
yield ss
def select_sigs(ss, ksize=ksize, moltype=moltype):
if (ksize is None or ss.minhash.ksize == ksize) and \
(moltype is None or ss.minhash.moltype == moltype):
return True

return self.filter(select_sigs)

def filter(self, filter_fn):
siglist = []
for ss in self._signatures:
if filter_fn(ss):
siglist.append(ss)

siglist=select_sigs(self._signatures, ksize, moltype)
return LinearIndex(siglist, self.filename)


class MultiIndex(Index):
"""An Index class that wraps other Index classes.

The MultiIndex constructor takes two arguments: a list of Index
objects, and a matching list of sources (filenames, etc.) If the
source is not None, then it will be used to override the 'filename'
in the triple that is returned by search and gather.

One specific use for this is when loading signatures from a directory;
MultiIndex will properly record which files provided which signatures.
"""
def __init__(self, index_list, source_list):
self.index_list = list(index_list)
self.source_list = list(source_list)
assert len(index_list) == len(source_list)

def signatures(self):
for idx in self.index_list:
for ss in idx.signatures():
yield ss

def __len__(self):
return sum([ len(idx) for idx in self.index_list ])

def insert(self, *args):
raise NotImplementedError

@classmethod
def load(self, *args):
raise NotImplementedError

def save(self, *args):
raise NotImplementedError

def select(self, ksize=None, moltype=None):
new_idx_list = []
new_src_list = []
for idx, src in zip(self.index_list, self.source_list):
idx = idx.select(ksize=ksize, moltype=moltype)
new_idx_list.append(idx)
new_src_list.append(src)

return MultiIndex(new_idx_list, new_src_list)

def filter(self, filter_fn):
new_idx_list = []
new_src_list = []
for idx, src in zip(self.index_list, self.source_list):
idx = idx.filter(filter_fn)
new_idx_list.append(idx)
new_src_list.append(src)

return MultiIndex(new_idx_list, new_src_list)

def search(self, query, *args, **kwargs):
# do the actual search:
matches = []
for idx, src in zip(self.index_list, self.source_list):
for (score, ss, filename) in idx.search(query, *args, **kwargs):
best_src = src or filename # override if src provided
matches.append((score, ss, best_src))

# sort!
matches.sort(key=lambda x: -x[0])
return matches

def gather(self, query, *args, **kwargs):
"Return the match with the best Jaccard containment in the Index."
# actually do search!
results = []
for idx, src in zip(self.index_list, self.source_list):
for (score, ss, filename) in idx.gather(query, *args, **kwargs):
best_src = src or filename # override if src provided
results.append((score, ss, best_src))

results.sort(reverse=True, key=lambda x: (x[0], x[1].md5sum()))

return results
153 changes: 52 additions & 101 deletions src/sourmash/sourmash_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from . import signature
from .logging import notify, error

from .index import LinearIndex
from .index import LinearIndex, MultiIndex
from . import signature as sig
from .sbt import SBT
from .sbtmh import SigLeaf
Expand Down Expand Up @@ -181,16 +181,7 @@ def traverse_find_sigs(filenames, yield_all_files=False):
yield fullname


def filter_compatible_signatures(query, siglist, force=False):
for ss in siglist:
if check_signatures_are_compatible(query, ss):
yield ss
else:
if not force:
raise ValueError("incompatible signature")


def check_signatures_are_compatible(query, subject):
def _check_signatures_are_compatible(query, subject):
# is one scaled, and the other not? cannot do search
if query.minhash.scaled and not subject.minhash.scaled or \
not query.minhash.scaled and subject.minhash.scaled:
Expand Down Expand Up @@ -275,20 +266,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
sys.exit(-1)

# are we collecting signatures from a directory/path?
# NOTE: error messages about loading will now be attributed to
# directory, not individual file.
if os.path.isdir(filename):
assert dbtype == DatabaseType.SIGLIST

siglist = _select_sigs(db, moltype=query_moltype, ksize=query_ksize)
siglist = filter_compatible_signatures(query, siglist, True)
linear = LinearIndex(siglist, filename=filename)
databases.append(linear)

n_signatures += len(linear)

# SBT
elif dbtype == DatabaseType.SBT:
if dbtype == DatabaseType.SBT:
if not check_tree_is_compatible(filename, db, query,
is_similarity_query):
sys.exit(-1)
Expand All @@ -301,7 +279,6 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
elif dbtype == DatabaseType.LCA:
if not check_lca_db_is_compatible(filename, db, query):
sys.exit(-1)
query_scaled = query.minhash.scaled

notify('loaded LCA {}', filename, end='\r')
n_databases += 1
Expand All @@ -310,26 +287,19 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)

# signature file
elif dbtype == DatabaseType.SIGLIST:
siglist = _select_sigs(db, moltype=query_moltype, ksize=query_ksize)
try:
# CTB: it's not clear to me that filter_compatible_signatures
# should fail here, on incompatible signatures; but that's
# what we have it doing currently. Revisit.
siglist = filter_compatible_signatures(query, siglist, False)
siglist = list(siglist)
except ValueError:
siglist = []

if not siglist:
notify("no compatible signatures found in '{}'", filename)
db = db.select(moltype=query_moltype, ksize=query_ksize)
siglist = db.signatures()
filter_fn = lambda s: _check_signatures_are_compatible(query, s)
db = db.filter(filter_fn)

if not db:
notify(f"no compatible signatures found in '{filename}'")
sys.exit(-1)

linear = LinearIndex(siglist, filename=filename)
databases.append(linear)
databases.append(db)

notify('loaded {} signatures from {}', len(linear),
filename, end='\r')
n_signatures += len(linear)
notify(f'loaded {len(db)} signatures from {filename}', end='\r')
n_signatures += len(db)

# unknown!?
else:
Expand Down Expand Up @@ -374,56 +344,58 @@ def _load_database(filename, traverse_yield_all, *, cache_size=None):

# special case stdin
if not loaded and filename == '-':
db = signature.load_signatures(sys.stdin, do_raise=True)
db = list(db)
loaded = True
db = LinearIndex.load(sys.stdin)
dbtype = DatabaseType.SIGLIST
loaded = True

# load signatures from directory
# load signatures from directory, using MultiIndex to preserve source.
if not loaded and os.path.isdir(filename):
all_sigs = []
index_list = []
source_list = []
for thisfile in traverse_find_sigs([filename], traverse_yield_all):
try:
with open(thisfile, 'rt') as fp:
x = signature.load_signatures(fp, do_raise=True)
siglist = list(x)
all_sigs.extend(siglist)
idx = LinearIndex.load(thisfile)
index_list.append(idx)
source_list.append(thisfile)
except (IOError, sourmash.exceptions.SourmashError):
if traverse_yield_all:
continue
else:
raise

loaded=True
db = all_sigs
dbtype = DatabaseType.SIGLIST

# load signatures from single file
try:
# CTB: could make this a generator, with some trickery; but for
# now, just force into list.
with open(filename, 'rt') as fp:
db = signature.load_signatures(fp, do_raise=True)
db = list(db)
if index_list:
loaded=True
db = MultiIndex(index_list, source_list)
dbtype = DatabaseType.SIGLIST

loaded = True
dbtype = DatabaseType.SIGLIST
except Exception as exc:
pass
# load signatures from single signature file
if not loaded:
try:
with open(filename, 'rt') as fp:
db = LinearIndex.load(filename)
dbtype = DatabaseType.SIGLIST
loaded = True
except Exception as exc:
pass

# try load signatures from single file (list of signature paths)
# use MultiIndex to preserve source filenames.
if not loaded:
try:
db = []
with open(filename, 'rt') as fp:
for line in fp:
line = line.strip()
if line:
sigs = load_file_as_signatures(line)
db += list(sigs)
idx_list = []
src_list = []

loaded = True
file_list = load_file_list_of_signatures(filename)
for fname in file_list:
idx = load_file_as_index(fname)
src = fname

idx_list.append(idx)
src_list.append(src)

db = MultiIndex(idx_list, src_list)
dbtype = DatabaseType.SIGLIST
loaded = True
except Exception as exc:
pass

Expand Down Expand Up @@ -461,19 +433,11 @@ def _load_database(filename, traverse_yield_all, *, cache_size=None):
raise OSError("Error while reading signatures from '{}' - got sequences instead! Is this a FASTA/FASTQ file?".format(filename))

if not loaded:
raise OSError("Error while reading signatures from '{}'.".format(filename))
raise OSError(f"Error while reading signatures from '{filename}'.")

return db, dbtype


# note: dup from index.py internal function.
def _select_sigs(siglist, ksize, moltype):
for ss in siglist:
if (ksize is None or ss.minhash.ksize == ksize) and \
(moltype is None or ss.minhash.moltype == moltype):
yield ss


def load_file_as_index(filename, yield_all_files=False):
"""Load 'filename' as a database; generic database loader.

Expand All @@ -488,14 +452,7 @@ def load_file_as_index(filename, yield_all_files=False):
attempt to load all files.
"""
db, dbtype = _load_database(filename, yield_all_files)
if dbtype in (DatabaseType.LCA, DatabaseType.SBT):
return db # already an index!
elif dbtype == DatabaseType.SIGLIST:
# turn siglist into a LinearIndex
idx = LinearIndex(db, filename)
return idx
else:
assert 0 # unknown enum!?
return db


def load_file_as_signatures(filename, select_moltype=None, ksize=None,
Expand All @@ -519,21 +476,15 @@ def load_file_as_signatures(filename, select_moltype=None, ksize=None,
progress.notify(filename)

db, dbtype = _load_database(filename, yield_all_files)

loader = None
if dbtype in (DatabaseType.LCA, DatabaseType.SBT):
db = db.select(moltype=select_moltype, ksize=ksize)
loader = db.signatures()
elif dbtype == DatabaseType.SIGLIST:
loader = _select_sigs(db, moltype=select_moltype, ksize=ksize)
else:
assert 0 # unknown enum!?
db = db.select(moltype=select_moltype, ksize=ksize)
loader = db.signatures()

if progress:
return progress.start_file(filename, loader)
else:
return loader


def load_file_list_of_signatures(filename):
"Load a list-of-files text file."
try:
Expand Down
Loading