diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index cc6b0447..6274b4f2 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -1286,7 +1286,7 @@ def summary_table(self): @dataclasses.dataclass class EncodingWork: - func: callable + func: callable = dataclasses.field(repr=False) start: int stop: int columns: list[str] @@ -1319,12 +1319,12 @@ def init_array(self, variable): def get_array(self, name): return self.root["wip_" + name] - def finalise_array(self, variable): - source = self.path / ("wip_" + variable.name) - dest = self.path / variable.name + def finalise_array(self, variable_name): + source = self.path / ("wip_" + variable_name) + dest = self.path / variable_name # Atomic swap os.rename(source, dest) - logger.debug(f"Finalised {variable.name}") + logger.info(f"Finalised {variable_name}") def encode_array_slice(self, column, start, stop): source_col = self.pcvcf.columns[column.vcf_field] @@ -1471,8 +1471,8 @@ def init(self): self.init_array(column) def finalise(self): - for column in self.schema.columns.values(): - self.finalise_array(column) + # for column in self.schema.columns.values(): + # self.finalise_array(column) zarr.consolidate_metadata(self.path) def encode( @@ -1536,13 +1536,17 @@ def encode( work.append( EncodingWork(self.encode_alleles_slice, start, stop, ["variant_allele"]) ) - work.append(EncodingWork(self.encode_id_slice, start, stop, ["variant_id"])) + work.append( + EncodingWork( + self.encode_id_slice, start, stop, ["variant_id", "variant_id_mask"] + ) + ) work.append( EncodingWork( functools.partial(self.encode_filters_slice, filter_id_map), start, stop, - ["variant_filters"], + ["variant_filter"], ) ) work.append( @@ -1550,7 +1554,7 @@ def encode( functools.partial(self.encode_contig_slice, contig_id_map), start, stop, - ["variant_contig_id"], + ["variant_contig"], ) ) if "call_genotype" in self.schema.columns: @@ -1567,6 +1571,7 @@ def encode( self.encode_genotypes_slice, start, stop, variables, gt_memory ) ) + # Fail early if we can't fit a particular column into memory for wp in work: if wp.memory >= max_memory: @@ -1581,31 +1586,47 @@ def encode( units="B", show=show_progress, ) - # TODO add a map of slices completed to column here, so that we can - # finalise the arrays as they get completed. We'll have to service - # the futures more, though, not just when we exceed the memory budget used_memory = 0 + max_queued = 4 * max(1, worker_processes) + encoded_slices = collections.Counter() + with core.ParallelWorkManager(worker_processes, progress_config) as pwm: future = pwm.submit(self.encode_samples) - future_to_memory_use = {future: 0} - for wp in work: - while used_memory + wp.memory >= max_memory: - logger.info( - f"Memory budget {display_size(max_memory)} exceeded: " - f"used={display_size(used_memory)} needed={display_size(wp.memory)}" - ) - futures = pwm.wait_for_completed() - released_mem = sum( - future_to_memory_use.pop(future) for future in futures - ) - logger.info( - f"{len(futures)} completed, released {display_size(released_mem)}" + future_to_work = {future: EncodingWork(None, 0, 0, [])} + + def service_completed_futures(): + nonlocal used_memory + + completed = pwm.wait_for_completed() + for future in completed: + wp_done = future_to_work.pop(future) + used_memory -= wp_done.memory + logger.debug( + f"Complete {wp_done}: used mem={display_size(used_memory)}" ) - used_memory -= released_mem + for column in wp_done.columns: + encoded_slices[column] += 1 + if encoded_slices[column] == len(slices): + # Do this syncronously for simplicity. Should be + # fine as the workers will probably be busy with + # large encode tasks most of the time. + self.finalise_array(column) + + for wp in work: + if ( + used_memory + wp.memory > max_memory + or len(future_to_work) > max_queued + ): + service_completed_futures() future = pwm.submit(wp.func, wp.start, wp.stop) used_memory += wp.memory - future_to_memory_use[future] = wp.memory + logger.debug(f"Submit {wp}: used mem={display_size(used_memory)}") + future_to_work[future] = wp + + logger.debug("All work submitted") + while len(future_to_work) > 0: + service_completed_futures() def mkschema(if_path, out):