Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Update sbt_gather output formats #175

Merged
merged 18 commits into from May 16, 2017
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/support.rst
@@ -0,0 +1,9 @@
Support
=======

Please ask questions and file bug descriptions `on the GitHub issue
tracker for sourmash, dib-lab/sourmash/issues
<https://github.com/dib-lab/sourmash/issues>`__.

You can also ask questions of Titus at `@ctitusbrown
<https://twitter.com/ctitusbrown/>`__ on Twitter.
8 changes: 8 additions & 0 deletions sourmash_lib/__init__.py
Expand Up @@ -15,3 +15,11 @@

DEFAULT_SEED = get_minhash_default_seed()
MAX_HASH = get_minhash_max_hash()

def scaled_to_max_hash(scaled):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch - removed!

if scaled > 1:
max_hash = MAX_HASH / float(scaled)
max_hash = int(round(max_hash, 0))
return max_hash
else:
return 0
120 changes: 63 additions & 57 deletions sourmash_lib/commands.py
Expand Up @@ -229,8 +229,7 @@ def save_siglist(siglist, output_fp, filename=None):
sig.save_signatures(siglist, fp)

if args.track_abundance:
print('Tracking abundance of input k-mers.',
file=sys.stderr)
notify('Tracking abundance of input k-mers.')

if not args.merge:
if args.output:
Expand All @@ -240,8 +239,7 @@ def save_siglist(siglist, output_fp, filename=None):
sigfile = os.path.basename(filename) + '.sig'
if not args.output and os.path.exists(sigfile) and not \
args.force:
print('skipping', filename, '- already done',
file=sys.stderr)
notify('skipping {} - already done', filename)
continue

if args.singleton:
Expand All @@ -254,20 +252,20 @@ def save_siglist(siglist, output_fp, filename=None):

siglist += build_siglist(args.email, Elist, filename,
name=record.name)
print('calculated {} signatures for {} sequences in {}'.\
format(len(siglist), n + 1, filename), file=sys.stderr)

notify('calculated {} signatures for {} sequences in {}'.\
format(len(siglist), n + 1, filename))
else:
# make minhashes for the whole file
Elist = make_minhashes()

# consume & calculate signatures
print('... reading sequences from', filename,
file=sys.stderr)
notify('... reading sequences from {}', filename)
name = None
for n, record in enumerate(screed.open(filename)):
if n % 10000 == 0:
if n:
print('...', filename, n, file=sys.stderr)
notify('...{} {}', filename, n)
elif args.name_from_first:
name = record.name

Expand All @@ -280,9 +278,9 @@ def save_siglist(siglist, output_fp, filename=None):
siglist += sigs
else:
siglist = sigs
print('calculated {} signatures for {} sequences in {}'.\
format(len(siglist), n + 1, filename),
file=sys.stderr)

notify('calculated {} signatures for {} sequences in {}'.\
format(len(siglist), n + 1, filename))

if not args.output:
save_siglist(siglist, args.output, sigfile)
Expand All @@ -295,18 +293,17 @@ def save_siglist(siglist, output_fp, filename=None):

for filename in args.filenames:
# consume & calculate signatures
print('... reading sequences from', filename,
file=sys.stderr)
notify('... reading sequences from', filename)
for n, record in enumerate(screed.open(filename)):
if n % 10000 == 0 and n:
print('...', filename, n, file=sys.stderr)
notify('...', filename, n)

add_seq(Elist, record.sequence,
args.input_is_protein, args.check_sequence)

siglist = build_siglist(args.email, Elist, filename,
name=args.merge)
print('calculated {} signatures for {} sequences taken from {}'.\
notify('calculated {} signatures for {} sequences taken from {}'.\
format(len(siglist), n + 1, " ".join(args.filenames)))
# at end, save!
save_siglist(siglist, args.output)
Expand All @@ -327,7 +324,7 @@ def compare(args):
# load in the various signatures
siglist = []
for filename in args.signatures:
print('loading', filename, file=sys.stderr)
notify('loading {}', filename)
loaded = sig.load_signatures(filename, select_ksize=args.ksize)
loaded = list(loaded)
if not loaded:
Expand All @@ -349,7 +346,8 @@ def compare(args):
for j, E2 in enumerate(siglist):
D[i][j] = E.similarity(E2, args.ignore_abundance)

print('%d-%20s\t%s' % (i, E.name(), D[i, :, ],))
if len(siglist) < 30:
print('%d-%20s\t%s' % (i, E.name(), D[i, :, ],))
labeltext.append(E.name())

notify('min similarity in matrix: {}', numpy.min(D))
Expand All @@ -368,6 +366,8 @@ def compare(args):

def plot(args):
"Produce a clustering and plot."
import matplotlib as mpl
mpl.use('Agg')
import numpy
import scipy
import pylab
Expand Down Expand Up @@ -697,6 +697,7 @@ def categorize(args):
if loader.skipped_nosig:
notify('skipped/nosig: {}', loader.skipped_nosig)


def sbt_gather(args):
from sourmash_lib.sbt import SBT, GraphFactory
from sourmash_lib.sbtmh import search_minhashes, SigLeaf
Expand All @@ -708,8 +709,8 @@ def sbt_gather(args):
parser.add_argument('-k', '--ksize', type=int, default=DEFAULT_K)
parser.add_argument('--threshold', default=0.05, type=float)
parser.add_argument('-o', '--output', type=argparse.FileType('wt'))
parser.add_argument('--csv', type=argparse.FileType('wt'))
parser.add_argument('--save-matches', type=argparse.FileType('wt'))
parser.add_argument('--threshold-bp', type=float, default=5e4)

sourmash_args.add_moltype_args(parser)

Expand All @@ -734,7 +735,6 @@ def sbt_gather(args):
error('query signature needs to be created with --scaled')
sys.exit(-1)

notify('query signature has max_hash: {}', query.minhash.max_hash)
orig_query = query
orig_mins = orig_query.minhash.get_hashes()

Expand Down Expand Up @@ -767,6 +767,19 @@ def build_new_signature(mins):
e.add_many(mins)
return sig.SourmashSignature('', e)

# xxx
def format_bp(bp):
bp = float(bp)
if bp < 500:
return '{:.0f} bp '.format(bp)
elif bp <= 500e3:
return '{:.1f} kbp'.format(round(bp / 1e3, 1))
elif bp < 500e6:
return '{:.1f} Mbp'.format(round(bp / 1e6, 1))
elif bp < 500e9:
return '{:.1f} Gbp'.format(round(bp / 1e9, 1))
return '???'

# construct a new query that doesn't have the max_hash attribute set.
new_mins = query.minhash.get_hashes()
query = build_new_signature(new_mins)
Expand Down Expand Up @@ -808,72 +821,65 @@ def build_new_signature(mins):
# calculate intersection:
intersect_mins = query_mins.intersection(found_mins)
intersect_orig_mins = orig_mins.intersection(found_mins)
intersect_bp = R_comparison * len(intersect_orig_mins)
sum_found += len(intersect_mins)

if len(intersect_mins) < 5: # hard cutoff for now
notify('found only {} hashes in common.', len(intersect_mins))
notify('this is below a sane threshold => exiting.')
if intersect_bp < args.threshold_bp: # hard cutoff for now
notify('found less than {} in common. => exiting',
format_bp(intersect_bp))
break

# calculate fractions wrt first denominator - genome size
genome_n_mins = len(found_mins)
f_genome = len(intersect_mins) / float(genome_n_mins)
f_orig_query = len(intersect_orig_mins) / float(genome_n_mins)
f_orig_query = len(intersect_orig_mins) / float(len(orig_mins))

# calculate fractions wrt second denominator - metagenome size
query_n_mins = len(orig_query.minhash.get_hashes())
f_query = len(intersect_mins) / float(query_n_mins)

if not len(found): # first result? print header.
notify("")
notify("overlap p_query p_match ")
notify("--------- ------- --------")

# print interim result & save in a list for later use
notify('found: {:.2f} {:.2f} {}', f_genome, f_query,
best_leaf.name())
found.append((f_genome, best_leaf, f_query))
pct_query = '{:.1f}%'.format(f_orig_query*100)
pct_genome = '{:.1f}%'.format(f_genome*100)

notify('{:9} {:>6} {:>6} {}',
format_bp(intersect_bp), pct_query, pct_genome,
best_leaf.name()[:40])
found.append((intersect_bp, f_orig_query, best_leaf, f_genome))

# construct a new query, minus the previous one.
query_mins -= set(found_mins)
query = build_new_signature(query_mins)

# basic reporting
notify('found {} matches total', len(found))
notify('the recovered matches hit {:.1f}% of the query',
100. * sum_found / len(orig_query.minhash.get_hashes()))
notify('\nfound {} matches total;', len(found))

sum_found /= len(orig_query.minhash.get_hashes())
notify('the recovered matches hit {:.1f}% of the query', sum_found * 100)
notify('')

if not found:
sys.exit(0)

# sort by fraction of genome (first key) - change this?
found.sort(key=lambda x: x[0])
found.reverse()

notify('Composition:')
for (frac, leaf_sketch, genome_fraction) in found:
notify('{:.2f} {:.2f} {}', frac, genome_fraction, leaf_sketch.name())

if args.output:
print('Composition:', file=args.output)
for (frac, leaf_sketch, genome_fraction) in found:
print('{:.2f} {:.2f} {}'.format(frac, genome_fraction,
leaf_sketch.name()),
file=args.output)

if args.csv:
fieldnames = ['fraction', 'name', 'sketch_kmers', 'genome_fraction']
w = csv.DictWriter(args.csv, fieldnames=fieldnames)

fieldnames = ['intersect_bp', 'f_orig_query', 'f_found_genome', 'name']
w = csv.DictWriter(args.output, fieldnames=fieldnames)
w.writeheader()
for (frac, leaf_sketch, genome_fraction) in found:
cardinality = 0
if leaf_sketch.minhash.hll:
cardinality = leaf_sketch.minhash.hll.estimate_cardinality()
w.writerow(dict(fraction=frac, name=leaf_sketch.name(),
sketch_kmers=cardinality,
genome_fraction=genome_fraction))
for (intersect_bp, f_genome, leaf, f_orig_query) in found:
w.writerow(dict(intersect_bp=intersect_bp,
f_orig_query=f_orig_query, name=leaf.name(),
f_found_genome=f_genome,))

if args.save_matches:
outname = args.save_matches.name
notify('saving all matches to "{}"', outname)
sig.save_signatures([ ss for (f, ss) in found ],
args.save_matches)
sig.save_signatures([ ss for (_, _, ss, _) in found ],
args.save_matches)


def watch(args):
Expand Down
47 changes: 41 additions & 6 deletions tests/test_sourmash.py
Expand Up @@ -1049,14 +1049,14 @@ def test_sbt_gather():

status, out, err = utils.runscript('sourmash',
['sbt_gather', 'zzz',
'query.fa.sig', '--csv',
'foo.csv'],
'query.fa.sig', '-o',
'foo.csv', '--threshold-bp=1'],
in_directory=location)

print(out)
print(err)

assert 'found: 1.00 1.00 ' in err
assert '0.9 kbp 100.0% 100.0%' in err


def test_sbt_gather_file_output():
Expand Down Expand Up @@ -1085,13 +1085,17 @@ def test_sbt_gather_file_output():
status, out, err = utils.runscript('sourmash',
['sbt_gather', 'zzz',
'query.fa.sig',
'--threshold-bp=500',
'-o', 'foo.out'],
in_directory=location)

print(out)
print(err)
assert '0.9 kbp 100.0% 100.0%' in err
with open(os.path.join(location, 'foo.out')) as f:
output = f.read()
print(output)
assert '1.00 1.00 ' in output
print((output,))
assert '910.0,1.0,1.0' in output


def test_sbt_gather_metagenome():
Expand All @@ -1117,8 +1121,39 @@ def test_sbt_gather_metagenome():
print(out)
print(err)

assert 'found 11 matches total' in err
assert 'found 12 matches total' in err
assert 'the recovered matches hit 100.0% of the query' in err
assert '4.9 Mbp 33.2% 100.0% NC_003198.1 Salmonella enterica subsp.' in err
assert '4.7 Mbp 32.1% 1.5% NC_011294.1 Salmonella enterica subsp' in err


def test_sbt_gather_save_matches():
with utils.TempDirectory() as location:
testdata_glob = utils.get_test_data('gather/GCF*.sig')
testdata_sigs = glob.glob(testdata_glob)

query_sig = utils.get_test_data('gather/combined.sig')

cmd = ['sbt_index', 'gcf_all', '-k', '21']
cmd.extend(testdata_sigs)

status, out, err = utils.runscript('sourmash', cmd,
in_directory=location)

assert os.path.exists(os.path.join(location, 'gcf_all.sbt.json'))

status, out, err = utils.runscript('sourmash',
['sbt_gather', 'gcf_all',
query_sig, '-k', '21',
'--save-matches', 'save.sigs'],
in_directory=location)

print(out)
print(err)

assert 'found 12 matches total' in err
assert 'the recovered matches hit 100.0% of the query' in err
assert os.path.exists(os.path.join(location, 'save.sigs'))


def test_sbt_gather_error_no_cardinality_query():
Expand Down