Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion bio2zarr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,17 +423,25 @@ def dencode_finalise(zarr_path, verbose):
@click.command(name="convert")
@vcfs
@new_zarr_path
@force
@variants_chunk_size
@samples_chunk_size
@verbose
@worker_processes
def convert_vcf(
vcfs, zarr_path, variants_chunk_size, samples_chunk_size, verbose, worker_processes
vcfs,
zarr_path,
force,
variants_chunk_size,
samples_chunk_size,
verbose,
worker_processes,
):
"""
Convert input VCF(s) directly to vcfzarr (not recommended for large files).
"""
setup_logging(verbose)
check_overwrite_dir(zarr_path, force)
vcf.convert(
vcfs,
zarr_path,
Expand Down
55 changes: 28 additions & 27 deletions bio2zarr/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,7 @@ def __post_init__(self):
self.shape = tuple(self.shape)
self.chunks = tuple(self.chunks)
self.dimensions = tuple(self.dimensions)
self.filters = tuple(self.filters)

@staticmethod
def new(**kwargs):
Expand Down Expand Up @@ -1396,27 +1397,29 @@ def variant_chunk_nbytes(self):
for size in self.shape[1:]:
chunk_items *= size
dt = np.dtype(self.dtype)
if dt.kind == "O":
if dt.kind == "O" and "samples" in self.dimensions:
logger.warning(
f"Field {self.name} is a string; max memory usage may "
"be a significant underestimate"
)
return chunk_items * dt.itemsize


ZARR_SCHEMA_FORMAT_VERSION = "0.3"
ZARR_SCHEMA_FORMAT_VERSION = "0.4"


@dataclasses.dataclass
class VcfZarrSchema:
format_version: str
samples_chunk_size: int
variants_chunk_size: int
dimensions: list
samples: list
contigs: list
filters: list
fields: dict
fields: list

def field_map(self):
return {field.name: field for field in self.fields}

def asdict(self):
return dataclasses.asdict(self)
Expand All @@ -1435,9 +1438,7 @@ def fromdict(d):
ret.samples = [Sample(**sd) for sd in d["samples"]]
ret.contigs = [Contig(**sd) for sd in d["contigs"]]
ret.filters = [Filter(**sd) for sd in d["filters"]]
ret.fields = {
key: ZarrColumnSpec(**value) for key, value in d["fields"].items()
}
ret.fields = [ZarrColumnSpec(**sd) for sd in d["fields"]]
return ret

@staticmethod
Expand Down Expand Up @@ -1572,8 +1573,7 @@ def fixed_field_spec(
format_version=ZARR_SCHEMA_FORMAT_VERSION,
samples_chunk_size=samples_chunk_size,
variants_chunk_size=variants_chunk_size,
fields={col.name: col for col in colspecs},
dimensions=["variants", "samples", "ploidy", "alleles", "filters"],
fields=colspecs,
samples=icf.metadata.samples,
contigs=icf.metadata.contigs,
filters=icf.metadata.filters,
Expand Down Expand Up @@ -1701,6 +1701,12 @@ def schema(self):
def num_partitions(self):
return len(self.metadata.partitions)

def has_genotypes(self):
for field in self.schema.fields:
if field.name == "call_genotype":
return True
return False

#######################
# init
#######################
Expand Down Expand Up @@ -1760,7 +1766,7 @@ def init(
root = zarr.group(store=store)

total_chunks = 0
for field in self.schema.fields.values():
for field in self.schema.fields:
a = self.init_array(root, field, partitions[-1].stop)
total_chunks += a.nchunks

Expand All @@ -1778,9 +1784,7 @@ def init(

def encode_samples(self, root):
if self.schema.samples != self.icf.metadata.samples:
raise ValueError(
"Subsetting or reordering samples not supported currently"
) # NEEDS TEST
raise ValueError("Subsetting or reordering samples not supported currently")
array = root.array(
"sample_id",
[sample.id for sample in self.schema.samples],
Expand Down Expand Up @@ -1880,10 +1884,10 @@ def encode_partition(self, partition_index):
self.encode_filters_partition(partition_index)
self.encode_contig_partition(partition_index)
self.encode_alleles_partition(partition_index)
for col in self.schema.fields.values():
for col in self.schema.fields:
if col.vcf_field is not None:
self.encode_array_partition(col, partition_index)
if "call_genotype" in self.schema.fields:
if self.has_genotypes():
self.encode_genotypes_partition(partition_index)

final_path = self.partition_path(partition_index)
Expand Down Expand Up @@ -2100,8 +2104,8 @@ def finalise(self, show_progress=False):
# for multiple workers, or making a standard wrapper for tqdm
# that allows us to have a consistent look and feel.
with core.ParallelWorkManager(0, progress_config) as pwm:
for name in self.schema.fields:
pwm.submit(self.finalise_array, name)
for field in self.schema.fields:
pwm.submit(self.finalise_array, field.name)
logger.debug(f"Removing {self.wip_path}")
shutil.rmtree(self.wip_path)
logger.info("Consolidating Zarr metadata")
Expand All @@ -2116,17 +2120,14 @@ def get_max_encoding_memory(self):
Return the approximate maximum memory used to encode a variant chunk.
"""
max_encoding_mem = 0
for col in self.schema.fields.values():
for col in self.schema.fields:
max_encoding_mem = max(max_encoding_mem, col.variant_chunk_nbytes)
gt_mem = 0
if "call_genotype" in self.schema.fields:
encoded_together = [
"call_genotype",
"call_genotype_phased",
"call_genotype_mask",
]
if self.has_genotypes:
gt_mem = sum(
self.schema.fields[col].variant_chunk_nbytes for col in encoded_together
field.variant_chunk_nbytes
for field in self.schema.fields
if field.name.startswith("call_genotype")
)
return max(max_encoding_mem, gt_mem)

Expand Down Expand Up @@ -2158,7 +2159,7 @@ def encode_all_partitions(
num_workers = min(max_num_workers, worker_processes)

total_bytes = 0
for col in self.schema.fields.values():
for col in self.schema.fields:
# Open the array definition to get the total size
total_bytes += zarr.open(self.arrays_path / col.name).nbytes

Expand Down Expand Up @@ -2273,7 +2274,7 @@ def convert(
# TODO add arguments to control location of tmpdir
):
with tempfile.TemporaryDirectory(prefix="vcf2zarr") as tmp:
if_dir = pathlib.Path(tmp) / "if"
if_dir = pathlib.Path(tmp) / "icf"
explode(
if_dir,
vcfs,
Expand Down
52 changes: 42 additions & 10 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@

DEFAULT_DENCODE_FINALISE_ARGS = dict(show_progress=True)

DEFAULT_CONVERT_ARGS = dict(
variants_chunk_size=None,
samples_chunk_size=None,
show_progress=True,
worker_processes=1,
)


@dataclasses.dataclass
class FakeWorkSummary:
Expand Down Expand Up @@ -508,11 +515,24 @@ def test_convert_vcf(self, mocked):
mocked.assert_called_once_with(
(self.vcf_path,),
"zarr_path",
variants_chunk_size=None,
samples_chunk_size=None,
worker_processes=1,
show_progress=True,
**DEFAULT_CONVERT_ARGS,
)

@pytest.mark.parametrize("response", ["n", "N", "No"])
@mock.patch("bio2zarr.vcf.convert")
def test_vcf_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, response):
zarr_path = tmp_path / "zarr"
zarr_path.mkdir()
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.vcf2zarr,
f"convert {self.vcf_path} {zarr_path}",
catch_exceptions=False,
input=response,
)
assert result.exit_code == 1
assert "Aborted" in result.stderr
mocked.assert_not_called()

@mock.patch("bio2zarr.plink.convert")
def test_convert_plink(self, mocked):
Expand All @@ -523,13 +543,25 @@ def test_convert_plink(self, mocked):
assert result.exit_code == 0
assert len(result.stdout) == 0
assert len(result.stderr) == 0
mocked.assert_called_once_with("in", "out", **DEFAULT_CONVERT_ARGS)

@pytest.mark.parametrize("response", ["y", "Y", "yes"])
@mock.patch("bio2zarr.vcf.convert")
def test_vcf_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response):
zarr_path = tmp_path / "zarr"
zarr_path.mkdir()
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.vcf2zarr,
f"convert {self.vcf_path} {zarr_path}",
catch_exceptions=False,
input=response,
)
assert result.exit_code == 0
assert f"Do you want to overwrite {zarr_path}" in result.stdout
assert len(result.stderr) == 0
mocked.assert_called_once_with(
"in",
"out",
worker_processes=1,
samples_chunk_size=None,
variants_chunk_size=None,
show_progress=True,
(self.vcf_path,), str(zarr_path), **DEFAULT_CONVERT_ARGS
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_icf.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def schema(self, icf):
],
)
def test_info_schemas(self, schema, name, dtype, shape, dimensions):
v = schema.fields[name]
v = schema.field_map()[name]
assert v.dtype == dtype
assert tuple(v.shape) == shape
assert v.dimensions == dimensions
Expand Down
Loading