From 57e6eba1300ef0d56cff495a616ee1ed20aa4184 Mon Sep 17 00:00:00 2001 From: Daniel McDonald Date: Mon, 16 Nov 2020 13:47:17 -0800 Subject: [PATCH] BUG: remove memory hog from storing barcode stats in memory (#122) --- q2_demux/_demux.py | 70 ++++++++++++++++++------------------ q2_demux/_format.py | 15 ++++---- q2_demux/tests/test_demux.py | 64 ++++++++++++++++++--------------- 3 files changed, 77 insertions(+), 72 deletions(-) diff --git a/q2_demux/_demux.py b/q2_demux/_demux.py index 529d02d..6fcb40d 100644 --- a/q2_demux/_demux.py +++ b/q2_demux/_demux.py @@ -13,7 +13,6 @@ import collections.abc import random import resource -import pandas as pd import skbio import psutil @@ -24,11 +23,32 @@ SingleLanePerSamplePairedEndFastqDirFmt, FastqManifestFormat, YamlFormat) from ._ecc import GolayDecoder +from ._format import ErrorCorrectionDetailsFmt FastqHeader = collections.namedtuple('FastqHeader', ['id', 'description']) +class ECDetails: + COLUMNS = ['id', + 'sample', + 'barcode-sequence-id', + 'barcode-uncorrected', + 'barcode-corrected', + 'barcode-errors'] + + def __init__(self, fmt): + self._fp = open(str(fmt), 'w') + self._write_header() + + def write(self, parts): + self._fp.write('\t'.join([str(part) for part in parts])) + self._fp.write('\n') + + def _write_header(self): + self.write(self.COLUMNS) + + def _read_fastq_seqs(filepath): # This function is adapted from @jairideout's SO post: # http://stackoverflow.com/a/39302117/3424666 @@ -254,7 +274,7 @@ def emp_single(seqs: BarcodeSequenceFastqIterator, rev_comp_mapping_barcodes: bool = False, ignore_description_mismatch: bool = False ) -> (SingleLanePerSampleSingleEndFastqDirFmt, - pd.DataFrame): + ErrorCorrectionDetailsFmt): seqs.ignore_description_mismatch = ignore_description_mismatch result = SingleLanePerSampleSingleEndFastqDirFmt() barcode_map, barcode_len = _make_barcode_map( @@ -271,7 +291,9 @@ def emp_single(seqs: BarcodeSequenceFastqIterator, manifest_fh.write('# joined reads\n') per_sample_fastqs = {} - ec_details = [] + + ec_details_fmt = ErrorCorrectionDetailsFmt() + ec_details = ECDetails(ec_details_fmt) for i, (barcode_record, sequence_record) in enumerate(seqs, start=1): barcode_read = barcode_record[1] @@ -296,12 +318,12 @@ def emp_single(seqs: BarcodeSequenceFastqIterator, sample_id = barcode_map.get(barcode_read) record = [ - i, + f'record-{i}', sample_id, barcode_record[0], raw_barcode_read, ] - ec_details.append(record + golay_stats) + ec_details.write(record + golay_stats) if sample_id is None: continue @@ -329,7 +351,6 @@ def emp_single(seqs: BarcodeSequenceFastqIterator, fastq_lines = '\n'.join(sequence_record) + '\n' fastq_lines = fastq_lines.encode('utf-8') per_sample_fastqs[sample_id].write(fastq_lines) - barcode_count = str(i) # last value here should be our largest record no. if len(per_sample_fastqs) == 0: raise ValueError('No sequences were mapped to samples. Check that ' @@ -347,18 +368,7 @@ def emp_single(seqs: BarcodeSequenceFastqIterator, _write_metadata_yaml(result) - columns = ['id', - 'sample', - 'barcode-sequence-id', - 'barcode-uncorrected', - 'barcode-corrected', - 'barcode-errors'] - details = pd.DataFrame(ec_details, columns=columns) - details['id'] = details['id'].apply(lambda x: 'record-%s' % - str(x).zfill(len(barcode_count))) - details = details.set_index('id') - - return result, details + return result, ec_details_fmt def emp_paired(seqs: BarcodePairedSequenceFastqIterator, @@ -368,7 +378,7 @@ def emp_paired(seqs: BarcodePairedSequenceFastqIterator, rev_comp_mapping_barcodes: bool = False, ignore_description_mismatch: bool = False ) -> (SingleLanePerSamplePairedEndFastqDirFmt, - pd.DataFrame): + ErrorCorrectionDetailsFmt): seqs.ignore_description_mismatch = ignore_description_mismatch result = SingleLanePerSamplePairedEndFastqDirFmt() barcode_map, barcode_len = _make_barcode_map( @@ -382,7 +392,9 @@ def emp_paired(seqs: BarcodePairedSequenceFastqIterator, manifest_fh.write('sample-id,filename,direction\n') per_sample_fastqs = {} - ec_details = [] + + ec_details_fmt = ErrorCorrectionDetailsFmt() + ec_details = ECDetails(ec_details_fmt) for i, record in enumerate(seqs, start=1): barcode_record, forward_record, reverse_record = record @@ -408,12 +420,12 @@ def emp_paired(seqs: BarcodePairedSequenceFastqIterator, sample_id = barcode_map.get(barcode_read) record = [ - i, + f'record-{i}', sample_id, barcode_record[0], raw_barcode_read, ] - ec_details.append(record + golay_stats) + ec_details.write(record + golay_stats) if sample_id is None: continue @@ -450,7 +462,6 @@ def emp_paired(seqs: BarcodePairedSequenceFastqIterator, fwd, rev = per_sample_fastqs[sample_id] fwd.write(('\n'.join(forward_record) + '\n').encode('utf-8')) rev.write(('\n'.join(reverse_record) + '\n').encode('utf-8')) - barcode_count = str(i) # last value here should be our largest record no. if len(per_sample_fastqs) == 0: raise ValueError('No sequences were mapped to samples. Check that ' @@ -469,15 +480,4 @@ def emp_paired(seqs: BarcodePairedSequenceFastqIterator, _write_metadata_yaml(result) - columns = ['id', - 'sample', - 'barcode-sequence-id', - 'barcode-uncorrected', - 'barcode-corrected', - 'barcode-errors'] - details = pd.DataFrame(ec_details, columns=columns) - details['id'] = details['id'].apply(lambda x: 'record-%s' % - str(x).zfill(len(barcode_count))) - details = details.set_index('id') - - return result, details + return result, ec_details_fmt diff --git a/q2_demux/_format.py b/q2_demux/_format.py index 744e5bb..1915e66 100644 --- a/q2_demux/_format.py +++ b/q2_demux/_format.py @@ -9,7 +9,6 @@ from q2_types.per_sample_sequences import FastqGzFormat import qiime2.plugin.model as model from qiime2.plugin import ValidationError -import qiime2 # TODO: deprecate this and alias it @@ -71,16 +70,14 @@ class ErrorCorrectionDetailsFmt(model.TextFileFormat): } def _validate_(self, level): - try: - md = qiime2.Metadata.load(str(self)) - except qiime2.metadata.MetadataFileError as md_exc: - raise ValidationError(md_exc) from md_exc + line = open(str(self)).readline() + if len(line.strip()) == 0: + raise ValidationError("Failed to locate header.") + header = set(line.strip().split('\t')) for column in sorted(self.METADATA_COLUMNS): - try: - md.get_column(column) - except ValueError as md_exc: - raise ValidationError(md_exc) from md_exc + if column not in header: + raise ValidationError(f"{column} is not a column") ErrorCorrectionDetailsDirFmt = model.SingleFileDirectoryFormat( diff --git a/q2_demux/tests/test_demux.py b/q2_demux/tests/test_demux.py index dcffc11..8edcaa7 100644 --- a/q2_demux/tests/test_demux.py +++ b/q2_demux/tests/test_demux.py @@ -310,23 +310,24 @@ def test_valid_ecc_no_golay(self): _, ecc = emp_single(self.bsi, self.barcode_map, golay_error_correction=False) exp_errors = pd.DataFrame([ - ['sample1', '@s1/2 abc/2', 'AAAA', None, None], - ['sample3', '@s2/2 abc/2', 'TTAA', None, None], - ['sample2', '@s3/2 abc/2', 'AACC', None, None], - ['sample3', '@s4/2 abc/2', 'TTAA', None, None], - ['sample2', '@s5/2 abc/2', 'AACC', None, None], - ['sample1', '@s6/2 abc/2', 'AAAA', None, None], - ['sample5', '@s7/2 abc/2', 'CGGC', None, None], - ['sample4', '@s8/2 abc/2', 'GGAA', None, None], - ['sample5', '@s9/2 abc/2', 'CGGC', None, None], - ['sample5', '@s10/2 abc/2', 'CGGC', None, None], - ['sample4', '@s11/2 abc/2', 'GGAA', None, None] + ['sample1', '@s1/2 abc/2', 'AAAA', 'None', 'None'], + ['sample3', '@s2/2 abc/2', 'TTAA', 'None', 'None'], + ['sample2', '@s3/2 abc/2', 'AACC', 'None', 'None'], + ['sample3', '@s4/2 abc/2', 'TTAA', 'None', 'None'], + ['sample2', '@s5/2 abc/2', 'AACC', 'None', 'None'], + ['sample1', '@s6/2 abc/2', 'AAAA', 'None', 'None'], + ['sample5', '@s7/2 abc/2', 'CGGC', 'None', 'None'], + ['sample4', '@s8/2 abc/2', 'GGAA', 'None', 'None'], + ['sample5', '@s9/2 abc/2', 'CGGC', 'None', 'None'], + ['sample5', '@s10/2 abc/2', 'CGGC', 'None', 'None'], + ['sample4', '@s11/2 abc/2', 'GGAA', 'None', 'None'] ], columns=['sample', 'barcode-sequence-id', 'barcode-uncorrected', 'barcode-corrected', 'barcode-errors'], - index=pd.Index(['record-%02d' % i for i in range(1, 12)], + index=pd.Index(['record-%d' % i for i in range(1, 12)], name='id')) + ecc = qiime2.Metadata.load(str(ecc)).to_dataframe() pdt.assert_frame_equal(ecc, exp_errors) def test_valid_with_barcode_errors(self): @@ -374,7 +375,7 @@ def test_valid_with_barcode_errors(self): ['sample2', '@s5/2 abc/2', 'ACACACTATGGC', 'ACACACTATGGC', 0], ['sample1', '@s6/2 abc/2', 'ACGATGCGACCA', 'ACGATGCGACCA', 0], ['sample5', '@s7/2 abc/2', 'CATTGTATCAAC', 'CATCGTATCAAC', 1], - [None, '@s8/2 abc/2', 'CTAACGCAGGGG', None, 4], + ['None', '@s8/2 abc/2', 'CTAACGCAGGGG', 'None', 4], ['sample5', '@s9/2 abc/2', 'CATCGTATCAAC', 'CATCGTATCAAC', 0], ['sample5', '@s10/2 abc/2', 'CATCGTATCAAC', 'CATCGTATCAAC', 0], ['sample4', '@s11/2 abc/2', 'CTAACGCAGTCA', 'CTAACGCAGTCA', 0] @@ -382,8 +383,11 @@ def test_valid_with_barcode_errors(self): columns=['sample', 'barcode-sequence-id', 'barcode-uncorrected', 'barcode-corrected', 'barcode-errors'], - index=pd.Index(['record-%02d' % i for i in range(1, 12)], + index=pd.Index(['record-%d' % i for i in range(1, 12)], name='id')) + exp_errors['barcode-errors'] = \ + exp_errors['barcode-errors'].astype(float) + error_detail = qiime2.Metadata.load(str(error_detail)).to_dataframe() pdt.assert_frame_equal(error_detail, exp_errors) @mock.patch('q2_demux._demux.OPEN_FH_LIMIT', 3) @@ -722,7 +726,7 @@ def check_valid(self, *args, **kwargs): ['sample2', '@s5/2 abc/2', 'ACACACTATGGC', 'ACACACTATGGC', 0], ['sample1', '@s6/2 abc/2', 'ACGATGCGACCA', 'ACGATGCGACCA', 0], ['sample5', '@s7/2 abc/2', 'CATTGTATCAAC', 'CATCGTATCAAC', 1], - [None, '@s8/2 abc/2', 'CTAACGCAGGGG', None, 4], + ['None', '@s8/2 abc/2', 'CTAACGCAGGGG', 'None', 4], ['sample5', '@s9/2 abc/2', 'CATCGTATCAAC', 'CATCGTATCAAC', 0], ['sample5', '@s10/2 abc/2', 'CATCGTATCAAC', 'CATCGTATCAAC', 0], ['sample4', '@s11/2 abc/2', 'CTAACGCAGTCA', 'CTAACGCAGTCA', 0] @@ -730,8 +734,11 @@ def check_valid(self, *args, **kwargs): columns=['sample', 'barcode-sequence-id', 'barcode-uncorrected', 'barcode-corrected', 'barcode-errors'], - index=pd.Index(['record-%02d' % i for i in range(1, 12)], + index=pd.Index(['record-%d' % i for i in range(1, 12)], name='id')) + exp_errors['barcode-errors'] = \ + exp_errors['barcode-errors'].astype(float) + ecc = qiime2.Metadata.load(str(ecc)).to_dataframe() pdt.assert_frame_equal(ecc, exp_errors) def test_valid(self): @@ -742,23 +749,24 @@ def test_valid_ecc_no_golay(self): _, ecc = emp_paired(self.bpsi, self.barcode_map, golay_error_correction=False) exp_errors = pd.DataFrame([ - ['sample1', '@s1/2 abc/2', 'AAAA', None, None], - ['sample3', '@s2/2 abc/2', 'TTAA', None, None], - ['sample2', '@s3/2 abc/2', 'AACC', None, None], - ['sample3', '@s4/2 abc/2', 'TTAA', None, None], - ['sample2', '@s5/2 abc/2', 'AACC', None, None], - ['sample1', '@s6/2 abc/2', 'AAAA', None, None], - ['sample5', '@s7/2 abc/2', 'CGGC', None, None], - ['sample4', '@s8/2 abc/2', 'GGAA', None, None], - ['sample5', '@s9/2 abc/2', 'CGGC', None, None], - ['sample5', '@s10/2 abc/2', 'CGGC', None, None], - ['sample4', '@s11/2 abc/2', 'GGAA', None, None] + ['sample1', '@s1/2 abc/2', 'AAAA', 'None', 'None'], + ['sample3', '@s2/2 abc/2', 'TTAA', 'None', 'None'], + ['sample2', '@s3/2 abc/2', 'AACC', 'None', 'None'], + ['sample3', '@s4/2 abc/2', 'TTAA', 'None', 'None'], + ['sample2', '@s5/2 abc/2', 'AACC', 'None', 'None'], + ['sample1', '@s6/2 abc/2', 'AAAA', 'None', 'None'], + ['sample5', '@s7/2 abc/2', 'CGGC', 'None', 'None'], + ['sample4', '@s8/2 abc/2', 'GGAA', 'None', 'None'], + ['sample5', '@s9/2 abc/2', 'CGGC', 'None', 'None'], + ['sample5', '@s10/2 abc/2', 'CGGC', 'None', 'None'], + ['sample4', '@s11/2 abc/2', 'GGAA', 'None', 'None'] ], columns=['sample', 'barcode-sequence-id', 'barcode-uncorrected', 'barcode-corrected', 'barcode-errors'], - index=pd.Index(['record-%02d' % i for i in range(1, 12)], + index=pd.Index(['record-%d' % i for i in range(1, 12)], name='id')) + ecc = qiime2.Metadata.load(str(ecc)).to_dataframe() pdt.assert_frame_equal(ecc, exp_errors) def test_valid_with_barcode_errors(self):