Skip to content

Commit

Permalink
[MRG] Add a MultiIndex class that wraps multiple Index classes. (#1374
Browse files Browse the repository at this point in the history
)

* add an IndexOfIndexes class

* rename to MultiIndex

* switch to using MultiIndex for loading from a directory

* some more MultiIndex tests

* add test of MultiIndex.signatures

* add docstring for MultiIndex

* stop special-casing SIGLISTs

* fix test to match more informative error message

* switch to using LinearIndex.load for stdin, too

* add __len__ to MultiIndex

* add check_csv to check for appropriate filename loading info

* add comment

* fix databases load

* more tests needed

* add tests for incompatible signatures

* add filter to LinearIndex and MultiIndex

* clean up sourmash_args some more
  • Loading branch information
ctb committed Mar 26, 2021
1 parent 202bdc5 commit 7ed6291
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 109 deletions.
99 changes: 93 additions & 6 deletions src/sourmash/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def select(self, ksize=None, moltype=None):
""

class LinearIndex(Index):
"An Index for a collection of signatures. Can load from a .sig file."
def __init__(self, _signatures=None, filename=None):
self._signatures = []
if _signatures:
Expand Down Expand Up @@ -155,11 +156,97 @@ def load(cls, location):
return lidx

def select(self, ksize=None, moltype=None):
def select_sigs(siglist, ksize, moltype):
for ss in siglist:
if (ksize is None or ss.minhash.ksize == ksize) and \
(moltype is None or ss.minhash.moltype == moltype):
yield ss
def select_sigs(ss, ksize=ksize, moltype=moltype):
if (ksize is None or ss.minhash.ksize == ksize) and \
(moltype is None or ss.minhash.moltype == moltype):
return True

return self.filter(select_sigs)

def filter(self, filter_fn):
siglist = []
for ss in self._signatures:
if filter_fn(ss):
siglist.append(ss)

siglist=select_sigs(self._signatures, ksize, moltype)
return LinearIndex(siglist, self.filename)


class MultiIndex(Index):
"""An Index class that wraps other Index classes.
The MultiIndex constructor takes two arguments: a list of Index
objects, and a matching list of sources (filenames, etc.) If the
source is not None, then it will be used to override the 'filename'
in the triple that is returned by search and gather.
One specific use for this is when loading signatures from a directory;
MultiIndex will properly record which files provided which signatures.
"""
def __init__(self, index_list, source_list):
self.index_list = list(index_list)
self.source_list = list(source_list)
assert len(index_list) == len(source_list)

def signatures(self):
for idx in self.index_list:
for ss in idx.signatures():
yield ss

def __len__(self):
return sum([ len(idx) for idx in self.index_list ])

def insert(self, *args):
raise NotImplementedError

@classmethod
def load(self, *args):
raise NotImplementedError

def save(self, *args):
raise NotImplementedError

def select(self, ksize=None, moltype=None):
new_idx_list = []
new_src_list = []
for idx, src in zip(self.index_list, self.source_list):
idx = idx.select(ksize=ksize, moltype=moltype)
new_idx_list.append(idx)
new_src_list.append(src)

return MultiIndex(new_idx_list, new_src_list)

def filter(self, filter_fn):
new_idx_list = []
new_src_list = []
for idx, src in zip(self.index_list, self.source_list):
idx = idx.filter(filter_fn)
new_idx_list.append(idx)
new_src_list.append(src)

return MultiIndex(new_idx_list, new_src_list)

def search(self, query, *args, **kwargs):
# do the actual search:
matches = []
for idx, src in zip(self.index_list, self.source_list):
for (score, ss, filename) in idx.search(query, *args, **kwargs):
best_src = src or filename # override if src provided
matches.append((score, ss, best_src))

# sort!
matches.sort(key=lambda x: -x[0])
return matches

def gather(self, query, *args, **kwargs):
"Return the match with the best Jaccard containment in the Index."
# actually do search!
results = []
for idx, src in zip(self.index_list, self.source_list):
for (score, ss, filename) in idx.gather(query, *args, **kwargs):
best_src = src or filename # override if src provided
results.append((score, ss, best_src))

results.sort(reverse=True, key=lambda x: (x[0], x[1].md5sum()))

return results
153 changes: 52 additions & 101 deletions src/sourmash/sourmash_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from . import signature
from .logging import notify, error

from .index import LinearIndex
from .index import LinearIndex, MultiIndex
from . import signature as sig
from .sbt import SBT
from .sbtmh import SigLeaf
Expand Down Expand Up @@ -181,16 +181,7 @@ def traverse_find_sigs(filenames, yield_all_files=False):
yield fullname


def filter_compatible_signatures(query, siglist, force=False):
for ss in siglist:
if check_signatures_are_compatible(query, ss):
yield ss
else:
if not force:
raise ValueError("incompatible signature")


def check_signatures_are_compatible(query, subject):
def _check_signatures_are_compatible(query, subject):
# is one scaled, and the other not? cannot do search
if query.minhash.scaled and not subject.minhash.scaled or \
not query.minhash.scaled and subject.minhash.scaled:
Expand Down Expand Up @@ -275,20 +266,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
sys.exit(-1)

# are we collecting signatures from a directory/path?
# NOTE: error messages about loading will now be attributed to
# directory, not individual file.
if os.path.isdir(filename):
assert dbtype == DatabaseType.SIGLIST

siglist = _select_sigs(db, moltype=query_moltype, ksize=query_ksize)
siglist = filter_compatible_signatures(query, siglist, True)
linear = LinearIndex(siglist, filename=filename)
databases.append(linear)

n_signatures += len(linear)

# SBT
elif dbtype == DatabaseType.SBT:
if dbtype == DatabaseType.SBT:
if not check_tree_is_compatible(filename, db, query,
is_similarity_query):
sys.exit(-1)
Expand All @@ -301,7 +279,6 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)
elif dbtype == DatabaseType.LCA:
if not check_lca_db_is_compatible(filename, db, query):
sys.exit(-1)
query_scaled = query.minhash.scaled

notify('loaded LCA {}', filename, end='\r')
n_databases += 1
Expand All @@ -310,26 +287,19 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None)

# signature file
elif dbtype == DatabaseType.SIGLIST:
siglist = _select_sigs(db, moltype=query_moltype, ksize=query_ksize)
try:
# CTB: it's not clear to me that filter_compatible_signatures
# should fail here, on incompatible signatures; but that's
# what we have it doing currently. Revisit.
siglist = filter_compatible_signatures(query, siglist, False)
siglist = list(siglist)
except ValueError:
siglist = []

if not siglist:
notify("no compatible signatures found in '{}'", filename)
db = db.select(moltype=query_moltype, ksize=query_ksize)
siglist = db.signatures()
filter_fn = lambda s: _check_signatures_are_compatible(query, s)
db = db.filter(filter_fn)

if not db:
notify(f"no compatible signatures found in '{filename}'")
sys.exit(-1)

linear = LinearIndex(siglist, filename=filename)
databases.append(linear)
databases.append(db)

notify('loaded {} signatures from {}', len(linear),
filename, end='\r')
n_signatures += len(linear)
notify(f'loaded {len(db)} signatures from {filename}', end='\r')
n_signatures += len(db)

# unknown!?
else:
Expand Down Expand Up @@ -374,56 +344,58 @@ def _load_database(filename, traverse_yield_all, *, cache_size=None):

# special case stdin
if not loaded and filename == '-':
db = signature.load_signatures(sys.stdin, do_raise=True)
db = list(db)
loaded = True
db = LinearIndex.load(sys.stdin)
dbtype = DatabaseType.SIGLIST
loaded = True

# load signatures from directory
# load signatures from directory, using MultiIndex to preserve source.
if not loaded and os.path.isdir(filename):
all_sigs = []
index_list = []
source_list = []
for thisfile in traverse_find_sigs([filename], traverse_yield_all):
try:
with open(thisfile, 'rt') as fp:
x = signature.load_signatures(fp, do_raise=True)
siglist = list(x)
all_sigs.extend(siglist)
idx = LinearIndex.load(thisfile)
index_list.append(idx)
source_list.append(thisfile)
except (IOError, sourmash.exceptions.SourmashError):
if traverse_yield_all:
continue
else:
raise

loaded=True
db = all_sigs
dbtype = DatabaseType.SIGLIST

# load signatures from single file
try:
# CTB: could make this a generator, with some trickery; but for
# now, just force into list.
with open(filename, 'rt') as fp:
db = signature.load_signatures(fp, do_raise=True)
db = list(db)
if index_list:
loaded=True
db = MultiIndex(index_list, source_list)
dbtype = DatabaseType.SIGLIST

loaded = True
dbtype = DatabaseType.SIGLIST
except Exception as exc:
pass
# load signatures from single signature file
if not loaded:
try:
with open(filename, 'rt') as fp:
db = LinearIndex.load(filename)
dbtype = DatabaseType.SIGLIST
loaded = True
except Exception as exc:
pass

# try load signatures from single file (list of signature paths)
# use MultiIndex to preserve source filenames.
if not loaded:
try:
db = []
with open(filename, 'rt') as fp:
for line in fp:
line = line.strip()
if line:
sigs = load_file_as_signatures(line)
db += list(sigs)
idx_list = []
src_list = []

loaded = True
file_list = load_file_list_of_signatures(filename)
for fname in file_list:
idx = load_file_as_index(fname)
src = fname

idx_list.append(idx)
src_list.append(src)

db = MultiIndex(idx_list, src_list)
dbtype = DatabaseType.SIGLIST
loaded = True
except Exception as exc:
pass

Expand Down Expand Up @@ -461,19 +433,11 @@ def _load_database(filename, traverse_yield_all, *, cache_size=None):
raise OSError("Error while reading signatures from '{}' - got sequences instead! Is this a FASTA/FASTQ file?".format(filename))

if not loaded:
raise OSError("Error while reading signatures from '{}'.".format(filename))
raise OSError(f"Error while reading signatures from '{filename}'.")

return db, dbtype


# note: dup from index.py internal function.
def _select_sigs(siglist, ksize, moltype):
for ss in siglist:
if (ksize is None or ss.minhash.ksize == ksize) and \
(moltype is None or ss.minhash.moltype == moltype):
yield ss


def load_file_as_index(filename, yield_all_files=False):
"""Load 'filename' as a database; generic database loader.
Expand All @@ -488,14 +452,7 @@ def load_file_as_index(filename, yield_all_files=False):
attempt to load all files.
"""
db, dbtype = _load_database(filename, yield_all_files)
if dbtype in (DatabaseType.LCA, DatabaseType.SBT):
return db # already an index!
elif dbtype == DatabaseType.SIGLIST:
# turn siglist into a LinearIndex
idx = LinearIndex(db, filename)
return idx
else:
assert 0 # unknown enum!?
return db


def load_file_as_signatures(filename, select_moltype=None, ksize=None,
Expand All @@ -519,21 +476,15 @@ def load_file_as_signatures(filename, select_moltype=None, ksize=None,
progress.notify(filename)

db, dbtype = _load_database(filename, yield_all_files)

loader = None
if dbtype in (DatabaseType.LCA, DatabaseType.SBT):
db = db.select(moltype=select_moltype, ksize=ksize)
loader = db.signatures()
elif dbtype == DatabaseType.SIGLIST:
loader = _select_sigs(db, moltype=select_moltype, ksize=ksize)
else:
assert 0 # unknown enum!?
db = db.select(moltype=select_moltype, ksize=ksize)
loader = db.signatures()

if progress:
return progress.start_file(filename, loader)
else:
return loader


def load_file_list_of_signatures(filename):
"Load a list-of-files text file."
try:
Expand Down
Loading

0 comments on commit 7ed6291

Please sign in to comment.