Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 49 additions & 28 deletions bio2zarr/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ def summary_table(self):

@dataclasses.dataclass
class EncodingWork:
func: callable
func: callable = dataclasses.field(repr=False)
start: int
stop: int
columns: list[str]
Expand Down Expand Up @@ -1319,12 +1319,12 @@ def init_array(self, variable):
def get_array(self, name):
return self.root["wip_" + name]

def finalise_array(self, variable):
source = self.path / ("wip_" + variable.name)
dest = self.path / variable.name
def finalise_array(self, variable_name):
source = self.path / ("wip_" + variable_name)
dest = self.path / variable_name
# Atomic swap
os.rename(source, dest)
logger.debug(f"Finalised {variable.name}")
logger.info(f"Finalised {variable_name}")

def encode_array_slice(self, column, start, stop):
source_col = self.pcvcf.columns[column.vcf_field]
Expand Down Expand Up @@ -1471,8 +1471,8 @@ def init(self):
self.init_array(column)

def finalise(self):
for column in self.schema.columns.values():
self.finalise_array(column)
# for column in self.schema.columns.values():
# self.finalise_array(column)
zarr.consolidate_metadata(self.path)

def encode(
Expand Down Expand Up @@ -1536,21 +1536,25 @@ def encode(
work.append(
EncodingWork(self.encode_alleles_slice, start, stop, ["variant_allele"])
)
work.append(EncodingWork(self.encode_id_slice, start, stop, ["variant_id"]))
work.append(
EncodingWork(
self.encode_id_slice, start, stop, ["variant_id", "variant_id_mask"]
)
)
work.append(
EncodingWork(
functools.partial(self.encode_filters_slice, filter_id_map),
start,
stop,
["variant_filters"],
["variant_filter"],
)
)
work.append(
EncodingWork(
functools.partial(self.encode_contig_slice, contig_id_map),
start,
stop,
["variant_contig_id"],
["variant_contig"],
)
)
if "call_genotype" in self.schema.columns:
Expand All @@ -1567,6 +1571,7 @@ def encode(
self.encode_genotypes_slice, start, stop, variables, gt_memory
)
)

# Fail early if we can't fit a particular column into memory
for wp in work:
if wp.memory >= max_memory:
Expand All @@ -1581,31 +1586,47 @@ def encode(
units="B",
show=show_progress,
)
# TODO add a map of slices completed to column here, so that we can
# finalise the arrays as they get completed. We'll have to service
# the futures more, though, not just when we exceed the memory budget

used_memory = 0
max_queued = 4 * max(1, worker_processes)
encoded_slices = collections.Counter()

with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
future = pwm.submit(self.encode_samples)
future_to_memory_use = {future: 0}
for wp in work:
while used_memory + wp.memory >= max_memory:
logger.info(
f"Memory budget {display_size(max_memory)} exceeded: "
f"used={display_size(used_memory)} needed={display_size(wp.memory)}"
)
futures = pwm.wait_for_completed()
released_mem = sum(
future_to_memory_use.pop(future) for future in futures
)
logger.info(
f"{len(futures)} completed, released {display_size(released_mem)}"
future_to_work = {future: EncodingWork(None, 0, 0, [])}

def service_completed_futures():
nonlocal used_memory

completed = pwm.wait_for_completed()
for future in completed:
wp_done = future_to_work.pop(future)
used_memory -= wp_done.memory
logger.debug(
f"Complete {wp_done}: used mem={display_size(used_memory)}"
)
used_memory -= released_mem
for column in wp_done.columns:
encoded_slices[column] += 1
if encoded_slices[column] == len(slices):
# Do this syncronously for simplicity. Should be
# fine as the workers will probably be busy with
# large encode tasks most of the time.
self.finalise_array(column)

for wp in work:
if (
used_memory + wp.memory > max_memory
or len(future_to_work) > max_queued
):
service_completed_futures()
future = pwm.submit(wp.func, wp.start, wp.stop)
used_memory += wp.memory
future_to_memory_use[future] = wp.memory
logger.debug(f"Submit {wp}: used mem={display_size(used_memory)}")
future_to_work[future] = wp

logger.debug("All work submitted")
while len(future_to_work) > 0:
service_completed_futures()


def mkschema(if_path, out):
Expand Down