diff --git a/bio2zarr/cli.py b/bio2zarr/cli.py index 7e7aabbd..2dd04c97 100644 --- a/bio2zarr/cli.py +++ b/bio2zarr/cli.py @@ -423,17 +423,25 @@ def dencode_finalise(zarr_path, verbose): @click.command(name="convert") @vcfs @new_zarr_path +@force @variants_chunk_size @samples_chunk_size @verbose @worker_processes def convert_vcf( - vcfs, zarr_path, variants_chunk_size, samples_chunk_size, verbose, worker_processes + vcfs, + zarr_path, + force, + variants_chunk_size, + samples_chunk_size, + verbose, + worker_processes, ): """ Convert input VCF(s) directly to vcfzarr (not recommended for large files). """ setup_logging(verbose) + check_overwrite_dir(zarr_path, force) vcf.convert( vcfs, zarr_path, diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index 4f1b9475..23bb566a 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -1311,6 +1311,7 @@ def __post_init__(self): self.shape = tuple(self.shape) self.chunks = tuple(self.chunks) self.dimensions = tuple(self.dimensions) + self.filters = tuple(self.filters) @staticmethod def new(**kwargs): @@ -1396,7 +1397,7 @@ def variant_chunk_nbytes(self): for size in self.shape[1:]: chunk_items *= size dt = np.dtype(self.dtype) - if dt.kind == "O": + if dt.kind == "O" and "samples" in self.dimensions: logger.warning( f"Field {self.name} is a string; max memory usage may " "be a significant underestimate" @@ -1404,7 +1405,7 @@ def variant_chunk_nbytes(self): return chunk_items * dt.itemsize -ZARR_SCHEMA_FORMAT_VERSION = "0.3" +ZARR_SCHEMA_FORMAT_VERSION = "0.4" @dataclasses.dataclass @@ -1412,11 +1413,13 @@ class VcfZarrSchema: format_version: str samples_chunk_size: int variants_chunk_size: int - dimensions: list samples: list contigs: list filters: list - fields: dict + fields: list + + def field_map(self): + return {field.name: field for field in self.fields} def asdict(self): return dataclasses.asdict(self) @@ -1435,9 +1438,7 @@ def fromdict(d): ret.samples = [Sample(**sd) for sd in d["samples"]] ret.contigs = [Contig(**sd) for sd in d["contigs"]] ret.filters = [Filter(**sd) for sd in d["filters"]] - ret.fields = { - key: ZarrColumnSpec(**value) for key, value in d["fields"].items() - } + ret.fields = [ZarrColumnSpec(**sd) for sd in d["fields"]] return ret @staticmethod @@ -1572,8 +1573,7 @@ def fixed_field_spec( format_version=ZARR_SCHEMA_FORMAT_VERSION, samples_chunk_size=samples_chunk_size, variants_chunk_size=variants_chunk_size, - fields={col.name: col for col in colspecs}, - dimensions=["variants", "samples", "ploidy", "alleles", "filters"], + fields=colspecs, samples=icf.metadata.samples, contigs=icf.metadata.contigs, filters=icf.metadata.filters, @@ -1701,6 +1701,12 @@ def schema(self): def num_partitions(self): return len(self.metadata.partitions) + def has_genotypes(self): + for field in self.schema.fields: + if field.name == "call_genotype": + return True + return False + ####################### # init ####################### @@ -1760,7 +1766,7 @@ def init( root = zarr.group(store=store) total_chunks = 0 - for field in self.schema.fields.values(): + for field in self.schema.fields: a = self.init_array(root, field, partitions[-1].stop) total_chunks += a.nchunks @@ -1778,9 +1784,7 @@ def init( def encode_samples(self, root): if self.schema.samples != self.icf.metadata.samples: - raise ValueError( - "Subsetting or reordering samples not supported currently" - ) # NEEDS TEST + raise ValueError("Subsetting or reordering samples not supported currently") array = root.array( "sample_id", [sample.id for sample in self.schema.samples], @@ -1880,10 +1884,10 @@ def encode_partition(self, partition_index): self.encode_filters_partition(partition_index) self.encode_contig_partition(partition_index) self.encode_alleles_partition(partition_index) - for col in self.schema.fields.values(): + for col in self.schema.fields: if col.vcf_field is not None: self.encode_array_partition(col, partition_index) - if "call_genotype" in self.schema.fields: + if self.has_genotypes(): self.encode_genotypes_partition(partition_index) final_path = self.partition_path(partition_index) @@ -2100,8 +2104,8 @@ def finalise(self, show_progress=False): # for multiple workers, or making a standard wrapper for tqdm # that allows us to have a consistent look and feel. with core.ParallelWorkManager(0, progress_config) as pwm: - for name in self.schema.fields: - pwm.submit(self.finalise_array, name) + for field in self.schema.fields: + pwm.submit(self.finalise_array, field.name) logger.debug(f"Removing {self.wip_path}") shutil.rmtree(self.wip_path) logger.info("Consolidating Zarr metadata") @@ -2116,17 +2120,14 @@ def get_max_encoding_memory(self): Return the approximate maximum memory used to encode a variant chunk. """ max_encoding_mem = 0 - for col in self.schema.fields.values(): + for col in self.schema.fields: max_encoding_mem = max(max_encoding_mem, col.variant_chunk_nbytes) gt_mem = 0 - if "call_genotype" in self.schema.fields: - encoded_together = [ - "call_genotype", - "call_genotype_phased", - "call_genotype_mask", - ] + if self.has_genotypes: gt_mem = sum( - self.schema.fields[col].variant_chunk_nbytes for col in encoded_together + field.variant_chunk_nbytes + for field in self.schema.fields + if field.name.startswith("call_genotype") ) return max(max_encoding_mem, gt_mem) @@ -2158,7 +2159,7 @@ def encode_all_partitions( num_workers = min(max_num_workers, worker_processes) total_bytes = 0 - for col in self.schema.fields.values(): + for col in self.schema.fields: # Open the array definition to get the total size total_bytes += zarr.open(self.arrays_path / col.name).nbytes @@ -2273,7 +2274,7 @@ def convert( # TODO add arguments to control location of tmpdir ): with tempfile.TemporaryDirectory(prefix="vcf2zarr") as tmp: - if_dir = pathlib.Path(tmp) / "if" + if_dir = pathlib.Path(tmp) / "icf" explode( if_dir, vcfs, diff --git a/tests/test_cli.py b/tests/test_cli.py index 1e07b0ca..29c10243 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -47,6 +47,13 @@ DEFAULT_DENCODE_FINALISE_ARGS = dict(show_progress=True) +DEFAULT_CONVERT_ARGS = dict( + variants_chunk_size=None, + samples_chunk_size=None, + show_progress=True, + worker_processes=1, +) + @dataclasses.dataclass class FakeWorkSummary: @@ -508,11 +515,24 @@ def test_convert_vcf(self, mocked): mocked.assert_called_once_with( (self.vcf_path,), "zarr_path", - variants_chunk_size=None, - samples_chunk_size=None, - worker_processes=1, - show_progress=True, + **DEFAULT_CONVERT_ARGS, + ) + + @pytest.mark.parametrize("response", ["n", "N", "No"]) + @mock.patch("bio2zarr.vcf.convert") + def test_vcf_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, response): + zarr_path = tmp_path / "zarr" + zarr_path.mkdir() + runner = ct.CliRunner(mix_stderr=False) + result = runner.invoke( + cli.vcf2zarr, + f"convert {self.vcf_path} {zarr_path}", + catch_exceptions=False, + input=response, ) + assert result.exit_code == 1 + assert "Aborted" in result.stderr + mocked.assert_not_called() @mock.patch("bio2zarr.plink.convert") def test_convert_plink(self, mocked): @@ -523,13 +543,25 @@ def test_convert_plink(self, mocked): assert result.exit_code == 0 assert len(result.stdout) == 0 assert len(result.stderr) == 0 + mocked.assert_called_once_with("in", "out", **DEFAULT_CONVERT_ARGS) + + @pytest.mark.parametrize("response", ["y", "Y", "yes"]) + @mock.patch("bio2zarr.vcf.convert") + def test_vcf_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response): + zarr_path = tmp_path / "zarr" + zarr_path.mkdir() + runner = ct.CliRunner(mix_stderr=False) + result = runner.invoke( + cli.vcf2zarr, + f"convert {self.vcf_path} {zarr_path}", + catch_exceptions=False, + input=response, + ) + assert result.exit_code == 0 + assert f"Do you want to overwrite {zarr_path}" in result.stdout + assert len(result.stderr) == 0 mocked.assert_called_once_with( - "in", - "out", - worker_processes=1, - samples_chunk_size=None, - variants_chunk_size=None, - show_progress=True, + (self.vcf_path,), str(zarr_path), **DEFAULT_CONVERT_ARGS ) diff --git a/tests/test_icf.py b/tests/test_icf.py index 29e1eae0..35c24f0e 100644 --- a/tests/test_icf.py +++ b/tests/test_icf.py @@ -228,7 +228,7 @@ def schema(self, icf): ], ) def test_info_schemas(self, schema, name, dtype, shape, dimensions): - v = schema.fields[name] + v = schema.field_map()[name] assert v.dtype == dtype assert tuple(v.shape) == shape assert v.dimensions == dimensions diff --git a/tests/test_vcf.py b/tests/test_vcf.py index 2cdf6c5f..1284582b 100644 --- a/tests/test_vcf.py +++ b/tests/test_vcf.py @@ -32,7 +32,7 @@ def schema_path(icf_path, tmp_path_factory): @pytest.fixture(scope="module") def schema(schema_path): with open(schema_path) as f: - return json.load(f) + return vcf.VcfZarrSchema.fromjson(f.read()) @pytest.fixture(scope="module") @@ -83,7 +83,7 @@ def test_not_enough_memory_for_two( class TestJsonVersions: @pytest.mark.parametrize("version", ["0.1", "1.0", "xxxxx", 0.2]) def test_zarr_schema_mismatch(self, schema, version): - d = dict(schema) + d = schema.asdict() d["format_version"] = version with pytest.raises(ValueError, match="Zarr schema format version mismatch"): vcf.VcfZarrSchema.fromdict(d) @@ -156,13 +156,13 @@ def test_generated_no_samples(self, icf_path): def test_generated_change_dtype(self, icf_path): icf = vcf.IntermediateColumnarFormat(icf_path) schema = vcf.VcfZarrSchema.generate(icf) - schema.fields["variant_position"].dtype = "i8" + schema.field_map()["variant_position"].dtype = "i8" self.assert_json_round_trip(schema) def test_generated_change_compressor(self, icf_path): icf = vcf.IntermediateColumnarFormat(icf_path) schema = vcf.VcfZarrSchema.generate(icf) - schema.fields["variant_position"].compressor = {"cname": "FAKE"} + schema.field_map()["variant_position"].compressor = {"cname": "FAKE"} self.assert_json_round_trip(schema) @@ -174,7 +174,7 @@ def test_codec(self, tmp_path, icf_path, cname, clevel, shuffle): zarr_path = tmp_path / "zarr" icf = vcf.IntermediateColumnarFormat(icf_path) schema = vcf.VcfZarrSchema.generate(icf) - for var in schema.fields.values(): + for var in schema.fields: var.compressor["cname"] = cname var.compressor["clevel"] = clevel var.compressor["shuffle"] = shuffle @@ -183,7 +183,7 @@ def test_codec(self, tmp_path, icf_path, cname, clevel, shuffle): f.write(schema.asjson()) vcf.encode(icf_path, zarr_path, schema_path=schema_path) root = zarr.open(zarr_path) - for var in schema.fields.values(): + for var in schema.fields: a = root[var.name] assert a.compressor.cname == cname assert a.compressor.clevel == clevel @@ -194,7 +194,7 @@ def test_genotype_dtype(self, tmp_path, icf_path, dtype): zarr_path = tmp_path / "zarr" icf = vcf.IntermediateColumnarFormat(icf_path) schema = vcf.VcfZarrSchema.generate(icf) - schema.fields["call_genotype"].dtype = dtype + schema.field_map()["call_genotype"].dtype = dtype schema_path = tmp_path / "schema" with open(schema_path, "w") as f: f.write(schema.asjson()) @@ -203,47 +203,45 @@ def test_genotype_dtype(self, tmp_path, icf_path, dtype): assert root["call_genotype"].dtype == dtype +def get_field_dict(a_schema, name): + d = a_schema.asdict() + for field in d["fields"]: + if field["name"] == name: + return field + + class TestDefaultSchema: def test_format_version(self, schema): - assert schema["format_version"] == vcf.ZARR_SCHEMA_FORMAT_VERSION + assert schema.format_version == vcf.ZARR_SCHEMA_FORMAT_VERSION def test_chunk_size(self, schema): - assert schema["samples_chunk_size"] == 1000 - assert schema["variants_chunk_size"] == 10000 - - def test_dimensions(self, schema): - assert schema["dimensions"] == [ - "variants", - "samples", - "ploidy", - "alleles", - "filters", - ] + assert schema.samples_chunk_size == 1000 + assert schema.variants_chunk_size == 10000 def test_samples(self, schema): - assert schema["samples"] == [ + assert schema.asdict()["samples"] == [ {"id": s} for s in ["NA00001", "NA00002", "NA00003"] ] def test_contigs(self, schema): - assert schema["contigs"] == [ + assert schema.asdict()["contigs"] == [ {"id": s, "length": None} for s in ["19", "20", "X"] ] def test_filters(self, schema): - assert schema["filters"] == [ + assert schema.asdict()["filters"] == [ {"id": "PASS", "description": "All filters passed"}, {"id": "s50", "description": "Less than 50% of samples have data"}, {"id": "q10", "description": "Quality below 10"}, ] def test_variant_contig(self, schema): - assert schema["fields"]["variant_contig"] == { + assert get_field_dict(schema, "variant_contig") == { "name": "variant_contig", "dtype": "i1", - "shape": [9], - "chunks": [10000], - "dimensions": ["variants"], + "shape": (9,), + "chunks": (10000,), + "dimensions": ("variants",), "description": "", "vcf_field": None, "compressor": { @@ -253,16 +251,16 @@ def test_variant_contig(self, schema): "shuffle": 0, "blocksize": 0, }, - "filters": [], + "filters": tuple(), } def test_call_genotype(self, schema): - assert schema["fields"]["call_genotype"] == { + assert get_field_dict(schema, "call_genotype") == { "name": "call_genotype", "dtype": "i1", - "shape": [9, 3, 2], - "chunks": [10000, 1000], - "dimensions": ["variants", "samples", "ploidy"], + "shape": (9, 3, 2), + "chunks": (10000, 1000), + "dimensions": ("variants", "samples", "ploidy"), "description": "", "vcf_field": None, "compressor": { @@ -272,16 +270,16 @@ def test_call_genotype(self, schema): "shuffle": 2, "blocksize": 0, }, - "filters": [], + "filters": tuple(), } def test_call_genotype_mask(self, schema): - assert schema["fields"]["call_genotype_mask"] == { + assert get_field_dict(schema, "call_genotype_mask") == { "name": "call_genotype_mask", "dtype": "bool", - "shape": [9, 3, 2], - "chunks": [10000, 1000], - "dimensions": ["variants", "samples", "ploidy"], + "shape": (9, 3, 2), + "chunks": (10000, 1000), + "dimensions": ("variants", "samples", "ploidy"), "description": "", "vcf_field": None, "compressor": { @@ -291,16 +289,16 @@ def test_call_genotype_mask(self, schema): "shuffle": 2, "blocksize": 0, }, - "filters": [], + "filters": tuple(), } def test_call_genotype_phased(self, schema): - assert schema["fields"]["call_genotype_mask"] == { + assert get_field_dict(schema, "call_genotype_mask") == { "name": "call_genotype_mask", "dtype": "bool", - "shape": [9, 3, 2], - "chunks": [10000, 1000], - "dimensions": ["variants", "samples", "ploidy"], + "shape": (9, 3, 2), + "chunks": (10000, 1000), + "dimensions": ("variants", "samples", "ploidy"), "description": "", "vcf_field": None, "compressor": { @@ -310,16 +308,16 @@ def test_call_genotype_phased(self, schema): "shuffle": 2, "blocksize": 0, }, - "filters": [], + "filters": tuple(), } def test_call_GQ(self, schema): - assert schema["fields"]["call_GQ"] == { + assert get_field_dict(schema, "call_GQ") == { "name": "call_GQ", "dtype": "i1", - "shape": [9, 3], - "chunks": [10000, 1000], - "dimensions": ["variants", "samples"], + "shape": (9, 3), + "chunks": (10000, 1000), + "dimensions": ("variants", "samples"), "description": "Genotype Quality", "vcf_field": "FORMAT/GQ", "compressor": { @@ -329,7 +327,7 @@ def test_call_GQ(self, schema): "shuffle": 0, "blocksize": 0, }, - "filters": [], + "filters": tuple(), } @@ -379,7 +377,7 @@ class TestVcfDescriptions: ], ) def test_fields(self, schema, field, description): - assert schema["fields"][field]["description"] == description + assert schema.field_map()[field].description == description # This information is not in the schema yet, # https://github.com/sgkit-dev/bio2zarr/issues/123 @@ -562,3 +560,30 @@ def test_call_fields(self, tmp_path, field): self.generate_vcf(vcf_file, format_field=field) with pytest.raises(ValueError, match=f"FORMAT field name.*{field}"): vcf.explode(tmp_path / "x.icf", [tmp_path / "test.vcf.gz"]) + + +class TestBadSchemaChanges: + # [{'id': 'NA00001'}, {'id': 'NA00002'}, {'id': 'NA00003'}], + @pytest.mark.parametrize( + "samples", + [ + [], + [{"id": "NA00001"}, {"id": "NA00003"}], + [{"id": "NA00001"}, {"id": "NA00002"}, {"id": "NA00004"}], + [ + {"id": "NA00001"}, + {"id": "NA00002"}, + {"id": "NA00003"}, + {"id": "NA00004"}, + ], + [{"id": "NA00001"}, {"id": "NA00003"}, {"id": "NA00002"}], + ], + ) + def test_removed_samples(self, tmp_path, schema, icf_path, samples): + d = schema.asdict() + d["samples"] = samples + schema_path = tmp_path / "schema.json" + with open(schema_path, "w") as f: + json.dump(d, f) + with pytest.raises(ValueError, match="Subsetting or reordering samples"): + vcf.encode(icf_path, tmp_path / "z", schema_path=schema_path)