Skip to content

Commit

Permalink
Added option (-i/--input-fasta-genome-name) to specify genome names a…
Browse files Browse the repository at this point in the history
…long with path to fasta
  • Loading branch information
peterk87 committed Feb 21, 2017
1 parent 6f44cc3 commit e4babd6
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions sistr/sistr_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def init_parser():

parser.add_argument('fastas',
metavar='F',
nargs='+',
nargs='*',
help='Input genome FASTA file')
parser.add_argument('-i',
'--input-fasta-genome-name',
nargs=2,
action='append',
help='fasta file path to genome name pair')
parser.add_argument('-f',
'--output-format',
default='json',
Expand Down Expand Up @@ -154,11 +159,12 @@ def infer_o_antigen(prediction):
prediction.o_antigen = counter_o_antigens.most_common(1)[0][0]


def sistr_predict(input_fasta, tmp_dir, keep_tmp, args):
def sistr_predict(input_fasta, genome_name, tmp_dir, keep_tmp, args):
blast_runner = None
try:
assert os.path.exists(input_fasta), "Input fasta file '%s' must exist!" % input_fasta
genome_name = genome_name_from_fasta_path(input_fasta)
if genome_name is None or genome_name == '':
genome_name = genome_name_from_fasta_path(input_fasta)
dtnow = datetime.now()
genome_tmp_dir = os.path.join(tmp_dir, dtnow.strftime("%Y%m%d%H%M%S") + '-' + 'SISTR' + '-' + genome_name)
blast_runner = BlastRunner(input_fasta, genome_tmp_dir)
Expand Down Expand Up @@ -278,8 +284,23 @@ def main():
args = parser.parse_args()
init_console_logger(args.verbose)
input_fastas = args.fastas
if len(input_fastas) == 0:
paths_names = args.input_fasta_genome_name
if len(input_fastas) == 0 and paths_names and len(paths_names) == 0:
raise Exception('No FASTA files specified!')
if paths_names is None:
genome_names = [genome_name_from_fasta_path(x) for x in input_fastas]
else:
if len(input_fastas) == 0 and len(paths_names) > 0:
input_fastas = [x for x,y in paths_names]
genome_names = [y for x,y in paths_names]
elif len(input_fastas) > 0 and len(paths_names) > 0:
tmp = input_fastas
input_fastas = [x for x,y in paths_names] + tmp
genome_names = [y for x,y in paths_names] + [genome_name_from_fasta_path(x) for x in tmp]
else:
raise Exception('Unhandled fasta input args: input_fastas="{}" | input_fasta_genome_name="{}"'.format(
input_fastas,
paths_names))

tmp_dir = args.tmp_dir
keep_tmp = args.keep_tmp
Expand All @@ -291,12 +312,12 @@ def main():

if n_threads == 1:
logging.info('Serial single threaded run mode on %s genomes', len(input_fastas))
outputs = [sistr_predict(input_fasta, tmp_dir, keep_tmp, args) for input_fasta in input_fastas]
outputs = [sistr_predict(input_fasta, genome_name, tmp_dir, keep_tmp, args) for input_fasta, genome_name in zip(input_fastas, genome_names)]
else:
logging.info('Initializing thread pool with %s threads', n_threads)
pool = Pool(processes=n_threads)
logging.info('Running SISTR analysis asynchronously on %s genomes', len(input_fastas))
res = [pool.apply_async(sistr_predict, (input_fasta, tmp_dir, keep_tmp, args)) for input_fasta in input_fastas]
res = [pool.apply_async(sistr_predict, (input_fasta, genome_name, tmp_dir, keep_tmp, args)) for input_fasta, genome_name in zip(input_fastas, genome_names)]

logging.info('Getting SISTR analysis results')
outputs = [x.get() for x in res]
Expand Down

0 comments on commit e4babd6

Please sign in to comment.