From bdebd92115e73e9d424999c53a431491e3fc7107 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 11 May 2024 22:51:39 +0100 Subject: [PATCH 1/4] Add JsonDataclass for asdict/asjson --- bio2zarr/core.py | 9 +++++++++ bio2zarr/vcf.py | 50 +++++++----------------------------------------- 2 files changed, 16 insertions(+), 43 deletions(-) diff --git a/bio2zarr/core.py b/bio2zarr/core.py index 16a1bcb4..60c25375 100644 --- a/bio2zarr/core.py +++ b/bio2zarr/core.py @@ -1,6 +1,7 @@ import concurrent.futures as cf import contextlib import dataclasses +import json import logging import multiprocessing import os @@ -277,3 +278,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._update_progress() self.progress_bar.close() return False + + +class JsonDataclass: + def asdict(self): + return dataclasses.asdict(self) + + def asjson(self): + return json.dumps(self.asdict(), indent=4) diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index 23bb566a..9ed3f451 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -50,7 +50,7 @@ def display_size(n): @dataclasses.dataclass -class VcfFieldSummary: +class VcfFieldSummary(core.JsonDataclass): num_chunks: int = 0 compressed_size: int = 0 uncompressed_size: int = 0 @@ -67,9 +67,6 @@ def update(self, other): self.min_value = min(self.min_value, other.min_value) self.max_value = max(self.max_value, other.max_value) - def asdict(self): - return dataclasses.asdict(self) - @staticmethod def fromdict(d): return VcfFieldSummary(**d) @@ -168,7 +165,7 @@ class Filter: @dataclasses.dataclass -class IcfMetadata: +class IcfMetadata(core.JsonDataclass): samples: list contigs: list filters: list @@ -226,12 +223,6 @@ def fromdict(d): d["contigs"] = [Contig(**cd) for cd in d["contigs"]] return IcfMetadata(**d) - def asdict(self): - return dataclasses.asdict(self) - - def asjson(self): - return json.dumps(self.asdict(), indent=4) - def fixed_vcf_field_definitions(): def make_field_def(name, vcf_type, vcf_number): @@ -933,17 +924,11 @@ def num_fields(self): @dataclasses.dataclass -class IcfPartitionMetadata: +class IcfPartitionMetadata(core.JsonDataclass): num_records: int last_position: int field_summaries: dict - def asdict(self): - return dataclasses.asdict(self) - - def asjson(self): - return json.dumps(self.asdict(), indent=4) - @staticmethod def fromdict(d): md = IcfPartitionMetadata(**d) @@ -987,17 +972,11 @@ def check_field_clobbering(icf_metadata): @dataclasses.dataclass -class IcfWriteSummary: +class IcfWriteSummary(core.JsonDataclass): num_partitions: int num_samples: int num_variants: int - def asdict(self): - return dataclasses.asdict(self) - - def asjson(self): - return json.dumps(self.asdict(), indent=4) - class IntermediateColumnarFormatWriter: def __init__(self, path): @@ -1409,7 +1388,7 @@ def variant_chunk_nbytes(self): @dataclasses.dataclass -class VcfZarrSchema: +class VcfZarrSchema(core.JsonDataclass): format_version: str samples_chunk_size: int variants_chunk_size: int @@ -1421,12 +1400,6 @@ class VcfZarrSchema: def field_map(self): return {field.name: field for field in self.fields} - def asdict(self): - return dataclasses.asdict(self) - - def asjson(self): - return json.dumps(self.asdict(), indent=4) - @staticmethod def fromdict(d): if d["format_version"] != ZARR_SCHEMA_FORMAT_VERSION: @@ -1645,7 +1618,7 @@ def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None @dataclasses.dataclass -class VcfZarrWriterMetadata: +class VcfZarrWriterMetadata(core.JsonDataclass): format_version: str icf_path: str schema: VcfZarrSchema @@ -1653,9 +1626,6 @@ class VcfZarrWriterMetadata: partitions: list provenance: dict - def asdict(self): - return dataclasses.asdict(self) - @staticmethod def fromdict(d): if d["format_version"] != VZW_METADATA_FORMAT_VERSION: @@ -1670,19 +1640,13 @@ def fromdict(d): @dataclasses.dataclass -class VcfZarrWriteSummary: +class VcfZarrWriteSummary(core.JsonDataclass): num_partitions: int num_samples: int num_variants: int num_chunks: int max_encoding_memory: str - def asdict(self): - return dataclasses.asdict(self) - - def asjson(self): - return json.dumps(self.asdict(), indent=4) - class VcfZarrWriter: def __init__(self, path): From 0aef4e0f1aeba9a22da1d09f697610a24ce2e445 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 11 May 2024 23:10:37 +0100 Subject: [PATCH 2/4] Move verification code to own file. --- bio2zarr/constants.py | 18 +++ bio2zarr/vcf.py | 286 +++---------------------------------- bio2zarr/verification.py | 231 ++++++++++++++++++++++++++++++ tests/test_vcf_examples.py | 14 +- validation.py | 4 +- 5 files changed, 278 insertions(+), 275 deletions(-) create mode 100644 bio2zarr/constants.py create mode 100644 bio2zarr/verification.py diff --git a/bio2zarr/constants.py b/bio2zarr/constants.py new file mode 100644 index 00000000..23ca5395 --- /dev/null +++ b/bio2zarr/constants.py @@ -0,0 +1,18 @@ +import numpy as np + +INT_MISSING = -1 +INT_FILL = -2 +STR_MISSING = "." +STR_FILL = "" + +FLOAT32_MISSING, FLOAT32_FILL = np.array([0x7F800001, 0x7F800002], dtype=np.int32).view( + np.float32 +) +FLOAT32_MISSING_AS_INT32, FLOAT32_FILL_AS_INT32 = np.array( + [0x7F800001, 0x7F800002], dtype=np.int32 +) + + +MIN_INT_VALUE = np.iinfo(np.int32).min + 2 +VCF_INT_MISSING = np.iinfo(np.int32).min +VCF_INT_FILL = np.iinfo(np.int32).min + 1 diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index 9ed3f451..52b6c065 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -13,30 +13,15 @@ import tempfile from typing import Any -import cyvcf2 import humanfriendly import numcodecs import numpy as np -import numpy.testing as nt -import tqdm import zarr -from . import core, provenance, vcf_utils +from . import constants, core, provenance, vcf_utils logger = logging.getLogger(__name__) -INT_MISSING = -1 -INT_FILL = -2 -STR_MISSING = "." -STR_FILL = "" - -FLOAT32_MISSING, FLOAT32_FILL = np.array([0x7F800001, 0x7F800002], dtype=np.int32).view( - np.float32 -) -FLOAT32_MISSING_AS_INT32, FLOAT32_FILL_AS_INT32 = np.array( - [0x7F800001, 0x7F800002], dtype=np.int32 -) - def display_number(x): ret = "n/a" @@ -377,7 +362,7 @@ def sanitise_value_bool(buff, j, value): def sanitise_value_float_scalar(buff, j, value): x = value if value is None: - x = [FLOAT32_MISSING] + x = [constants.FLOAT32_MISSING] buff[j] = x[0] @@ -385,7 +370,7 @@ def sanitise_value_int_scalar(buff, j, value): x = value if value is None: # print("MISSING", INT_MISSING, INT_FILL) - x = [INT_MISSING] + x = [constants.INT_MISSING] else: x = sanitise_int_array(value, ndmin=1, dtype=np.int32) buff[j] = x[0] @@ -436,33 +421,35 @@ def drop_empty_second_dim(value): def sanitise_value_float_1d(buff, j, value): if value is None: - buff[j] = FLOAT32_MISSING + buff[j] = constants.FLOAT32_MISSING else: value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False) # numpy will map None values to Nan, but we need a # specific NaN - value[np.isnan(value)] = FLOAT32_MISSING + value[np.isnan(value)] = constants.FLOAT32_MISSING value = drop_empty_second_dim(value) - buff[j] = FLOAT32_FILL + buff[j] = constants.FLOAT32_FILL buff[j, : value.shape[0]] = value def sanitise_value_float_2d(buff, j, value): if value is None: - buff[j] = FLOAT32_MISSING + buff[j] = constants.FLOAT32_MISSING else: # print("value = ", value) value = np.array(value, ndmin=2, dtype=buff.dtype, copy=False) - buff[j] = FLOAT32_FILL + buff[j] = constants.FLOAT32_FILL buff[j, :, : value.shape[1]] = value def sanitise_int_array(value, ndmin, dtype): if isinstance(value, tuple): - value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST + value = [ + constants.VCF_INT_MISSING if x is None else x for x in value + ] # NEEDS TEST value = np.array(value, ndmin=ndmin, copy=False) - value[value == VCF_INT_MISSING] = -1 - value[value == VCF_INT_FILL] = -2 + value[value == constants.VCF_INT_MISSING] = -1 + value[value == constants.VCF_INT_FILL] = -2 # TODO watch out for clipping here! return value.astype(dtype) @@ -486,15 +473,11 @@ def sanitise_value_int_2d(buff, j, value): buff[j, :, : value.shape[1]] = value -MIN_INT_VALUE = np.iinfo(np.int32).min + 2 -VCF_INT_MISSING = np.iinfo(np.int32).min -VCF_INT_FILL = np.iinfo(np.int32).min + 1 - missing_value_map = { - "Integer": -1, - "Float": FLOAT32_MISSING, - "String": ".", - "Character": ".", + "Integer": constants.INT_MISSING, + "Float": constants.FLOAT32_MISSING, + "String": constants.STR_MISSING, + "Character": constants.STR_MISSING, "Flag": False, } @@ -538,17 +521,12 @@ def transform_and_update_bounds(self, vcf_value): return value -MIN_INT_VALUE = np.iinfo(np.int32).min + 2 -VCF_INT_MISSING = np.iinfo(np.int32).min -VCF_INT_FILL = np.iinfo(np.int32).min + 1 - - class IntegerValueTransformer(VcfValueTransformer): def update_bounds(self, value): summary = self.field.summary # Mask out missing and fill values # print(value) - a = value[value >= MIN_INT_VALUE] + a = value[value >= constants.MIN_INT_VALUE] if a.size > 0: summary.max_value = int(max(summary.max_value, np.max(a))) summary.min_value = int(min(summary.min_value, np.min(a))) @@ -1935,7 +1913,7 @@ def encode_alleles_partition(self, partition_index): alt_col.iter_values(partition.start, partition.stop), ): j = alleles.next_buffer_row() - alleles.buff[j, :] = STR_FILL + alleles.buff[j, :] = constants.STR_FILL alleles.buff[j, 0] = ref[0] alleles.buff[j, 1 : 1 + len(alt)] = alt alleles.flush() @@ -1958,7 +1936,7 @@ def encode_id_partition(self, partition_index): vid.buff[j] = value[0] vid_mask.buff[j] = False else: - vid.buff[j] = STR_MISSING + vid.buff[j] = constants.STR_MISSING vid_mask.buff[j] = True vid.flush() vid_mask.flush() @@ -2253,227 +2231,3 @@ def convert( worker_processes=worker_processes, show_progress=show_progress, ) - - -def assert_all_missing_float(a): - v = np.array(a, dtype=np.float32).view(np.int32) - nt.assert_equal(v, FLOAT32_MISSING_AS_INT32) - - -def assert_all_fill_float(a): - v = np.array(a, dtype=np.float32).view(np.int32) - nt.assert_equal(v, FLOAT32_FILL_AS_INT32) - - -def assert_all_missing_int(a): - v = np.array(a, dtype=int) - nt.assert_equal(v, -1) - - -def assert_all_fill_int(a): - v = np.array(a, dtype=int) - nt.assert_equal(v, -2) - - -def assert_all_missing_string(a): - nt.assert_equal(a, ".") - - -def assert_all_fill_string(a): - nt.assert_equal(a, "") - - -def assert_all_fill(zarr_val, vcf_type): - if vcf_type == "Integer": - assert_all_fill_int(zarr_val) - elif vcf_type in ("String", "Character"): - assert_all_fill_string(zarr_val) - elif vcf_type == "Float": - assert_all_fill_float(zarr_val) - else: # pragma: no cover - assert False # noqa PT015 - - -def assert_all_missing(zarr_val, vcf_type): - if vcf_type == "Integer": - assert_all_missing_int(zarr_val) - elif vcf_type in ("String", "Character"): - assert_all_missing_string(zarr_val) - elif vcf_type == "Flag": - assert zarr_val == False # noqa 712 - elif vcf_type == "Float": - assert_all_missing_float(zarr_val) - else: # pragma: no cover - assert False # noqa PT015 - - -def assert_info_val_missing(zarr_val, vcf_type): - assert_all_missing(zarr_val, vcf_type) - - -def assert_format_val_missing(zarr_val, vcf_type): - assert_info_val_missing(zarr_val, vcf_type) - - -# Note: checking exact equality may prove problematic here -# but we should be deterministically storing what cyvcf2 -# provides, which should compare equal. - - -def assert_info_val_equal(vcf_val, zarr_val, vcf_type): - assert vcf_val is not None - if vcf_type in ("String", "Character"): - split = list(vcf_val.split(",")) - k = len(split) - if isinstance(zarr_val, str): - assert k == 1 - # Scalar - assert vcf_val == zarr_val - else: - nt.assert_equal(split, zarr_val[:k]) - assert_all_fill(zarr_val[k:], vcf_type) - - elif isinstance(vcf_val, tuple): - vcf_missing_value_map = { - "Integer": -1, - "Float": FLOAT32_MISSING, - } - v = [vcf_missing_value_map[vcf_type] if x is None else x for x in vcf_val] - missing = np.array([j for j, x in enumerate(vcf_val) if x is None], dtype=int) - a = np.array(v) - k = len(a) - # We are checking for int missing twice here, but it's necessary to have - # a separate check for floats because different NaNs compare equal - nt.assert_equal(a, zarr_val[:k]) - assert_all_missing(zarr_val[missing], vcf_type) - if k < len(zarr_val): - assert_all_fill(zarr_val[k:], vcf_type) - else: - # Scalar - zarr_val = np.array(zarr_val, ndmin=1) - assert len(zarr_val.shape) == 1 - assert vcf_val == zarr_val[0] - if len(zarr_val) > 1: - assert_all_fill(zarr_val[1:], vcf_type) - - -def assert_format_val_equal(vcf_val, zarr_val, vcf_type): - assert vcf_val is not None - assert isinstance(vcf_val, np.ndarray) - if vcf_type in ("String", "Character"): - assert len(vcf_val) == len(zarr_val) - for v, z in zip(vcf_val, zarr_val): - split = list(v.split(",")) - # Note: deliberately duplicating logic here between this and the - # INFO col above to make sure all combinations are covered by tests - k = len(split) - if k == 1: - assert v == z - else: - nt.assert_equal(split, z[:k]) - assert_all_fill(z[k:], vcf_type) - else: - assert vcf_val.shape[0] == zarr_val.shape[0] - if len(vcf_val.shape) == len(zarr_val.shape) + 1: - assert vcf_val.shape[-1] == 1 - vcf_val = vcf_val[..., 0] - assert len(vcf_val.shape) <= 2 - assert len(vcf_val.shape) == len(zarr_val.shape) - if len(vcf_val.shape) == 2: - k = vcf_val.shape[1] - if zarr_val.shape[1] != k: - assert_all_fill(zarr_val[:, k:], vcf_type) - zarr_val = zarr_val[:, :k] - assert vcf_val.shape == zarr_val.shape - if vcf_type == "Integer": - vcf_val[vcf_val == VCF_INT_MISSING] = INT_MISSING - vcf_val[vcf_val == VCF_INT_FILL] = INT_FILL - elif vcf_type == "Float": - nt.assert_equal(vcf_val.view(np.int32), zarr_val.view(np.int32)) - - nt.assert_equal(vcf_val, zarr_val) - - -# TODO rename to "verify" -def validate(vcf_path, zarr_path, show_progress=False): - store = zarr.DirectoryStore(zarr_path) - - root = zarr.group(store=store) - pos = root["variant_position"][:] - allele = root["variant_allele"][:] - chrom = root["contig_id"][:][root["variant_contig"][:]] - vid = root["variant_id"][:] - call_genotype = None - if "call_genotype" in root: - call_genotype = iter(root["call_genotype"]) - - vcf = cyvcf2.VCF(vcf_path) - format_headers = {} - info_headers = {} - for h in vcf.header_iter(): - if h["HeaderType"] == "FORMAT": - format_headers[h["ID"]] = h - if h["HeaderType"] == "INFO": - info_headers[h["ID"]] = h - - format_fields = {} - info_fields = {} - for colname in root.keys(): - if colname.startswith("call") and not colname.startswith("call_genotype"): - vcf_name = colname.split("_", 1)[1] - vcf_type = format_headers[vcf_name]["Type"] - format_fields[vcf_name] = vcf_type, iter(root[colname]) - if colname.startswith("variant"): - name = colname.split("_", 1)[1] - if name.isupper(): - vcf_type = info_headers[name]["Type"] - info_fields[name] = vcf_type, iter(root[colname]) - - first_pos = next(vcf).POS - start_index = np.searchsorted(pos, first_pos) - assert pos[start_index] == first_pos - vcf = cyvcf2.VCF(vcf_path) - if show_progress: - iterator = tqdm.tqdm(vcf, desc=" Verify", total=vcf.num_records) # NEEDS TEST - else: - iterator = vcf - for j, row in enumerate(iterator, start_index): - assert chrom[j] == row.CHROM - assert pos[j] == row.POS - assert vid[j] == ("." if row.ID is None else row.ID) - assert allele[j, 0] == row.REF - k = len(row.ALT) - nt.assert_array_equal(allele[j, 1 : k + 1], row.ALT) - assert np.all(allele[j, k + 1 :] == "") - # TODO FILTERS - - if call_genotype is None: - val = None - try: - val = row.format("GT") - except KeyError: - pass - assert val is None - else: - gt = row.genotype.array() - gt_zarr = next(call_genotype) - gt_vcf = gt[:, :-1] - # NOTE cyvcf2 remaps genotypes automatically - # into the same missing/pad encoding that sgkit uses. - nt.assert_array_equal(gt_zarr, gt_vcf) - - for name, (vcf_type, zarr_iter) in info_fields.items(): - vcf_val = row.INFO.get(name, None) - zarr_val = next(zarr_iter) - if vcf_val is None: - assert_info_val_missing(zarr_val, vcf_type) - else: - assert_info_val_equal(vcf_val, zarr_val, vcf_type) - - for name, (vcf_type, zarr_iter) in format_fields.items(): - vcf_val = row.format(name) - zarr_val = next(zarr_iter) - if vcf_val is None: - assert_format_val_missing(zarr_val, vcf_type) - else: - assert_format_val_equal(vcf_val, zarr_val, vcf_type) diff --git a/bio2zarr/verification.py b/bio2zarr/verification.py new file mode 100644 index 00000000..2656e594 --- /dev/null +++ b/bio2zarr/verification.py @@ -0,0 +1,231 @@ +import cyvcf2 +import numpy as np +import numpy.testing as nt +import tqdm +import zarr + +from . import constants + + +def assert_all_missing_float(a): + v = np.array(a, dtype=np.float32).view(np.int32) + nt.assert_equal(v, constants.FLOAT32_MISSING_AS_INT32) + + +def assert_all_fill_float(a): + v = np.array(a, dtype=np.float32).view(np.int32) + nt.assert_equal(v, constants.FLOAT32_FILL_AS_INT32) + + +def assert_all_missing_int(a): + v = np.array(a, dtype=int) + nt.assert_equal(v, constants.INT_MISSING) + + +def assert_all_fill_int(a): + v = np.array(a, dtype=int) + nt.assert_equal(v, constants.INT_FILL) + + +def assert_all_missing_string(a): + nt.assert_equal(a, constants.STR_MISSING) + + +def assert_all_fill_string(a): + nt.assert_equal(a, constants.STR_FILL) + + +def assert_all_fill(zarr_val, vcf_type): + if vcf_type == "Integer": + assert_all_fill_int(zarr_val) + elif vcf_type in ("String", "Character"): + assert_all_fill_string(zarr_val) + elif vcf_type == "Float": + assert_all_fill_float(zarr_val) + else: # pragma: no cover + assert False # noqa PT015 + + +def assert_all_missing(zarr_val, vcf_type): + if vcf_type == "Integer": + assert_all_missing_int(zarr_val) + elif vcf_type in ("String", "Character"): + assert_all_missing_string(zarr_val) + elif vcf_type == "Flag": + assert zarr_val == False # noqa 712 + elif vcf_type == "Float": + assert_all_missing_float(zarr_val) + else: # pragma: no cover + assert False # noqa PT015 + + +def assert_info_val_missing(zarr_val, vcf_type): + assert_all_missing(zarr_val, vcf_type) + + +def assert_format_val_missing(zarr_val, vcf_type): + assert_info_val_missing(zarr_val, vcf_type) + + +# Note: checking exact equality may prove problematic here +# but we should be deterministically storing what cyvcf2 +# provides, which should compare equal. + + +def assert_info_val_equal(vcf_val, zarr_val, vcf_type): + assert vcf_val is not None + if vcf_type in ("String", "Character"): + split = list(vcf_val.split(",")) + k = len(split) + if isinstance(zarr_val, str): + assert k == 1 + # Scalar + assert vcf_val == zarr_val + else: + nt.assert_equal(split, zarr_val[:k]) + assert_all_fill(zarr_val[k:], vcf_type) + + elif isinstance(vcf_val, tuple): + vcf_missing_value_map = { + "Integer": constants.INT_MISSING, + "Float": constants.FLOAT32_MISSING, + } + v = [vcf_missing_value_map[vcf_type] if x is None else x for x in vcf_val] + missing = np.array([j for j, x in enumerate(vcf_val) if x is None], dtype=int) + a = np.array(v) + k = len(a) + # We are checking for int missing twice here, but it's necessary to have + # a separate check for floats because different NaNs compare equal + nt.assert_equal(a, zarr_val[:k]) + assert_all_missing(zarr_val[missing], vcf_type) + if k < len(zarr_val): + assert_all_fill(zarr_val[k:], vcf_type) + else: + # Scalar + zarr_val = np.array(zarr_val, ndmin=1) + assert len(zarr_val.shape) == 1 + assert vcf_val == zarr_val[0] + if len(zarr_val) > 1: + assert_all_fill(zarr_val[1:], vcf_type) + + +def assert_format_val_equal(vcf_val, zarr_val, vcf_type): + assert vcf_val is not None + assert isinstance(vcf_val, np.ndarray) + if vcf_type in ("String", "Character"): + assert len(vcf_val) == len(zarr_val) + for v, z in zip(vcf_val, zarr_val): + split = list(v.split(",")) + # Note: deliberately duplicating logic here between this and the + # INFO col above to make sure all combinations are covered by tests + k = len(split) + if k == 1: + assert v == z + else: + nt.assert_equal(split, z[:k]) + assert_all_fill(z[k:], vcf_type) + else: + assert vcf_val.shape[0] == zarr_val.shape[0] + if len(vcf_val.shape) == len(zarr_val.shape) + 1: + assert vcf_val.shape[-1] == 1 + vcf_val = vcf_val[..., 0] + assert len(vcf_val.shape) <= 2 + assert len(vcf_val.shape) == len(zarr_val.shape) + if len(vcf_val.shape) == 2: + k = vcf_val.shape[1] + if zarr_val.shape[1] != k: + assert_all_fill(zarr_val[:, k:], vcf_type) + zarr_val = zarr_val[:, :k] + assert vcf_val.shape == zarr_val.shape + if vcf_type == "Integer": + vcf_val[vcf_val == constants.VCF_INT_MISSING] = constants.INT_MISSING + vcf_val[vcf_val == constants.VCF_INT_FILL] = constants.INT_FILL + elif vcf_type == "Float": + nt.assert_equal(vcf_val.view(np.int32), zarr_val.view(np.int32)) + + nt.assert_equal(vcf_val, zarr_val) + + +# TODO rename to "verify" +def validate(vcf_path, zarr_path, show_progress=False): + store = zarr.DirectoryStore(zarr_path) + + root = zarr.group(store=store) + pos = root["variant_position"][:] + allele = root["variant_allele"][:] + chrom = root["contig_id"][:][root["variant_contig"][:]] + vid = root["variant_id"][:] + call_genotype = None + if "call_genotype" in root: + call_genotype = iter(root["call_genotype"]) + + vcf = cyvcf2.VCF(vcf_path) + format_headers = {} + info_headers = {} + for h in vcf.header_iter(): + if h["HeaderType"] == "FORMAT": + format_headers[h["ID"]] = h + if h["HeaderType"] == "INFO": + info_headers[h["ID"]] = h + + format_fields = {} + info_fields = {} + for colname in root.keys(): + if colname.startswith("call") and not colname.startswith("call_genotype"): + vcf_name = colname.split("_", 1)[1] + vcf_type = format_headers[vcf_name]["Type"] + format_fields[vcf_name] = vcf_type, iter(root[colname]) + if colname.startswith("variant"): + name = colname.split("_", 1)[1] + if name.isupper(): + vcf_type = info_headers[name]["Type"] + info_fields[name] = vcf_type, iter(root[colname]) + + first_pos = next(vcf).POS + start_index = np.searchsorted(pos, first_pos) + assert pos[start_index] == first_pos + vcf = cyvcf2.VCF(vcf_path) + if show_progress: + iterator = tqdm.tqdm(vcf, desc=" Verify", total=vcf.num_records) # NEEDS TEST + else: + iterator = vcf + for j, row in enumerate(iterator, start_index): + assert chrom[j] == row.CHROM + assert pos[j] == row.POS + assert vid[j] == ("." if row.ID is None else row.ID) + assert allele[j, 0] == row.REF + k = len(row.ALT) + nt.assert_array_equal(allele[j, 1 : k + 1], row.ALT) + assert np.all(allele[j, k + 1 :] == "") + # TODO FILTERS + + if call_genotype is None: + val = None + try: + val = row.format("GT") + except KeyError: + pass + assert val is None + else: + gt = row.genotype.array() + gt_zarr = next(call_genotype) + gt_vcf = gt[:, :-1] + # NOTE cyvcf2 remaps genotypes automatically + # into the same missing/pad encoding that sgkit uses. + nt.assert_array_equal(gt_zarr, gt_vcf) + + for name, (vcf_type, zarr_iter) in info_fields.items(): + vcf_val = row.INFO.get(name, None) + zarr_val = next(zarr_iter) + if vcf_val is None: + assert_info_val_missing(zarr_val, vcf_type) + else: + assert_info_val_equal(vcf_val, zarr_val, vcf_type) + + for name, (vcf_type, zarr_iter) in format_fields.items(): + vcf_val = row.format(name) + zarr_val = next(zarr_iter) + if vcf_val is None: + assert_format_val_missing(zarr_val, vcf_type) + else: + assert_format_val_equal(vcf_val, zarr_val, vcf_type) diff --git a/tests/test_vcf_examples.py b/tests/test_vcf_examples.py index e6a73121..27d22468 100644 --- a/tests/test_vcf_examples.py +++ b/tests/test_vcf_examples.py @@ -8,7 +8,7 @@ import sgkit as sg import xarray.testing as xt -from bio2zarr import provenance, vcf +from bio2zarr import constants, provenance, vcf, verification class TestSmallExample: @@ -81,8 +81,8 @@ def test_int_info_fields(self, ds): ) def test_float_info_fields(self, ds): - missing = vcf.FLOAT32_MISSING - fill = vcf.FLOAT32_FILL + missing = constants.FLOAT32_MISSING + fill = constants.FLOAT32_FILL variant_AF = np.array( [ [missing, missing], @@ -137,7 +137,7 @@ def test_flag_info_fields(self, ds): ) def test_allele(self, ds): - fill = vcf.STR_FILL + fill = constants.STR_FILL nt.assert_array_equal( ds["variant_allele"].values.tolist(), [ @@ -844,7 +844,7 @@ def test_by_validating(name, tmp_path): path = f"tests/data/vcf/{name}" out = tmp_path / "test.zarr" vcf.convert([path], out, worker_processes=0) - vcf.validate(path, out) + verification.validate(path, out) @pytest.mark.parametrize( @@ -862,7 +862,7 @@ def test_by_validating_split(source, suffix, files, tmp_path): split_files = [f"{source_path}.{suffix}/{f}" for f in files] out = tmp_path / "test.zarr" vcf.convert(split_files, out, worker_processes=0) - vcf.validate(source_path, out) + verification.validate(source_path, out) def test_split_explode(tmp_path): @@ -891,7 +891,7 @@ def test_split_explode(tmp_path): "min_value": 10, } vcf.encode(out, tmp_path / "test.zarr") - vcf.validate("tests/data/vcf/sample.vcf.gz", tmp_path / "test.zarr") + verification.validate("tests/data/vcf/sample.vcf.gz", tmp_path / "test.zarr") def test_missing_filter(tmp_path): diff --git a/validation.py b/validation.py index 19baac46..05a2b578 100644 --- a/validation.py +++ b/validation.py @@ -6,7 +6,7 @@ import click -from bio2zarr import vcf +from bio2zarr import vcf, verification # TODO add support here for split vcfs. Perhaps simplest to take a # directory provided as input as indicating this, and then having @@ -67,7 +67,7 @@ def cli(vcfs, worker_processes, force): show_progress=True, ) - vcf.validate(source_file, zarr, show_progress=True) + verification.validate(source_file, zarr, show_progress=True) if __name__ == "__main__": From d08b24d4b279bce5e30993cffbcf7eba2a433602 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 11 May 2024 23:22:43 +0100 Subject: [PATCH 3/4] Move display functions into core --- bio2zarr/core.py | 13 +++++++++++++ bio2zarr/vcf.py | 41 ++++++++++++++++------------------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/bio2zarr/core.py b/bio2zarr/core.py index 60c25375..b94e7eff 100644 --- a/bio2zarr/core.py +++ b/bio2zarr/core.py @@ -3,12 +3,14 @@ import dataclasses import json import logging +import math import multiprocessing import os import os.path import threading import time +import humanfriendly import numcodecs import numpy as np import tqdm @@ -19,6 +21,17 @@ numcodecs.blosc.use_threads = False +def display_number(x): + ret = "n/a" + if math.isfinite(x): + ret = f"{x: 0.2g}" + return ret + + +def display_size(n): + return humanfriendly.format_size(n, binary=True) + + def min_int_dtype(min_value, max_value): if min_value > max_value: raise ValueError("min_value must be <= max_value") diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index 52b6c065..ea763b3d 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -23,17 +23,6 @@ logger = logging.getLogger(__name__) -def display_number(x): - ret = "n/a" - if math.isfinite(x): - ret = f"{x: 0.2g}" - return ret - - -def display_size(n): - return humanfriendly.format_size(n, binary=True) - - @dataclasses.dataclass class VcfFieldSummary(core.JsonDataclass): num_chunks: int = 0 @@ -874,11 +863,11 @@ def summary_table(self): "name": name, "type": col.vcf_field.vcf_type, "chunks": summary.num_chunks, - "size": display_size(summary.uncompressed_size), - "compressed": display_size(summary.compressed_size), + "size": core.display_size(summary.uncompressed_size), + "compressed": core.display_size(summary.compressed_size), "max_n": summary.max_number, - "min_val": display_number(summary.min_value), - "max_val": display_number(summary.max_value), + "min_val": core.display_number(summary.min_value), + "max_val": core.display_number(summary.max_value), } data.append(d) @@ -1546,12 +1535,12 @@ def summary_table(self): d = { "name": array.name, "dtype": str(array.dtype), - "stored": display_size(stored), - "size": display_size(array.nbytes), - "ratio": display_number(array.nbytes / stored), + "stored": core.display_size(stored), + "size": core.display_size(array.nbytes), + "ratio": core.display_number(array.nbytes / stored), "nchunks": str(array.nchunks), - "chunk_size": display_size(array.nbytes / array.nchunks), - "avg_chunk_stored": display_size(int(stored / array.nchunks)), + "chunk_size": core.display_size(array.nbytes / array.nchunks), + "avg_chunk_stored": core.display_size(int(stored / array.nchunks)), "shape": str(array.shape), "chunk_shape": str(array.chunks), "compressor": str(array.compressor), @@ -1567,7 +1556,7 @@ def parse_max_memory(max_memory): return 2**63 if isinstance(max_memory, str): max_memory = humanfriendly.parse_size(max_memory) - logger.info(f"Set memory budget to {display_size(max_memory)}") + logger.info(f"Set memory budget to {core.display_size(max_memory)}") return max_memory @@ -1721,7 +1710,7 @@ def init( num_samples=self.icf.num_samples, num_partitions=self.num_partitions, num_chunks=total_chunks, - max_encoding_memory=display_size(self.get_max_encoding_memory()), + max_encoding_memory=core.display_size(self.get_max_encoding_memory()), ) def encode_samples(self, root): @@ -2082,7 +2071,7 @@ def encode_all_partitions( per_worker_memory = self.get_max_encoding_memory() logger.info( f"Encoding Zarr over {num_partitions} partitions with " - f"{worker_processes} workers and {display_size(per_worker_memory)} " + f"{worker_processes} workers and {core.display_size(per_worker_memory)} " "per worker" ) # Each partition requires per_worker_memory bytes, so to prevent more that @@ -2091,12 +2080,14 @@ def encode_all_partitions( if max_num_workers < worker_processes: logger.warning( f"Limiting number of workers to {max_num_workers} to " - f"keep within specified memory budget of {display_size(max_memory)}" + "keep within specified memory budget of " + f"{core.display_size(max_memory)}" ) if max_num_workers <= 0: raise ValueError( f"Insufficient memory to encode a partition:" - f"{display_size(per_worker_memory)} > {display_size(max_memory)}" + f"{core.display_size(per_worker_memory)} > " + f"{core.display_size(max_memory)}" ) num_workers = min(max_num_workers, worker_processes) From 49e75c713786f05d1696e7cc7d2dcb37e5f89dfb Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 11 May 2024 23:45:54 +0100 Subject: [PATCH 4/4] Move ICF code into icf.py Note quite a clean break between modules, but not bad and a good step in the right direction. --- bio2zarr/cli.py | 12 +- bio2zarr/icf.py | 1220 +++++++++++++++++++++++++++++++++++ bio2zarr/vcf.py | 1236 +----------------------------------- tests/test_cli.py | 34 +- tests/test_icf.py | 111 ++-- tests/test_vcf.py | 49 +- tests/test_vcf_examples.py | 19 +- validation.py | 4 +- 8 files changed, 1350 insertions(+), 1335 deletions(-) create mode 100644 bio2zarr/icf.py diff --git a/bio2zarr/cli.py b/bio2zarr/cli.py index 2dd04c97..40c257e5 100644 --- a/bio2zarr/cli.py +++ b/bio2zarr/cli.py @@ -8,7 +8,7 @@ import numcodecs import tabulate -from . import plink, provenance, vcf, vcf_utils +from . import icf, plink, provenance, vcf, vcf_utils logger = logging.getLogger(__name__) @@ -167,7 +167,7 @@ def check_overwrite_dir(path, force): def get_compressor(cname): if cname is None: return None - config = vcf.ICF_DEFAULT_COMPRESSOR.get_config() + config = icf.ICF_DEFAULT_COMPRESSOR.get_config() config["cname"] = cname return numcodecs.get_codec(config) @@ -198,7 +198,7 @@ def explode( """ setup_logging(verbose) check_overwrite_dir(icf_path, force) - vcf.explode( + icf.explode( icf_path, vcfs, worker_processes=worker_processes, @@ -235,7 +235,7 @@ def dexplode_init( """ setup_logging(verbose) check_overwrite_dir(icf_path, force) - work_summary = vcf.explode_init( + work_summary = icf.explode_init( icf_path, vcfs, target_num_partitions=num_partitions, @@ -263,7 +263,7 @@ def dexplode_partition(icf_path, partition, verbose, one_based): setup_logging(verbose) if one_based: partition -= 1 - vcf.explode_partition(icf_path, partition) + icf.explode_partition(icf_path, partition) @click.command @@ -274,7 +274,7 @@ def dexplode_finalise(icf_path, verbose): Final step for distributed conversion of VCF(s) to intermediate columnar format. """ setup_logging(verbose) - vcf.explode_finalise(icf_path) + icf.explode_finalise(icf_path) @click.command diff --git a/bio2zarr/icf.py b/bio2zarr/icf.py new file mode 100644 index 00000000..57d6d5e5 --- /dev/null +++ b/bio2zarr/icf.py @@ -0,0 +1,1220 @@ +import collections +import contextlib +import dataclasses +import json +import logging +import math +import pathlib +import pickle +import shutil +import sys +from typing import Any + +import numcodecs +import numpy as np + +from . import constants, core, provenance, vcf_utils + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class VcfFieldSummary(core.JsonDataclass): + num_chunks: int = 0 + compressed_size: int = 0 + uncompressed_size: int = 0 + max_number: int = 0 # Corresponds to VCF Number field, depends on context + # Only defined for numeric fields + max_value: Any = -math.inf + min_value: Any = math.inf + + def update(self, other): + self.num_chunks += other.num_chunks + self.compressed_size += other.compressed_size + self.uncompressed_size += other.uncompressed_size + self.max_number = max(self.max_number, other.max_number) + self.min_value = min(self.min_value, other.min_value) + self.max_value = max(self.max_value, other.max_value) + + @staticmethod + def fromdict(d): + return VcfFieldSummary(**d) + + +@dataclasses.dataclass +class VcfField: + category: str + name: str + vcf_number: str + vcf_type: str + description: str + summary: VcfFieldSummary + + @staticmethod + def from_header(definition): + category = definition["HeaderType"] + name = definition["ID"] + vcf_number = definition["Number"] + vcf_type = definition["Type"] + return VcfField( + category=category, + name=name, + vcf_number=vcf_number, + vcf_type=vcf_type, + description=definition["Description"].strip('"'), + summary=VcfFieldSummary(), + ) + + @staticmethod + def fromdict(d): + f = VcfField(**d) + f.summary = VcfFieldSummary(**d["summary"]) + return f + + @property + def full_name(self): + if self.category == "fixed": + return self.name + return f"{self.category}/{self.name}" + + def smallest_dtype(self): + """ + Returns the smallest dtype suitable for this field based + on type, and values. + """ + s = self.summary + if self.vcf_type == "Float": + ret = "f4" + elif self.vcf_type == "Integer": + if not math.isfinite(s.max_value): + # All missing values; use i1. Note we should have some API to + # check more explicitly for missingness: + # https://github.com/sgkit-dev/bio2zarr/issues/131 + ret = "i1" + else: + ret = core.min_int_dtype(s.min_value, s.max_value) + elif self.vcf_type == "Flag": + ret = "bool" + elif self.vcf_type == "Character": + ret = "U1" + else: + assert self.vcf_type == "String" + ret = "O" + return ret + + +@dataclasses.dataclass +class VcfPartition: + vcf_path: str + region: str + num_records: int = -1 + + +ICF_METADATA_FORMAT_VERSION = "0.3" +ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc( + cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE +) + + +@dataclasses.dataclass +class Contig: + id: str + length: int = None + + +@dataclasses.dataclass +class Sample: + id: str + + +@dataclasses.dataclass +class Filter: + id: str + description: str = "" + + +@dataclasses.dataclass +class IcfMetadata(core.JsonDataclass): + samples: list + contigs: list + filters: list + fields: list + partitions: list = None + format_version: str = None + compressor: dict = None + column_chunk_size: int = None + provenance: dict = None + num_records: int = -1 + + @property + def info_fields(self): + fields = [] + for field in self.fields: + if field.category == "INFO": + fields.append(field) + return fields + + @property + def format_fields(self): + fields = [] + for field in self.fields: + if field.category == "FORMAT": + fields.append(field) + return fields + + @property + def num_contigs(self): + return len(self.contigs) + + @property + def num_filters(self): + return len(self.filters) + + @property + def num_samples(self): + return len(self.samples) + + @staticmethod + def fromdict(d): + if d["format_version"] != ICF_METADATA_FORMAT_VERSION: + raise ValueError( + "Intermediate columnar metadata format version mismatch: " + f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}" + ) + partitions = [VcfPartition(**pd) for pd in d["partitions"]] + for p in partitions: + p.region = vcf_utils.Region(**p.region) + d = d.copy() + d["partitions"] = partitions + d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]] + d["samples"] = [Sample(**sd) for sd in d["samples"]] + d["filters"] = [Filter(**fd) for fd in d["filters"]] + d["contigs"] = [Contig(**cd) for cd in d["contigs"]] + return IcfMetadata(**d) + + +def fixed_vcf_field_definitions(): + def make_field_def(name, vcf_type, vcf_number): + return VcfField( + category="fixed", + name=name, + vcf_type=vcf_type, + vcf_number=vcf_number, + description="", + summary=VcfFieldSummary(), + ) + + fields = [ + make_field_def("CHROM", "String", "1"), + make_field_def("POS", "Integer", "1"), + make_field_def("QUAL", "Float", "1"), + make_field_def("ID", "String", "."), + make_field_def("FILTERS", "String", "."), + make_field_def("REF", "String", "1"), + make_field_def("ALT", "String", "."), + ] + return fields + + +def scan_vcf(path, target_num_partitions): + with vcf_utils.IndexedVcf(path) as indexed_vcf: + vcf = indexed_vcf.vcf + filters = [] + pass_index = -1 + for h in vcf.header_iter(): + if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str): + try: + description = h["Description"].strip('"') + except KeyError: + description = "" + if h["ID"] == "PASS": + pass_index = len(filters) + filters.append(Filter(h["ID"], description)) + + # Ensure PASS is the first filter if present + if pass_index > 0: + pass_filter = filters.pop(pass_index) + filters.insert(0, pass_filter) + + fields = fixed_vcf_field_definitions() + for h in vcf.header_iter(): + if h["HeaderType"] in ["INFO", "FORMAT"]: + field = VcfField.from_header(h) + if field.name == "GT": + field.vcf_type = "Integer" + field.vcf_number = "." + fields.append(field) + + try: + contig_lengths = vcf.seqlens + except AttributeError: + contig_lengths = [None for _ in vcf.seqnames] + + metadata = IcfMetadata( + samples=[Sample(sample_id) for sample_id in vcf.samples], + contigs=[ + Contig(contig_id, length) + for contig_id, length in zip(vcf.seqnames, contig_lengths) + ], + filters=filters, + fields=fields, + partitions=[], + num_records=sum(indexed_vcf.contig_record_counts().values()), + ) + + regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions) + logger.info( + f"Split {path} into {len(regions)} regions (target={target_num_partitions})" + ) + for region in regions: + metadata.partitions.append( + VcfPartition( + # TODO should this be fully resolving the path? Otherwise it's all + # relative to the original WD + vcf_path=str(path), + region=region, + ) + ) + core.update_progress(1) + return metadata, vcf.raw_header + + +def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1): + logger.info( + f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}" + f" partitions." + ) + # An easy mistake to make is to pass the same file twice. Check this early on. + for path, count in collections.Counter(paths).items(): + if not path.exists(): # NEEDS TEST + raise FileNotFoundError(path) + if count > 1: + raise ValueError(f"Duplicate path provided: {path}") + + progress_config = core.ProgressConfig( + total=len(paths), + units="files", + title="Scan", + show=show_progress, + ) + with core.ParallelWorkManager(worker_processes, progress_config) as pwm: + for path in paths: + pwm.submit(scan_vcf, path, max(1, target_num_partitions // len(paths))) + results = list(pwm.results_as_completed()) + + # Sort to make the ordering deterministic + results.sort(key=lambda t: t[0].partitions[0].vcf_path) + # We just take the first header, assuming the others + # are compatible. + all_partitions = [] + total_records = 0 + for metadata, _ in results: + for partition in metadata.partitions: + logger.debug(f"Scanned partition {partition}") + all_partitions.append(partition) + total_records += metadata.num_records + metadata.num_records = 0 + metadata.partitions = [] + + icf_metadata, header = results[0] + for metadata, _ in results[1:]: + if metadata != icf_metadata: + raise ValueError("Incompatible VCF chunks") + + # Note: this will be infinity here if any of the chunks has an index + # that doesn't keep track of the number of records per-contig + icf_metadata.num_records = total_records + + # Sort by contig (in the order they appear in the header) first, + # then by start coordinate + contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)} + all_partitions.sort( + key=lambda x: (contig_index_map[x.region.contig], x.region.start) + ) + icf_metadata.partitions = all_partitions + logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.") + return icf_metadata, header + + +def sanitise_value_bool(buff, j, value): + x = True + if value is None: + x = False + buff[j] = x + + +def sanitise_value_float_scalar(buff, j, value): + x = value + if value is None: + x = [constants.FLOAT32_MISSING] + buff[j] = x[0] + + +def sanitise_value_int_scalar(buff, j, value): + x = value + if value is None: + # print("MISSING", INT_MISSING, INT_FILL) + x = [constants.INT_MISSING] + else: + x = sanitise_int_array(value, ndmin=1, dtype=np.int32) + buff[j] = x[0] + + +def sanitise_value_string_scalar(buff, j, value): + if value is None: + buff[j] = "." + else: + buff[j] = value[0] + + +def sanitise_value_string_1d(buff, j, value): + if value is None: + buff[j] = "." + else: + # value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False) + # FIXME failure isn't coming from here, it seems to be from an + # incorrectly detected dimension in the zarr array + # The dimesions look all wrong, and the dtype should be Object + # not str + value = drop_empty_second_dim(value) + buff[j] = "" + buff[j, : value.shape[0]] = value + + +def sanitise_value_string_2d(buff, j, value): + if value is None: + buff[j] = "." + else: + # print(buff.shape, value.dtype, value) + # assert value.ndim == 2 + buff[j] = "" + if value.ndim == 2: + buff[j, :, : value.shape[1]] = value + else: + # TODO check if this is still necessary + for k, val in enumerate(value): + buff[j, k, : len(val)] = val + + +def drop_empty_second_dim(value): + assert len(value.shape) == 1 or value.shape[1] == 1 + if len(value.shape) == 2 and value.shape[1] == 1: + value = value[..., 0] + return value + + +def sanitise_value_float_1d(buff, j, value): + if value is None: + buff[j] = constants.FLOAT32_MISSING + else: + value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False) + # numpy will map None values to Nan, but we need a + # specific NaN + value[np.isnan(value)] = constants.FLOAT32_MISSING + value = drop_empty_second_dim(value) + buff[j] = constants.FLOAT32_FILL + buff[j, : value.shape[0]] = value + + +def sanitise_value_float_2d(buff, j, value): + if value is None: + buff[j] = constants.FLOAT32_MISSING + else: + # print("value = ", value) + value = np.array(value, ndmin=2, dtype=buff.dtype, copy=False) + buff[j] = constants.FLOAT32_FILL + buff[j, :, : value.shape[1]] = value + + +def sanitise_int_array(value, ndmin, dtype): + if isinstance(value, tuple): + value = [ + constants.VCF_INT_MISSING if x is None else x for x in value + ] # NEEDS TEST + value = np.array(value, ndmin=ndmin, copy=False) + value[value == constants.VCF_INT_MISSING] = -1 + value[value == constants.VCF_INT_FILL] = -2 + # TODO watch out for clipping here! + return value.astype(dtype) + + +def sanitise_value_int_1d(buff, j, value): + if value is None: + buff[j] = -1 + else: + value = sanitise_int_array(value, 1, buff.dtype) + value = drop_empty_second_dim(value) + buff[j] = -2 + buff[j, : value.shape[0]] = value + + +def sanitise_value_int_2d(buff, j, value): + if value is None: + buff[j] = -1 + else: + value = sanitise_int_array(value, 2, buff.dtype) + buff[j] = -2 + buff[j, :, : value.shape[1]] = value + + +missing_value_map = { + "Integer": constants.INT_MISSING, + "Float": constants.FLOAT32_MISSING, + "String": constants.STR_MISSING, + "Character": constants.STR_MISSING, + "Flag": False, +} + + +class VcfValueTransformer: + """ + Transform VCF values into the stored intermediate format used + in the IntermediateColumnarFormat, and update field summaries. + """ + + def __init__(self, field, num_samples): + self.field = field + self.num_samples = num_samples + self.dimension = 1 + if field.category == "FORMAT": + self.dimension = 2 + self.missing = missing_value_map[field.vcf_type] + + @staticmethod + def factory(field, num_samples): + if field.vcf_type in ("Integer", "Flag"): + return IntegerValueTransformer(field, num_samples) + if field.vcf_type == "Float": + return FloatValueTransformer(field, num_samples) + if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]: + return SplitStringValueTransformer(field, num_samples) + return StringValueTransformer(field, num_samples) + + def transform(self, vcf_value): + if isinstance(vcf_value, tuple): + vcf_value = [self.missing if v is None else v for v in vcf_value] + value = np.array(vcf_value, ndmin=self.dimension, copy=False) + return value + + def transform_and_update_bounds(self, vcf_value): + if vcf_value is None: + return None + value = self.transform(vcf_value) + self.update_bounds(value) + # print(self.field.full_name, "T", vcf_value, "->", value) + return value + + +class IntegerValueTransformer(VcfValueTransformer): + def update_bounds(self, value): + summary = self.field.summary + # Mask out missing and fill values + # print(value) + a = value[value >= constants.MIN_INT_VALUE] + if a.size > 0: + summary.max_value = int(max(summary.max_value, np.max(a))) + summary.min_value = int(min(summary.min_value, np.min(a))) + number = value.shape[-1] + summary.max_number = max(summary.max_number, number) + + +class FloatValueTransformer(VcfValueTransformer): + def update_bounds(self, value): + summary = self.field.summary + summary.max_value = float(max(summary.max_value, np.max(value))) + summary.min_value = float(min(summary.min_value, np.min(value))) + number = value.shape[-1] + summary.max_number = max(summary.max_number, number) + + +class StringValueTransformer(VcfValueTransformer): + def update_bounds(self, value): + summary = self.field.summary + number = value.shape[-1] + # TODO would be nice to report string lengths, but not + # really necessary. + summary.max_number = max(summary.max_number, number) + + def transform(self, vcf_value): + # print("transform", vcf_value) + if self.dimension == 1: + value = np.array(list(vcf_value.split(","))) + else: + # TODO can we make this faster?? + value = np.array([v.split(",") for v in vcf_value], dtype="O") + # print("HERE", vcf_value, value) + # for v in vcf_value: + # print("\t", type(v), len(v), v.split(",")) + # print("S: ", self.dimension, ":", value.shape, value) + return value + + +class SplitStringValueTransformer(StringValueTransformer): + def transform(self, vcf_value): + if vcf_value is None: + return self.missing_value # NEEDS TEST + assert self.dimension == 1 + return np.array(vcf_value, ndmin=1, dtype="str") + + +def get_vcf_field_path(base_path, vcf_field): + if vcf_field.category == "fixed": + return base_path / vcf_field.name + return base_path / vcf_field.category / vcf_field.name + + +class IntermediateColumnarFormatField: + def __init__(self, icf, vcf_field): + self.vcf_field = vcf_field + self.path = get_vcf_field_path(icf.path, vcf_field) + self.compressor = icf.compressor + self.num_partitions = icf.num_partitions + self.num_records = icf.num_records + self.partition_record_index = icf.partition_record_index + # A map of partition id to the cumulative number of records + # in chunks within that partition + self._chunk_record_index = {} + + @property + def name(self): + return self.vcf_field.full_name + + def partition_path(self, partition_id): + return self.path / f"p{partition_id}" + + def __repr__(self): + partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)] + return ( + f"IntermediateColumnarFormatField(name={self.name}, " + f"partition_chunks={partition_chunks}, " + f"path={self.path})" + ) + + def num_chunks(self, partition_id): + return len(self.chunk_record_index(partition_id)) - 1 + + def chunk_record_index(self, partition_id): + if partition_id not in self._chunk_record_index: + index_path = self.partition_path(partition_id) / "chunk_index" + with open(index_path, "rb") as f: + a = pickle.load(f) + assert len(a) > 1 + assert a[0] == 0 + self._chunk_record_index[partition_id] = a + return self._chunk_record_index[partition_id] + + def read_chunk(self, path): + with open(path, "rb") as f: + pkl = self.compressor.decode(f.read()) + return pickle.loads(pkl) + + def chunk_num_records(self, partition_id): + return np.diff(self.chunk_record_index(partition_id)) + + def chunks(self, partition_id, start_chunk=0): + partition_path = self.partition_path(partition_id) + chunk_cumulative_records = self.chunk_record_index(partition_id) + chunk_num_records = np.diff(chunk_cumulative_records) + for count, cumulative in zip( + chunk_num_records[start_chunk:], chunk_cumulative_records[start_chunk + 1 :] + ): + path = partition_path / f"{cumulative}" + chunk = self.read_chunk(path) + if len(chunk) != count: + raise ValueError(f"Corruption detected in chunk: {path}") + yield chunk + + def iter_values(self, start=None, stop=None): + start = 0 if start is None else start + stop = self.num_records if stop is None else stop + start_partition = ( + np.searchsorted(self.partition_record_index, start, side="right") - 1 + ) + offset = self.partition_record_index[start_partition] + assert offset <= start + chunk_offset = start - offset + + chunk_record_index = self.chunk_record_index(start_partition) + start_chunk = ( + np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1 + ) + record_id = offset + chunk_record_index[start_chunk] + assert record_id <= start + logger.debug( + f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:" + f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}" + ) + for chunk in self.chunks(start_partition, start_chunk): + for record in chunk: + if record_id == stop: + return + if record_id >= start: + yield record + record_id += 1 + assert record_id > start + for partition_id in range(start_partition + 1, self.num_partitions): + for chunk in self.chunks(partition_id): + for record in chunk: + if record_id == stop: + return + yield record + record_id += 1 + + # Note: this involves some computation so should arguably be a method, + # but making a property for consistency with xarray etc + @property + def values(self): + ret = [None] * self.num_records + j = 0 + for partition_id in range(self.num_partitions): + for chunk in self.chunks(partition_id): + for record in chunk: + ret[j] = record + j += 1 + assert j == self.num_records + return ret + + def sanitiser_factory(self, shape): + """ + Return a function that sanitised values from this column + and writes into a buffer of the specified shape. + """ + assert len(shape) <= 3 + if self.vcf_field.vcf_type == "Flag": + assert len(shape) == 1 + return sanitise_value_bool + elif self.vcf_field.vcf_type == "Float": + if len(shape) == 1: + return sanitise_value_float_scalar + elif len(shape) == 2: + return sanitise_value_float_1d + else: + return sanitise_value_float_2d + elif self.vcf_field.vcf_type == "Integer": + if len(shape) == 1: + return sanitise_value_int_scalar + elif len(shape) == 2: + return sanitise_value_int_1d + else: + return sanitise_value_int_2d + else: + assert self.vcf_field.vcf_type in ("String", "Character") + if len(shape) == 1: + return sanitise_value_string_scalar + elif len(shape) == 2: + return sanitise_value_string_1d + else: + return sanitise_value_string_2d + + +@dataclasses.dataclass +class IcfFieldWriter: + vcf_field: VcfField + path: pathlib.Path + transformer: VcfValueTransformer + compressor: Any + max_buffered_bytes: int + buff: list[Any] = dataclasses.field(default_factory=list) + buffered_bytes: int = 0 + chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0]) + num_records: int = 0 + + def append(self, val): + val = self.transformer.transform_and_update_bounds(val) + assert val is None or isinstance(val, np.ndarray) + self.buff.append(val) + val_bytes = sys.getsizeof(val) + self.buffered_bytes += val_bytes + self.num_records += 1 + if self.buffered_bytes >= self.max_buffered_bytes: + logger.debug( + f"Flush {self.path} buffered={self.buffered_bytes} " + f"max={self.max_buffered_bytes}" + ) + self.write_chunk() + self.buff.clear() + self.buffered_bytes = 0 + + def write_chunk(self): + # Update index + self.chunk_index.append(self.num_records) + path = self.path / f"{self.num_records}" + logger.debug(f"Start write: {path}") + pkl = pickle.dumps(self.buff) + compressed = self.compressor.encode(pkl) + with open(path, "wb") as f: + f.write(compressed) + + # Update the summary + self.vcf_field.summary.num_chunks += 1 + self.vcf_field.summary.compressed_size += len(compressed) + self.vcf_field.summary.uncompressed_size += self.buffered_bytes + logger.debug(f"Finish write: {path}") + + def flush(self): + logger.debug( + f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}" + ) + if len(self.buff) > 0: + self.write_chunk() + with open(self.path / "chunk_index", "wb") as f: + a = np.array(self.chunk_index, dtype=int) + pickle.dump(a, f) + + +class IcfPartitionWriter(contextlib.AbstractContextManager): + """ + Writes the data for a IntermediateColumnarFormat partition. + """ + + def __init__( + self, + icf_metadata, + out_path, + partition_index, + ): + self.partition_index = partition_index + # chunk_size is in megabytes + max_buffered_bytes = icf_metadata.column_chunk_size * 2**20 + assert max_buffered_bytes > 0 + compressor = numcodecs.get_codec(icf_metadata.compressor) + + self.field_writers = {} + num_samples = len(icf_metadata.samples) + for vcf_field in icf_metadata.fields: + field_path = get_vcf_field_path(out_path, vcf_field) + field_partition_path = field_path / f"p{partition_index}" + # Should be robust to running explode_partition twice. + field_partition_path.mkdir(exist_ok=True) + transformer = VcfValueTransformer.factory(vcf_field, num_samples) + self.field_writers[vcf_field.full_name] = IcfFieldWriter( + vcf_field, + field_partition_path, + transformer, + compressor, + max_buffered_bytes, + ) + + @property + def field_summaries(self): + return { + name: field.vcf_field.summary for name, field in self.field_writers.items() + } + + def append(self, name, value): + self.field_writers[name].append(value) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + for field in self.field_writers.values(): + field.flush() + return False + + +class IntermediateColumnarFormat(collections.abc.Mapping): + def __init__(self, path): + self.path = pathlib.Path(path) + # TODO raise a more informative error here telling people this + # directory is either a WIP or the wrong format. + with open(self.path / "metadata.json") as f: + self.metadata = IcfMetadata.fromdict(json.load(f)) + with open(self.path / "header.txt") as f: + self.vcf_header = f.read() + self.compressor = numcodecs.get_codec(self.metadata.compressor) + self.fields = {} + partition_num_records = [ + partition.num_records for partition in self.metadata.partitions + ] + # Allow us to find which partition a given record is in + self.partition_record_index = np.cumsum([0, *partition_num_records]) + for field in self.metadata.fields: + self.fields[field.full_name] = IntermediateColumnarFormatField(self, field) + logger.info( + f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, " + f"records={self.num_records}, fields={self.num_fields})" + ) + + def __repr__(self): + return ( + f"IntermediateColumnarFormat(fields={len(self)}, " + f"partitions={self.num_partitions}, " + f"records={self.num_records}, path={self.path})" + ) + + def __getitem__(self, key): + return self.fields[key] + + def __iter__(self): + return iter(self.fields) + + def __len__(self): + return len(self.fields) + + def summary_table(self): + data = [] + for name, col in self.fields.items(): + summary = col.vcf_field.summary + d = { + "name": name, + "type": col.vcf_field.vcf_type, + "chunks": summary.num_chunks, + "size": core.display_size(summary.uncompressed_size), + "compressed": core.display_size(summary.compressed_size), + "max_n": summary.max_number, + "min_val": core.display_number(summary.min_value), + "max_val": core.display_number(summary.max_value), + } + + data.append(d) + return data + + @property + def num_records(self): + return self.metadata.num_records + + @property + def num_partitions(self): + return len(self.metadata.partitions) + + @property + def num_samples(self): + return len(self.metadata.samples) + + @property + def num_fields(self): + return len(self.fields) + + +@dataclasses.dataclass +class IcfPartitionMetadata(core.JsonDataclass): + num_records: int + last_position: int + field_summaries: dict + + @staticmethod + def fromdict(d): + md = IcfPartitionMetadata(**d) + for k, v in md.field_summaries.items(): + md.field_summaries[k] = VcfFieldSummary.fromdict(v) + return md + + +def check_overlapping_partitions(partitions): + for i in range(1, len(partitions)): + prev_region = partitions[i - 1].region + current_region = partitions[i].region + if prev_region.contig == current_region.contig: + assert prev_region.end is not None + # Regions are *inclusive* + if prev_region.end >= current_region.start: + raise ValueError( + f"Overlapping VCF regions in partitions {i - 1} and {i}: " + f"{prev_region} and {current_region}" + ) + + +def check_field_clobbering(icf_metadata): + info_field_names = set(field.name for field in icf_metadata.info_fields) + fixed_variant_fields = set( + ["contig", "id", "id_mask", "position", "allele", "filter", "quality"] + ) + intersection = info_field_names & fixed_variant_fields + if len(intersection) > 0: + raise ValueError( + f"INFO field name(s) clashing with VCF Zarr spec: {intersection}" + ) + + format_field_names = set(field.name for field in icf_metadata.format_fields) + fixed_variant_fields = set(["genotype", "genotype_phased", "genotype_mask"]) + intersection = format_field_names & fixed_variant_fields + if len(intersection) > 0: + raise ValueError( + f"FORMAT field name(s) clashing with VCF Zarr spec: {intersection}" + ) + + +@dataclasses.dataclass +class IcfWriteSummary(core.JsonDataclass): + num_partitions: int + num_samples: int + num_variants: int + + +class IntermediateColumnarFormatWriter: + def __init__(self, path): + self.path = pathlib.Path(path) + self.wip_path = self.path / "wip" + self.metadata = None + + @property + def num_partitions(self): + return len(self.metadata.partitions) + + def init( + self, + vcfs, + *, + column_chunk_size=16, + worker_processes=1, + target_num_partitions=None, + show_progress=False, + compressor=None, + ): + if self.path.exists(): + raise ValueError("ICF path already exists") + if compressor is None: + compressor = ICF_DEFAULT_COMPRESSOR + vcfs = [pathlib.Path(vcf) for vcf in vcfs] + target_num_partitions = max(target_num_partitions, len(vcfs)) + + # TODO move scan_vcfs into this class + icf_metadata, header = scan_vcfs( + vcfs, + worker_processes=worker_processes, + show_progress=show_progress, + target_num_partitions=target_num_partitions, + ) + check_field_clobbering(icf_metadata) + self.metadata = icf_metadata + self.metadata.format_version = ICF_METADATA_FORMAT_VERSION + self.metadata.compressor = compressor.get_config() + self.metadata.column_chunk_size = column_chunk_size + # Bare minimum here for provenance - would be nice to include versions of key + # dependencies as well. + self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"} + + self.mkdirs() + + # Note: this is needed for the current version of the vcfzarr spec, but it's + # probably going to be dropped. + # https://github.com/pystatgen/vcf-zarr-spec/issues/15 + # May be useful to keep lying around still though? + logger.info("Writing VCF header") + with open(self.path / "header.txt", "w") as f: + f.write(header) + + logger.info("Writing WIP metadata") + with open(self.wip_path / "metadata.json", "w") as f: + json.dump(self.metadata.asdict(), f, indent=4) + return IcfWriteSummary( + num_partitions=self.num_partitions, + num_variants=icf_metadata.num_records, + num_samples=icf_metadata.num_samples, + ) + + def mkdirs(self): + num_dirs = len(self.metadata.fields) + logger.info(f"Creating {num_dirs} field directories") + self.path.mkdir() + self.wip_path.mkdir() + for field in self.metadata.fields: + col_path = get_vcf_field_path(self.path, field) + col_path.mkdir(parents=True) + + def load_partition_summaries(self): + summaries = [] + not_found = [] + for j in range(self.num_partitions): + try: + with open(self.wip_path / f"p{j}.json") as f: + summaries.append(IcfPartitionMetadata.fromdict(json.load(f))) + except FileNotFoundError: + not_found.append(j) + if len(not_found) > 0: + raise FileNotFoundError( + f"Partition metadata not found for {len(not_found)}" + f" partitions: {not_found}" + ) + return summaries + + def load_metadata(self): + if self.metadata is None: + with open(self.wip_path / "metadata.json") as f: + self.metadata = IcfMetadata.fromdict(json.load(f)) + + def process_partition(self, partition_index): + self.load_metadata() + summary_path = self.wip_path / f"p{partition_index}.json" + # If someone is rewriting a summary path (for whatever reason), make sure it + # doesn't look like it's already been completed. + # NOTE to do this properly we probably need to take a lock on this file - but + # this simple approach will catch the vast majority of problems. + if summary_path.exists(): + summary_path.unlink() + + partition = self.metadata.partitions[partition_index] + logger.info( + f"Start p{partition_index} {partition.vcf_path}__{partition.region}" + ) + info_fields = self.metadata.info_fields + format_fields = [] + has_gt = False + for field in self.metadata.format_fields: + if field.name == "GT": + has_gt = True + else: + format_fields.append(field) + + last_position = None + with IcfPartitionWriter( + self.metadata, + self.path, + partition_index, + ) as tcw: + with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf: + num_records = 0 + for variant in ivcf.variants(partition.region): + num_records += 1 + last_position = variant.POS + tcw.append("CHROM", variant.CHROM) + tcw.append("POS", variant.POS) + tcw.append("QUAL", variant.QUAL) + tcw.append("ID", variant.ID) + tcw.append("FILTERS", variant.FILTERS) + tcw.append("REF", variant.REF) + tcw.append("ALT", variant.ALT) + for field in info_fields: + tcw.append(field.full_name, variant.INFO.get(field.name, None)) + if has_gt: + tcw.append("FORMAT/GT", variant.genotype.array()) + for field in format_fields: + val = variant.format(field.name) + tcw.append(field.full_name, val) + # Note: an issue with updating the progress per variant here like + # this is that we get a significant pause at the end of the counter + # while all the "small" fields get flushed. Possibly not much to be + # done about it. + core.update_progress(1) + logger.info( + f"Finished reading VCF for partition {partition_index}, " + f"flushing buffers" + ) + + partition_metadata = IcfPartitionMetadata( + num_records=num_records, + last_position=last_position, + field_summaries=tcw.field_summaries, + ) + with open(summary_path, "w") as f: + f.write(partition_metadata.asjson()) + logger.info( + f"Finish p{partition_index} {partition.vcf_path}__{partition.region} " + f"{num_records} records last_pos={last_position}" + ) + + def explode(self, *, worker_processes=1, show_progress=False): + self.load_metadata() + num_records = self.metadata.num_records + if np.isinf(num_records): + logger.warning( + "Total records unknown, cannot show progress; " + "reindex VCFs with bcftools index to fix" + ) + num_records = None + num_fields = len(self.metadata.fields) + num_samples = len(self.metadata.samples) + logger.info( + f"Exploding fields={num_fields} samples={num_samples}; " + f"partitions={self.num_partitions} " + f"variants={'unknown' if num_records is None else num_records}" + ) + progress_config = core.ProgressConfig( + total=num_records, + units="vars", + title="Explode", + show=show_progress, + ) + with core.ParallelWorkManager(worker_processes, progress_config) as pwm: + for j in range(self.num_partitions): + pwm.submit(self.process_partition, j) + + def explode_partition(self, partition): + self.load_metadata() + if partition < 0 or partition >= self.num_partitions: + raise ValueError("Partition index not in the valid range") + self.process_partition(partition) + + def finalise(self): + self.load_metadata() + partition_summaries = self.load_partition_summaries() + total_records = 0 + for index, summary in enumerate(partition_summaries): + partition_records = summary.num_records + self.metadata.partitions[index].num_records = partition_records + self.metadata.partitions[index].region.end = summary.last_position + total_records += partition_records + if not np.isinf(self.metadata.num_records): + # Note: this is just telling us that there's a bug in the + # index based record counting code, but it doesn't actually + # matter much. We may want to just make this a warning if + # we hit regular problems. + assert total_records == self.metadata.num_records + self.metadata.num_records = total_records + + check_overlapping_partitions(self.metadata.partitions) + + for field in self.metadata.fields: + for summary in partition_summaries: + field.summary.update(summary.field_summaries[field.full_name]) + + logger.info("Finalising metadata") + with open(self.path / "metadata.json", "w") as f: + f.write(self.metadata.asjson()) + + logger.debug("Removing WIP directory") + shutil.rmtree(self.wip_path) + + +def explode( + icf_path, + vcfs, + *, + column_chunk_size=16, + worker_processes=1, + show_progress=False, + compressor=None, +): + writer = IntermediateColumnarFormatWriter(icf_path) + writer.init( + vcfs, + # Heuristic to get reasonable worker utilisation with lumpy partition sizing + target_num_partitions=max(1, worker_processes * 4), + worker_processes=worker_processes, + show_progress=show_progress, + column_chunk_size=column_chunk_size, + compressor=compressor, + ) + writer.explode(worker_processes=worker_processes, show_progress=show_progress) + writer.finalise() + return IntermediateColumnarFormat(icf_path) + + +def explode_init( + icf_path, + vcfs, + *, + column_chunk_size=16, + target_num_partitions=1, + worker_processes=1, + show_progress=False, + compressor=None, +): + writer = IntermediateColumnarFormatWriter(icf_path) + return writer.init( + vcfs, + target_num_partitions=target_num_partitions, + worker_processes=worker_processes, + show_progress=show_progress, + column_chunk_size=column_chunk_size, + compressor=compressor, + ) + + +def explode_partition(icf_path, partition): + writer = IntermediateColumnarFormatWriter(icf_path) + writer.explode_partition(partition) + + +def explode_finalise(icf_path): + writer = IntermediateColumnarFormatWriter(icf_path) + writer.finalise() diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index ea763b3d..95672f40 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -1,1235 +1,27 @@ -import collections -import contextlib import dataclasses import json import logging -import math import os import os.path import pathlib -import pickle import shutil -import sys import tempfile -from typing import Any import humanfriendly import numcodecs import numpy as np import zarr -from . import constants, core, provenance, vcf_utils +from . import constants, core, icf, provenance logger = logging.getLogger(__name__) -@dataclasses.dataclass -class VcfFieldSummary(core.JsonDataclass): - num_chunks: int = 0 - compressed_size: int = 0 - uncompressed_size: int = 0 - max_number: int = 0 # Corresponds to VCF Number field, depends on context - # Only defined for numeric fields - max_value: Any = -math.inf - min_value: Any = math.inf - - def update(self, other): - self.num_chunks += other.num_chunks - self.compressed_size += other.compressed_size - self.uncompressed_size += other.uncompressed_size - self.max_number = max(self.max_number, other.max_number) - self.min_value = min(self.min_value, other.min_value) - self.max_value = max(self.max_value, other.max_value) - - @staticmethod - def fromdict(d): - return VcfFieldSummary(**d) - - -@dataclasses.dataclass -class VcfField: - category: str - name: str - vcf_number: str - vcf_type: str - description: str - summary: VcfFieldSummary - - @staticmethod - def from_header(definition): - category = definition["HeaderType"] - name = definition["ID"] - vcf_number = definition["Number"] - vcf_type = definition["Type"] - return VcfField( - category=category, - name=name, - vcf_number=vcf_number, - vcf_type=vcf_type, - description=definition["Description"].strip('"'), - summary=VcfFieldSummary(), - ) - - @staticmethod - def fromdict(d): - f = VcfField(**d) - f.summary = VcfFieldSummary(**d["summary"]) - return f - - @property - def full_name(self): - if self.category == "fixed": - return self.name - return f"{self.category}/{self.name}" - - def smallest_dtype(self): - """ - Returns the smallest dtype suitable for this field based - on type, and values. - """ - s = self.summary - if self.vcf_type == "Float": - ret = "f4" - elif self.vcf_type == "Integer": - if not math.isfinite(s.max_value): - # All missing values; use i1. Note we should have some API to - # check more explicitly for missingness: - # https://github.com/sgkit-dev/bio2zarr/issues/131 - ret = "i1" - else: - ret = core.min_int_dtype(s.min_value, s.max_value) - elif self.vcf_type == "Flag": - ret = "bool" - elif self.vcf_type == "Character": - ret = "U1" - else: - assert self.vcf_type == "String" - ret = "O" - return ret - - -@dataclasses.dataclass -class VcfPartition: - vcf_path: str - region: str - num_records: int = -1 - - -ICF_METADATA_FORMAT_VERSION = "0.3" -ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc( - cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE -) - - -@dataclasses.dataclass -class Contig: - id: str - length: int = None - - -@dataclasses.dataclass -class Sample: - id: str - - -@dataclasses.dataclass -class Filter: - id: str - description: str = "" - - -@dataclasses.dataclass -class IcfMetadata(core.JsonDataclass): - samples: list - contigs: list - filters: list - fields: list - partitions: list = None - format_version: str = None - compressor: dict = None - column_chunk_size: int = None - provenance: dict = None - num_records: int = -1 - - @property - def info_fields(self): - fields = [] - for field in self.fields: - if field.category == "INFO": - fields.append(field) - return fields - - @property - def format_fields(self): - fields = [] - for field in self.fields: - if field.category == "FORMAT": - fields.append(field) - return fields - - @property - def num_contigs(self): - return len(self.contigs) - - @property - def num_filters(self): - return len(self.filters) - - @property - def num_samples(self): - return len(self.samples) - - @staticmethod - def fromdict(d): - if d["format_version"] != ICF_METADATA_FORMAT_VERSION: - raise ValueError( - "Intermediate columnar metadata format version mismatch: " - f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}" - ) - partitions = [VcfPartition(**pd) for pd in d["partitions"]] - for p in partitions: - p.region = vcf_utils.Region(**p.region) - d = d.copy() - d["partitions"] = partitions - d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]] - d["samples"] = [Sample(**sd) for sd in d["samples"]] - d["filters"] = [Filter(**fd) for fd in d["filters"]] - d["contigs"] = [Contig(**cd) for cd in d["contigs"]] - return IcfMetadata(**d) - - -def fixed_vcf_field_definitions(): - def make_field_def(name, vcf_type, vcf_number): - return VcfField( - category="fixed", - name=name, - vcf_type=vcf_type, - vcf_number=vcf_number, - description="", - summary=VcfFieldSummary(), - ) - - fields = [ - make_field_def("CHROM", "String", "1"), - make_field_def("POS", "Integer", "1"), - make_field_def("QUAL", "Float", "1"), - make_field_def("ID", "String", "."), - make_field_def("FILTERS", "String", "."), - make_field_def("REF", "String", "1"), - make_field_def("ALT", "String", "."), - ] - return fields - - -def scan_vcf(path, target_num_partitions): - with vcf_utils.IndexedVcf(path) as indexed_vcf: - vcf = indexed_vcf.vcf - filters = [] - pass_index = -1 - for h in vcf.header_iter(): - if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str): - try: - description = h["Description"].strip('"') - except KeyError: - description = "" - if h["ID"] == "PASS": - pass_index = len(filters) - filters.append(Filter(h["ID"], description)) - - # Ensure PASS is the first filter if present - if pass_index > 0: - pass_filter = filters.pop(pass_index) - filters.insert(0, pass_filter) - - fields = fixed_vcf_field_definitions() - for h in vcf.header_iter(): - if h["HeaderType"] in ["INFO", "FORMAT"]: - field = VcfField.from_header(h) - if field.name == "GT": - field.vcf_type = "Integer" - field.vcf_number = "." - fields.append(field) - - try: - contig_lengths = vcf.seqlens - except AttributeError: - contig_lengths = [None for _ in vcf.seqnames] - - metadata = IcfMetadata( - samples=[Sample(sample_id) for sample_id in vcf.samples], - contigs=[ - Contig(contig_id, length) - for contig_id, length in zip(vcf.seqnames, contig_lengths) - ], - filters=filters, - fields=fields, - partitions=[], - num_records=sum(indexed_vcf.contig_record_counts().values()), - ) - - regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions) - logger.info( - f"Split {path} into {len(regions)} regions (target={target_num_partitions})" - ) - for region in regions: - metadata.partitions.append( - VcfPartition( - # TODO should this be fully resolving the path? Otherwise it's all - # relative to the original WD - vcf_path=str(path), - region=region, - ) - ) - core.update_progress(1) - return metadata, vcf.raw_header - - -def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1): - logger.info( - f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}" - f" partitions." - ) - # An easy mistake to make is to pass the same file twice. Check this early on. - for path, count in collections.Counter(paths).items(): - if not path.exists(): # NEEDS TEST - raise FileNotFoundError(path) - if count > 1: - raise ValueError(f"Duplicate path provided: {path}") - - progress_config = core.ProgressConfig( - total=len(paths), - units="files", - title="Scan", - show=show_progress, - ) - with core.ParallelWorkManager(worker_processes, progress_config) as pwm: - for path in paths: - pwm.submit(scan_vcf, path, max(1, target_num_partitions // len(paths))) - results = list(pwm.results_as_completed()) - - # Sort to make the ordering deterministic - results.sort(key=lambda t: t[0].partitions[0].vcf_path) - # We just take the first header, assuming the others - # are compatible. - all_partitions = [] - total_records = 0 - for metadata, _ in results: - for partition in metadata.partitions: - logger.debug(f"Scanned partition {partition}") - all_partitions.append(partition) - total_records += metadata.num_records - metadata.num_records = 0 - metadata.partitions = [] - - icf_metadata, header = results[0] - for metadata, _ in results[1:]: - if metadata != icf_metadata: - raise ValueError("Incompatible VCF chunks") - - # Note: this will be infinity here if any of the chunks has an index - # that doesn't keep track of the number of records per-contig - icf_metadata.num_records = total_records - - # Sort by contig (in the order they appear in the header) first, - # then by start coordinate - contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)} - all_partitions.sort( - key=lambda x: (contig_index_map[x.region.contig], x.region.start) - ) - icf_metadata.partitions = all_partitions - logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.") - return icf_metadata, header - - -def sanitise_value_bool(buff, j, value): - x = True - if value is None: - x = False - buff[j] = x - - -def sanitise_value_float_scalar(buff, j, value): - x = value - if value is None: - x = [constants.FLOAT32_MISSING] - buff[j] = x[0] - - -def sanitise_value_int_scalar(buff, j, value): - x = value - if value is None: - # print("MISSING", INT_MISSING, INT_FILL) - x = [constants.INT_MISSING] - else: - x = sanitise_int_array(value, ndmin=1, dtype=np.int32) - buff[j] = x[0] - - -def sanitise_value_string_scalar(buff, j, value): - if value is None: - buff[j] = "." - else: - buff[j] = value[0] - - -def sanitise_value_string_1d(buff, j, value): - if value is None: - buff[j] = "." - else: - # value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False) - # FIXME failure isn't coming from here, it seems to be from an - # incorrectly detected dimension in the zarr array - # The dimesions look all wrong, and the dtype should be Object - # not str - value = drop_empty_second_dim(value) - buff[j] = "" - buff[j, : value.shape[0]] = value - - -def sanitise_value_string_2d(buff, j, value): - if value is None: - buff[j] = "." - else: - # print(buff.shape, value.dtype, value) - # assert value.ndim == 2 - buff[j] = "" - if value.ndim == 2: - buff[j, :, : value.shape[1]] = value - else: - # TODO check if this is still necessary - for k, val in enumerate(value): - buff[j, k, : len(val)] = val - - -def drop_empty_second_dim(value): - assert len(value.shape) == 1 or value.shape[1] == 1 - if len(value.shape) == 2 and value.shape[1] == 1: - value = value[..., 0] - return value - - -def sanitise_value_float_1d(buff, j, value): - if value is None: - buff[j] = constants.FLOAT32_MISSING - else: - value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False) - # numpy will map None values to Nan, but we need a - # specific NaN - value[np.isnan(value)] = constants.FLOAT32_MISSING - value = drop_empty_second_dim(value) - buff[j] = constants.FLOAT32_FILL - buff[j, : value.shape[0]] = value - - -def sanitise_value_float_2d(buff, j, value): - if value is None: - buff[j] = constants.FLOAT32_MISSING - else: - # print("value = ", value) - value = np.array(value, ndmin=2, dtype=buff.dtype, copy=False) - buff[j] = constants.FLOAT32_FILL - buff[j, :, : value.shape[1]] = value - - -def sanitise_int_array(value, ndmin, dtype): - if isinstance(value, tuple): - value = [ - constants.VCF_INT_MISSING if x is None else x for x in value - ] # NEEDS TEST - value = np.array(value, ndmin=ndmin, copy=False) - value[value == constants.VCF_INT_MISSING] = -1 - value[value == constants.VCF_INT_FILL] = -2 - # TODO watch out for clipping here! - return value.astype(dtype) - - -def sanitise_value_int_1d(buff, j, value): - if value is None: - buff[j] = -1 - else: - value = sanitise_int_array(value, 1, buff.dtype) - value = drop_empty_second_dim(value) - buff[j] = -2 - buff[j, : value.shape[0]] = value - - -def sanitise_value_int_2d(buff, j, value): - if value is None: - buff[j] = -1 - else: - value = sanitise_int_array(value, 2, buff.dtype) - buff[j] = -2 - buff[j, :, : value.shape[1]] = value - - -missing_value_map = { - "Integer": constants.INT_MISSING, - "Float": constants.FLOAT32_MISSING, - "String": constants.STR_MISSING, - "Character": constants.STR_MISSING, - "Flag": False, -} - - -class VcfValueTransformer: - """ - Transform VCF values into the stored intermediate format used - in the IntermediateColumnarFormat, and update field summaries. - """ - - def __init__(self, field, num_samples): - self.field = field - self.num_samples = num_samples - self.dimension = 1 - if field.category == "FORMAT": - self.dimension = 2 - self.missing = missing_value_map[field.vcf_type] - - @staticmethod - def factory(field, num_samples): - if field.vcf_type in ("Integer", "Flag"): - return IntegerValueTransformer(field, num_samples) - if field.vcf_type == "Float": - return FloatValueTransformer(field, num_samples) - if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]: - return SplitStringValueTransformer(field, num_samples) - return StringValueTransformer(field, num_samples) - - def transform(self, vcf_value): - if isinstance(vcf_value, tuple): - vcf_value = [self.missing if v is None else v for v in vcf_value] - value = np.array(vcf_value, ndmin=self.dimension, copy=False) - return value - - def transform_and_update_bounds(self, vcf_value): - if vcf_value is None: - return None - value = self.transform(vcf_value) - self.update_bounds(value) - # print(self.field.full_name, "T", vcf_value, "->", value) - return value - - -class IntegerValueTransformer(VcfValueTransformer): - def update_bounds(self, value): - summary = self.field.summary - # Mask out missing and fill values - # print(value) - a = value[value >= constants.MIN_INT_VALUE] - if a.size > 0: - summary.max_value = int(max(summary.max_value, np.max(a))) - summary.min_value = int(min(summary.min_value, np.min(a))) - number = value.shape[-1] - summary.max_number = max(summary.max_number, number) - - -class FloatValueTransformer(VcfValueTransformer): - def update_bounds(self, value): - summary = self.field.summary - summary.max_value = float(max(summary.max_value, np.max(value))) - summary.min_value = float(min(summary.min_value, np.min(value))) - number = value.shape[-1] - summary.max_number = max(summary.max_number, number) - - -class StringValueTransformer(VcfValueTransformer): - def update_bounds(self, value): - summary = self.field.summary - number = value.shape[-1] - # TODO would be nice to report string lengths, but not - # really necessary. - summary.max_number = max(summary.max_number, number) - - def transform(self, vcf_value): - # print("transform", vcf_value) - if self.dimension == 1: - value = np.array(list(vcf_value.split(","))) - else: - # TODO can we make this faster?? - value = np.array([v.split(",") for v in vcf_value], dtype="O") - # print("HERE", vcf_value, value) - # for v in vcf_value: - # print("\t", type(v), len(v), v.split(",")) - # print("S: ", self.dimension, ":", value.shape, value) - return value - - -class SplitStringValueTransformer(StringValueTransformer): - def transform(self, vcf_value): - if vcf_value is None: - return self.missing_value # NEEDS TEST - assert self.dimension == 1 - return np.array(vcf_value, ndmin=1, dtype="str") - - -def get_vcf_field_path(base_path, vcf_field): - if vcf_field.category == "fixed": - return base_path / vcf_field.name - return base_path / vcf_field.category / vcf_field.name - - -class IntermediateColumnarFormatField: - def __init__(self, icf, vcf_field): - self.vcf_field = vcf_field - self.path = get_vcf_field_path(icf.path, vcf_field) - self.compressor = icf.compressor - self.num_partitions = icf.num_partitions - self.num_records = icf.num_records - self.partition_record_index = icf.partition_record_index - # A map of partition id to the cumulative number of records - # in chunks within that partition - self._chunk_record_index = {} - - @property - def name(self): - return self.vcf_field.full_name - - def partition_path(self, partition_id): - return self.path / f"p{partition_id}" - - def __repr__(self): - partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)] - return ( - f"IntermediateColumnarFormatField(name={self.name}, " - f"partition_chunks={partition_chunks}, " - f"path={self.path})" - ) - - def num_chunks(self, partition_id): - return len(self.chunk_record_index(partition_id)) - 1 - - def chunk_record_index(self, partition_id): - if partition_id not in self._chunk_record_index: - index_path = self.partition_path(partition_id) / "chunk_index" - with open(index_path, "rb") as f: - a = pickle.load(f) - assert len(a) > 1 - assert a[0] == 0 - self._chunk_record_index[partition_id] = a - return self._chunk_record_index[partition_id] - - def read_chunk(self, path): - with open(path, "rb") as f: - pkl = self.compressor.decode(f.read()) - return pickle.loads(pkl) - - def chunk_num_records(self, partition_id): - return np.diff(self.chunk_record_index(partition_id)) - - def chunks(self, partition_id, start_chunk=0): - partition_path = self.partition_path(partition_id) - chunk_cumulative_records = self.chunk_record_index(partition_id) - chunk_num_records = np.diff(chunk_cumulative_records) - for count, cumulative in zip( - chunk_num_records[start_chunk:], chunk_cumulative_records[start_chunk + 1 :] - ): - path = partition_path / f"{cumulative}" - chunk = self.read_chunk(path) - if len(chunk) != count: - raise ValueError(f"Corruption detected in chunk: {path}") - yield chunk - - def iter_values(self, start=None, stop=None): - start = 0 if start is None else start - stop = self.num_records if stop is None else stop - start_partition = ( - np.searchsorted(self.partition_record_index, start, side="right") - 1 - ) - offset = self.partition_record_index[start_partition] - assert offset <= start - chunk_offset = start - offset - - chunk_record_index = self.chunk_record_index(start_partition) - start_chunk = ( - np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1 - ) - record_id = offset + chunk_record_index[start_chunk] - assert record_id <= start - logger.debug( - f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:" - f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}" - ) - for chunk in self.chunks(start_partition, start_chunk): - for record in chunk: - if record_id == stop: - return - if record_id >= start: - yield record - record_id += 1 - assert record_id > start - for partition_id in range(start_partition + 1, self.num_partitions): - for chunk in self.chunks(partition_id): - for record in chunk: - if record_id == stop: - return - yield record - record_id += 1 - - # Note: this involves some computation so should arguably be a method, - # but making a property for consistency with xarray etc - @property - def values(self): - ret = [None] * self.num_records - j = 0 - for partition_id in range(self.num_partitions): - for chunk in self.chunks(partition_id): - for record in chunk: - ret[j] = record - j += 1 - assert j == self.num_records - return ret - - def sanitiser_factory(self, shape): - """ - Return a function that sanitised values from this column - and writes into a buffer of the specified shape. - """ - assert len(shape) <= 3 - if self.vcf_field.vcf_type == "Flag": - assert len(shape) == 1 - return sanitise_value_bool - elif self.vcf_field.vcf_type == "Float": - if len(shape) == 1: - return sanitise_value_float_scalar - elif len(shape) == 2: - return sanitise_value_float_1d - else: - return sanitise_value_float_2d - elif self.vcf_field.vcf_type == "Integer": - if len(shape) == 1: - return sanitise_value_int_scalar - elif len(shape) == 2: - return sanitise_value_int_1d - else: - return sanitise_value_int_2d - else: - assert self.vcf_field.vcf_type in ("String", "Character") - if len(shape) == 1: - return sanitise_value_string_scalar - elif len(shape) == 2: - return sanitise_value_string_1d - else: - return sanitise_value_string_2d - - -@dataclasses.dataclass -class IcfFieldWriter: - vcf_field: VcfField - path: pathlib.Path - transformer: VcfValueTransformer - compressor: Any - max_buffered_bytes: int - buff: list[Any] = dataclasses.field(default_factory=list) - buffered_bytes: int = 0 - chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0]) - num_records: int = 0 - - def append(self, val): - val = self.transformer.transform_and_update_bounds(val) - assert val is None or isinstance(val, np.ndarray) - self.buff.append(val) - val_bytes = sys.getsizeof(val) - self.buffered_bytes += val_bytes - self.num_records += 1 - if self.buffered_bytes >= self.max_buffered_bytes: - logger.debug( - f"Flush {self.path} buffered={self.buffered_bytes} " - f"max={self.max_buffered_bytes}" - ) - self.write_chunk() - self.buff.clear() - self.buffered_bytes = 0 - - def write_chunk(self): - # Update index - self.chunk_index.append(self.num_records) - path = self.path / f"{self.num_records}" - logger.debug(f"Start write: {path}") - pkl = pickle.dumps(self.buff) - compressed = self.compressor.encode(pkl) - with open(path, "wb") as f: - f.write(compressed) - - # Update the summary - self.vcf_field.summary.num_chunks += 1 - self.vcf_field.summary.compressed_size += len(compressed) - self.vcf_field.summary.uncompressed_size += self.buffered_bytes - logger.debug(f"Finish write: {path}") - - def flush(self): - logger.debug( - f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}" - ) - if len(self.buff) > 0: - self.write_chunk() - with open(self.path / "chunk_index", "wb") as f: - a = np.array(self.chunk_index, dtype=int) - pickle.dump(a, f) - - -class IcfPartitionWriter(contextlib.AbstractContextManager): - """ - Writes the data for a IntermediateColumnarFormat partition. - """ - - def __init__( - self, - icf_metadata, - out_path, - partition_index, - ): - self.partition_index = partition_index - # chunk_size is in megabytes - max_buffered_bytes = icf_metadata.column_chunk_size * 2**20 - assert max_buffered_bytes > 0 - compressor = numcodecs.get_codec(icf_metadata.compressor) - - self.field_writers = {} - num_samples = len(icf_metadata.samples) - for vcf_field in icf_metadata.fields: - field_path = get_vcf_field_path(out_path, vcf_field) - field_partition_path = field_path / f"p{partition_index}" - # Should be robust to running explode_partition twice. - field_partition_path.mkdir(exist_ok=True) - transformer = VcfValueTransformer.factory(vcf_field, num_samples) - self.field_writers[vcf_field.full_name] = IcfFieldWriter( - vcf_field, - field_partition_path, - transformer, - compressor, - max_buffered_bytes, - ) - - @property - def field_summaries(self): - return { - name: field.vcf_field.summary for name, field in self.field_writers.items() - } - - def append(self, name, value): - self.field_writers[name].append(value) - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None: - for field in self.field_writers.values(): - field.flush() - return False - - -class IntermediateColumnarFormat(collections.abc.Mapping): - def __init__(self, path): - self.path = pathlib.Path(path) - # TODO raise a more informative error here telling people this - # directory is either a WIP or the wrong format. - with open(self.path / "metadata.json") as f: - self.metadata = IcfMetadata.fromdict(json.load(f)) - with open(self.path / "header.txt") as f: - self.vcf_header = f.read() - self.compressor = numcodecs.get_codec(self.metadata.compressor) - self.fields = {} - partition_num_records = [ - partition.num_records for partition in self.metadata.partitions - ] - # Allow us to find which partition a given record is in - self.partition_record_index = np.cumsum([0, *partition_num_records]) - for field in self.metadata.fields: - self.fields[field.full_name] = IntermediateColumnarFormatField(self, field) - logger.info( - f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, " - f"records={self.num_records}, fields={self.num_fields})" - ) - - def __repr__(self): - return ( - f"IntermediateColumnarFormat(fields={len(self)}, " - f"partitions={self.num_partitions}, " - f"records={self.num_records}, path={self.path})" - ) - - def __getitem__(self, key): - return self.fields[key] - - def __iter__(self): - return iter(self.fields) - - def __len__(self): - return len(self.fields) - - def summary_table(self): - data = [] - for name, col in self.fields.items(): - summary = col.vcf_field.summary - d = { - "name": name, - "type": col.vcf_field.vcf_type, - "chunks": summary.num_chunks, - "size": core.display_size(summary.uncompressed_size), - "compressed": core.display_size(summary.compressed_size), - "max_n": summary.max_number, - "min_val": core.display_number(summary.min_value), - "max_val": core.display_number(summary.max_value), - } - - data.append(d) - return data - - @property - def num_records(self): - return self.metadata.num_records - - @property - def num_partitions(self): - return len(self.metadata.partitions) - - @property - def num_samples(self): - return len(self.metadata.samples) - - @property - def num_fields(self): - return len(self.fields) - - -@dataclasses.dataclass -class IcfPartitionMetadata(core.JsonDataclass): - num_records: int - last_position: int - field_summaries: dict - - @staticmethod - def fromdict(d): - md = IcfPartitionMetadata(**d) - for k, v in md.field_summaries.items(): - md.field_summaries[k] = VcfFieldSummary.fromdict(v) - return md - - -def check_overlapping_partitions(partitions): - for i in range(1, len(partitions)): - prev_region = partitions[i - 1].region - current_region = partitions[i].region - if prev_region.contig == current_region.contig: - assert prev_region.end is not None - # Regions are *inclusive* - if prev_region.end >= current_region.start: - raise ValueError( - f"Overlapping VCF regions in partitions {i - 1} and {i}: " - f"{prev_region} and {current_region}" - ) - - -def check_field_clobbering(icf_metadata): - info_field_names = set(field.name for field in icf_metadata.info_fields) - fixed_variant_fields = set( - ["contig", "id", "id_mask", "position", "allele", "filter", "quality"] - ) - intersection = info_field_names & fixed_variant_fields - if len(intersection) > 0: - raise ValueError( - f"INFO field name(s) clashing with VCF Zarr spec: {intersection}" - ) - - format_field_names = set(field.name for field in icf_metadata.format_fields) - fixed_variant_fields = set(["genotype", "genotype_phased", "genotype_mask"]) - intersection = format_field_names & fixed_variant_fields - if len(intersection) > 0: - raise ValueError( - f"FORMAT field name(s) clashing with VCF Zarr spec: {intersection}" - ) - - -@dataclasses.dataclass -class IcfWriteSummary(core.JsonDataclass): - num_partitions: int - num_samples: int - num_variants: int - - -class IntermediateColumnarFormatWriter: - def __init__(self, path): - self.path = pathlib.Path(path) - self.wip_path = self.path / "wip" - self.metadata = None - - @property - def num_partitions(self): - return len(self.metadata.partitions) - - def init( - self, - vcfs, - *, - column_chunk_size=16, - worker_processes=1, - target_num_partitions=None, - show_progress=False, - compressor=None, - ): - if self.path.exists(): - raise ValueError("ICF path already exists") - if compressor is None: - compressor = ICF_DEFAULT_COMPRESSOR - vcfs = [pathlib.Path(vcf) for vcf in vcfs] - target_num_partitions = max(target_num_partitions, len(vcfs)) - - # TODO move scan_vcfs into this class - icf_metadata, header = scan_vcfs( - vcfs, - worker_processes=worker_processes, - show_progress=show_progress, - target_num_partitions=target_num_partitions, - ) - check_field_clobbering(icf_metadata) - self.metadata = icf_metadata - self.metadata.format_version = ICF_METADATA_FORMAT_VERSION - self.metadata.compressor = compressor.get_config() - self.metadata.column_chunk_size = column_chunk_size - # Bare minimum here for provenance - would be nice to include versions of key - # dependencies as well. - self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"} - - self.mkdirs() - - # Note: this is needed for the current version of the vcfzarr spec, but it's - # probably going to be dropped. - # https://github.com/pystatgen/vcf-zarr-spec/issues/15 - # May be useful to keep lying around still though? - logger.info("Writing VCF header") - with open(self.path / "header.txt", "w") as f: - f.write(header) - - logger.info("Writing WIP metadata") - with open(self.wip_path / "metadata.json", "w") as f: - json.dump(self.metadata.asdict(), f, indent=4) - return IcfWriteSummary( - num_partitions=self.num_partitions, - num_variants=icf_metadata.num_records, - num_samples=icf_metadata.num_samples, - ) - - def mkdirs(self): - num_dirs = len(self.metadata.fields) - logger.info(f"Creating {num_dirs} field directories") - self.path.mkdir() - self.wip_path.mkdir() - for field in self.metadata.fields: - col_path = get_vcf_field_path(self.path, field) - col_path.mkdir(parents=True) - - def load_partition_summaries(self): - summaries = [] - not_found = [] - for j in range(self.num_partitions): - try: - with open(self.wip_path / f"p{j}.json") as f: - summaries.append(IcfPartitionMetadata.fromdict(json.load(f))) - except FileNotFoundError: - not_found.append(j) - if len(not_found) > 0: - raise FileNotFoundError( - f"Partition metadata not found for {len(not_found)}" - f" partitions: {not_found}" - ) - return summaries - - def load_metadata(self): - if self.metadata is None: - with open(self.wip_path / "metadata.json") as f: - self.metadata = IcfMetadata.fromdict(json.load(f)) - - def process_partition(self, partition_index): - self.load_metadata() - summary_path = self.wip_path / f"p{partition_index}.json" - # If someone is rewriting a summary path (for whatever reason), make sure it - # doesn't look like it's already been completed. - # NOTE to do this properly we probably need to take a lock on this file - but - # this simple approach will catch the vast majority of problems. - if summary_path.exists(): - summary_path.unlink() - - partition = self.metadata.partitions[partition_index] - logger.info( - f"Start p{partition_index} {partition.vcf_path}__{partition.region}" - ) - info_fields = self.metadata.info_fields - format_fields = [] - has_gt = False - for field in self.metadata.format_fields: - if field.name == "GT": - has_gt = True - else: - format_fields.append(field) - - last_position = None - with IcfPartitionWriter( - self.metadata, - self.path, - partition_index, - ) as tcw: - with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf: - num_records = 0 - for variant in ivcf.variants(partition.region): - num_records += 1 - last_position = variant.POS - tcw.append("CHROM", variant.CHROM) - tcw.append("POS", variant.POS) - tcw.append("QUAL", variant.QUAL) - tcw.append("ID", variant.ID) - tcw.append("FILTERS", variant.FILTERS) - tcw.append("REF", variant.REF) - tcw.append("ALT", variant.ALT) - for field in info_fields: - tcw.append(field.full_name, variant.INFO.get(field.name, None)) - if has_gt: - tcw.append("FORMAT/GT", variant.genotype.array()) - for field in format_fields: - val = variant.format(field.name) - tcw.append(field.full_name, val) - # Note: an issue with updating the progress per variant here like - # this is that we get a significant pause at the end of the counter - # while all the "small" fields get flushed. Possibly not much to be - # done about it. - core.update_progress(1) - logger.info( - f"Finished reading VCF for partition {partition_index}, " - f"flushing buffers" - ) - - partition_metadata = IcfPartitionMetadata( - num_records=num_records, - last_position=last_position, - field_summaries=tcw.field_summaries, - ) - with open(summary_path, "w") as f: - f.write(partition_metadata.asjson()) - logger.info( - f"Finish p{partition_index} {partition.vcf_path}__{partition.region} " - f"{num_records} records last_pos={last_position}" - ) - - def explode(self, *, worker_processes=1, show_progress=False): - self.load_metadata() - num_records = self.metadata.num_records - if np.isinf(num_records): - logger.warning( - "Total records unknown, cannot show progress; " - "reindex VCFs with bcftools index to fix" - ) - num_records = None - num_fields = len(self.metadata.fields) - num_samples = len(self.metadata.samples) - logger.info( - f"Exploding fields={num_fields} samples={num_samples}; " - f"partitions={self.num_partitions} " - f"variants={'unknown' if num_records is None else num_records}" - ) - progress_config = core.ProgressConfig( - total=num_records, - units="vars", - title="Explode", - show=show_progress, - ) - with core.ParallelWorkManager(worker_processes, progress_config) as pwm: - for j in range(self.num_partitions): - pwm.submit(self.process_partition, j) - - def explode_partition(self, partition): - self.load_metadata() - if partition < 0 or partition >= self.num_partitions: - raise ValueError("Partition index not in the valid range") - self.process_partition(partition) - - def finalise(self): - self.load_metadata() - partition_summaries = self.load_partition_summaries() - total_records = 0 - for index, summary in enumerate(partition_summaries): - partition_records = summary.num_records - self.metadata.partitions[index].num_records = partition_records - self.metadata.partitions[index].region.end = summary.last_position - total_records += partition_records - if not np.isinf(self.metadata.num_records): - # Note: this is just telling us that there's a bug in the - # index based record counting code, but it doesn't actually - # matter much. We may want to just make this a warning if - # we hit regular problems. - assert total_records == self.metadata.num_records - self.metadata.num_records = total_records - - check_overlapping_partitions(self.metadata.partitions) - - for field in self.metadata.fields: - for summary in partition_summaries: - field.summary.update(summary.field_summaries[field.full_name]) - - logger.info("Finalising metadata") - with open(self.path / "metadata.json", "w") as f: - f.write(self.metadata.asjson()) - - logger.debug("Removing WIP directory") - shutil.rmtree(self.wip_path) - - -def explode( - icf_path, - vcfs, - *, - column_chunk_size=16, - worker_processes=1, - show_progress=False, - compressor=None, -): - writer = IntermediateColumnarFormatWriter(icf_path) - writer.init( - vcfs, - # Heuristic to get reasonable worker utilisation with lumpy partition sizing - target_num_partitions=max(1, worker_processes * 4), - worker_processes=worker_processes, - show_progress=show_progress, - column_chunk_size=column_chunk_size, - compressor=compressor, - ) - writer.explode(worker_processes=worker_processes, show_progress=show_progress) - writer.finalise() - return IntermediateColumnarFormat(icf_path) - - -def explode_init( - icf_path, - vcfs, - *, - column_chunk_size=16, - target_num_partitions=1, - worker_processes=1, - show_progress=False, - compressor=None, -): - writer = IntermediateColumnarFormatWriter(icf_path) - return writer.init( - vcfs, - target_num_partitions=target_num_partitions, - worker_processes=worker_processes, - show_progress=show_progress, - column_chunk_size=column_chunk_size, - compressor=compressor, - ) - - -def explode_partition(icf_path, partition): - writer = IntermediateColumnarFormatWriter(icf_path) - writer.explode_partition(partition) - - -def explode_finalise(icf_path): - writer = IntermediateColumnarFormatWriter(icf_path) - writer.finalise() - - def inspect(path): path = pathlib.Path(path) # TODO add support for the Zarr format also if (path / "metadata.json").exists(): - obj = IntermediateColumnarFormat(path) + obj = icf.IntermediateColumnarFormat(path) elif (path / ".zmetadata").exists(): obj = VcfZarr(path) else: @@ -1375,9 +167,9 @@ def fromdict(d): f"{d['format_version']} != {ZARR_SCHEMA_FORMAT_VERSION}" ) ret = VcfZarrSchema(**d) - ret.samples = [Sample(**sd) for sd in d["samples"]] - ret.contigs = [Contig(**sd) for sd in d["contigs"]] - ret.filters = [Filter(**sd) for sd in d["filters"]] + ret.samples = [icf.Sample(**sd) for sd in d["samples"]] + ret.contigs = [icf.Contig(**sd) for sd in d["contigs"]] + ret.filters = [icf.Filter(**sd) for sd in d["filters"]] ret.fields = [ZarrColumnSpec(**sd) for sd in d["fields"]] return ret @@ -1789,7 +581,7 @@ def load_metadata(self): if self.metadata is None: with open(self.wip_path / "metadata.json") as f: self.metadata = VcfZarrWriterMetadata.fromdict(json.load(f)) - self.icf = IntermediateColumnarFormat(self.metadata.icf_path) + self.icf = icf.IntermediateColumnarFormat(self.metadata.icf_path) def partition_path(self, partition_index): return self.partitions_path / f"p{partition_index}" @@ -1874,9 +666,9 @@ def encode_genotypes_partition(self, partition_index): source_col = self.icf.fields["FORMAT/GT"] for value in source_col.iter_values(partition.start, partition.stop): j = gt.next_buffer_row() - sanitise_value_int_2d(gt.buff, j, value[:, :-1]) + icf.sanitise_value_int_2d(gt.buff, j, value[:, :-1]) j = gt_phased.next_buffer_row() - sanitise_value_int_1d(gt_phased.buff, j, value[:, -1]) + icf.sanitise_value_int_1d(gt_phased.buff, j, value[:, -1]) # TODO check is this the correct semantics when we are padding # with mixed ploidies? j = gt_mask.next_buffer_row() @@ -2108,8 +900,8 @@ def encode_all_partitions( def mkschema(if_path, out): - icf = IntermediateColumnarFormat(if_path) - spec = VcfZarrSchema.generate(icf) + store = icf.IntermediateColumnarFormat(if_path) + spec = VcfZarrSchema.generate(store) out.write(spec.asjson()) @@ -2160,10 +952,10 @@ def encode_init( worker_processes=1, show_progress=False, ): - icf = IntermediateColumnarFormat(icf_path) + icf_store = icf.IntermediateColumnarFormat(icf_path) if schema_path is None: schema = VcfZarrSchema.generate( - icf, + icf_store, variants_chunk_size=variants_chunk_size, samples_chunk_size=samples_chunk_size, ) @@ -2178,7 +970,7 @@ def encode_init( zarr_path = pathlib.Path(zarr_path) vzw = VcfZarrWriter(zarr_path) return vzw.init( - icf, + icf_store, target_num_partitions=target_num_partitions, schema=schema, dimension_separator=dimension_separator, @@ -2208,7 +1000,7 @@ def convert( ): with tempfile.TemporaryDirectory(prefix="vcf2zarr") as tmp: if_dir = pathlib.Path(tmp) / "icf" - explode( + icf.explode( if_dir, vcfs, worker_processes=worker_processes, diff --git a/tests/test_cli.py b/tests/test_cli.py index 29c10243..e894d9a2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -69,7 +69,7 @@ def asjson(self): class TestWithMocks: vcf_path = "tests/data/vcf/sample.vcf.gz" - @mock.patch("bio2zarr.vcf.explode") + @mock.patch("bio2zarr.icf.explode") def test_vcf_explode(self, mocked, tmp_path): icf_path = tmp_path / "icf" runner = ct.CliRunner(mix_stderr=False) @@ -84,7 +84,7 @@ def test_vcf_explode(self, mocked, tmp_path): ) @pytest.mark.parametrize("compressor", ["lz4", "zstd"]) - @mock.patch("bio2zarr.vcf.explode") + @mock.patch("bio2zarr.icf.explode") def test_vcf_explode_compressor(self, mocked, tmp_path, compressor): icf_path = tmp_path / "icf" runner = ct.CliRunner(mix_stderr=False) @@ -107,7 +107,7 @@ def test_vcf_explode_compressor(self, mocked, tmp_path, compressor): ) @pytest.mark.parametrize("compressor", ["lz4", "zstd"]) - @mock.patch("bio2zarr.vcf.explode_init") + @mock.patch("bio2zarr.icf.explode_init") def test_vcf_dexplode_init_compressor(self, mocked, tmp_path, compressor): icf_path = tmp_path / "icf" runner = ct.CliRunner(mix_stderr=False) @@ -131,7 +131,7 @@ def test_vcf_dexplode_init_compressor(self, mocked, tmp_path, compressor): ) @pytest.mark.parametrize("compressor", ["LZ4", "asdf"]) - @mock.patch("bio2zarr.vcf.explode") + @mock.patch("bio2zarr.icf.explode") def test_vcf_explode_bad_compressor(self, mocked, tmp_path, compressor): runner = ct.CliRunner(mix_stderr=False) icf_path = tmp_path / "icf" @@ -144,7 +144,7 @@ def test_vcf_explode_bad_compressor(self, mocked, tmp_path, compressor): assert "Invalid value for '-C'" in result.stderr mocked.assert_not_called() - @mock.patch("bio2zarr.vcf.explode") + @mock.patch("bio2zarr.icf.explode") def test_vcf_explode_multiple_vcfs(self, mocked, tmp_path): icf_path = tmp_path / "icf" runner = ct.CliRunner(mix_stderr=False) @@ -161,7 +161,7 @@ def test_vcf_explode_multiple_vcfs(self, mocked, tmp_path): ) @pytest.mark.parametrize("response", ["y", "Y", "yes"]) - @mock.patch("bio2zarr.vcf.explode") + @mock.patch("bio2zarr.icf.explode") def test_vcf_explode_overwrite_icf_confirm_yes(self, mocked, tmp_path, response): icf_path = tmp_path / "icf" icf_path.mkdir() @@ -201,7 +201,7 @@ def test_vcf_encode_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response) ) @pytest.mark.parametrize("force_arg", ["-f", "--force"]) - @mock.patch("bio2zarr.vcf.explode") + @mock.patch("bio2zarr.icf.explode") def test_vcf_explode_overwrite_icf_force(self, mocked, tmp_path, force_arg): icf_path = tmp_path / "icf" icf_path.mkdir() @@ -240,7 +240,7 @@ def test_vcf_encode_overwrite_icf_force(self, mocked, tmp_path, force_arg): **DEFAULT_ENCODE_ARGS, ) - @mock.patch("bio2zarr.vcf.explode") + @mock.patch("bio2zarr.icf.explode") def test_vcf_explode_missing_vcf(self, mocked, tmp_path): icf_path = tmp_path / "icf" runner = ct.CliRunner(mix_stderr=False) @@ -253,7 +253,7 @@ def test_vcf_explode_missing_vcf(self, mocked, tmp_path): mocked.assert_not_called() @pytest.mark.parametrize("response", ["n", "N", "No"]) - @mock.patch("bio2zarr.vcf.explode") + @mock.patch("bio2zarr.icf.explode") def test_vcf_explode_overwrite_icf_confirm_no(self, mocked, tmp_path, response): icf_path = tmp_path / "icf" icf_path.mkdir() @@ -268,7 +268,7 @@ def test_vcf_explode_overwrite_icf_confirm_no(self, mocked, tmp_path, response): assert "Aborted" in result.stderr mocked.assert_not_called() - @mock.patch("bio2zarr.vcf.explode") + @mock.patch("bio2zarr.icf.explode") def test_vcf_explode_missing_and_existing_vcf(self, mocked, tmp_path): icf_path = tmp_path / "icf" runner = ct.CliRunner(mix_stderr=False) @@ -282,7 +282,7 @@ def test_vcf_explode_missing_and_existing_vcf(self, mocked, tmp_path): assert "'no_such_file' does not exist" in result.stderr mocked.assert_not_called() - @mock.patch("bio2zarr.vcf.explode_init", return_value=FakeWorkSummary(5)) + @mock.patch("bio2zarr.icf.explode_init", return_value=FakeWorkSummary(5)) def test_vcf_dexplode_init(self, mocked, tmp_path): runner = ct.CliRunner(mix_stderr=False) icf_path = tmp_path / "icf" @@ -302,7 +302,7 @@ def test_vcf_dexplode_init(self, mocked, tmp_path): ) @pytest.mark.parametrize("num_partitions", ["-- -1", "0", "asdf", "1.112"]) - @mock.patch("bio2zarr.vcf.explode_init", return_value=5) + @mock.patch("bio2zarr.icf.explode_init", return_value=5) def test_vcf_dexplode_init_bad_num_partitions( self, mocked, tmp_path, num_partitions ): @@ -317,7 +317,7 @@ def test_vcf_dexplode_init_bad_num_partitions( assert "Invalid value for 'NUM_PARTITIONS'" in result.stderr mocked.assert_not_called() - @mock.patch("bio2zarr.vcf.explode_partition") + @mock.patch("bio2zarr.icf.explode_partition") def test_vcf_dexplode_partition(self, mocked, tmp_path): runner = ct.CliRunner(mix_stderr=False) icf_path = tmp_path / "icf" @@ -334,7 +334,7 @@ def test_vcf_dexplode_partition(self, mocked, tmp_path): str(icf_path), 1, **DEFAULT_DEXPLODE_PARTITION_ARGS ) - @mock.patch("bio2zarr.vcf.explode_partition") + @mock.patch("bio2zarr.icf.explode_partition") def test_vcf_dexplode_partition_one_based(self, mocked, tmp_path): runner = ct.CliRunner(mix_stderr=False) icf_path = tmp_path / "icf" @@ -351,7 +351,7 @@ def test_vcf_dexplode_partition_one_based(self, mocked, tmp_path): str(icf_path), 0, **DEFAULT_DEXPLODE_PARTITION_ARGS ) - @mock.patch("bio2zarr.vcf.explode_partition") + @mock.patch("bio2zarr.icf.explode_partition") def test_vcf_dexplode_partition_missing_dir(self, mocked, tmp_path): runner = ct.CliRunner(mix_stderr=False) icf_path = tmp_path / "icf" @@ -366,7 +366,7 @@ def test_vcf_dexplode_partition_missing_dir(self, mocked, tmp_path): mocked.assert_not_called() @pytest.mark.parametrize("partition", ["-- -1", "asdf", "1.112"]) - @mock.patch("bio2zarr.vcf.explode_partition") + @mock.patch("bio2zarr.icf.explode_partition") def test_vcf_dexplode_partition_bad_partition(self, mocked, tmp_path, partition): runner = ct.CliRunner(mix_stderr=False) icf_path = tmp_path / "icf" @@ -381,7 +381,7 @@ def test_vcf_dexplode_partition_bad_partition(self, mocked, tmp_path, partition) assert len(result.stdout) == 0 mocked.assert_not_called() - @mock.patch("bio2zarr.vcf.explode_finalise") + @mock.patch("bio2zarr.icf.explode_finalise") def test_vcf_dexplode_finalise(self, mocked, tmp_path): runner = ct.CliRunner(mix_stderr=False) result = runner.invoke( diff --git a/tests/test_icf.py b/tests/test_icf.py index 35c24f0e..a5e0792a 100644 --- a/tests/test_icf.py +++ b/tests/test_icf.py @@ -6,7 +6,8 @@ import numpy.testing as nt import pytest -from bio2zarr import provenance, vcf +from bio2zarr import icf as icf_mod +from bio2zarr import provenance, vcf, vcf_utils class TestSmallExample: @@ -24,10 +25,10 @@ class TestSmallExample: @pytest.fixture(scope="class") def icf(self, tmp_path_factory): out = tmp_path_factory.mktemp("data") / "example.exploded" - return vcf.explode(out, [self.data_path]) + return icf_mod.explode(out, [self.data_path]) def test_format_version(self, icf): - assert icf.metadata.format_version == vcf.ICF_METADATA_FORMAT_VERSION + assert icf.metadata.format_version == icf_mod.ICF_METADATA_FORMAT_VERSION def test_provenance(self, icf): assert icf.metadata.provenance == { @@ -105,7 +106,7 @@ class TestIcfWriterExample: def test_init_paths(self, tmp_path): icf_path = tmp_path / "x.icf" assert not icf_path.exists() - summary = vcf.explode_init(icf_path, [self.data_path]) + summary = icf_mod.explode_init(icf_path, [self.data_path]) assert summary.num_partitions == 3 assert icf_path.exists() wip_path = icf_path / "wip" @@ -118,50 +119,50 @@ def test_init_paths(self, tmp_path): def test_finalise_paths(self, tmp_path): icf_path = tmp_path / "x.icf" wip_path = icf_path / "wip" - summary = vcf.explode_init(icf_path, [self.data_path]) + summary = icf_mod.explode_init(icf_path, [self.data_path]) assert icf_path.exists() for j in range(summary.num_partitions): - vcf.explode_partition(icf_path, j) + icf_mod.explode_partition(icf_path, j) assert wip_path.exists() - vcf.explode_finalise(icf_path) + icf_mod.explode_finalise(icf_path) assert icf_path.exists() assert not wip_path.exists() def test_finalise_no_partitions_fails(self, tmp_path): icf_path = tmp_path / "x.icf" - vcf.explode_init(icf_path, [self.data_path]) + icf_mod.explode_init(icf_path, [self.data_path]) with pytest.raises(FileNotFoundError, match="3 partitions: \\[0, 1, 2\\]"): - vcf.explode_finalise(icf_path) + icf_mod.explode_finalise(icf_path) @pytest.mark.parametrize("partition", [0, 1, 2]) def test_finalise_missing_partition_fails(self, tmp_path, partition): icf_path = tmp_path / "x.icf" - vcf.explode_init(icf_path, [self.data_path]) + icf_mod.explode_init(icf_path, [self.data_path]) for j in range(3): if j != partition: - vcf.explode_partition(icf_path, j) + icf_mod.explode_partition(icf_path, j) with pytest.raises(FileNotFoundError, match=f"1 partitions: \\[{partition}\\]"): - vcf.explode_finalise(icf_path) + icf_mod.explode_finalise(icf_path) @pytest.mark.parametrize("partition", [0, 1, 2]) def test_explode_partition(self, tmp_path, partition): icf_path = tmp_path / "x.icf" - vcf.explode_init(icf_path, [self.data_path]) + icf_mod.explode_init(icf_path, [self.data_path]) summary_file = icf_path / "wip" / f"p{partition}.json" assert not summary_file.exists() - vcf.explode_partition(icf_path, partition) + icf_mod.explode_partition(icf_path, partition) assert summary_file.exists() def test_double_explode_partition(self, tmp_path): partition = 1 icf_path = tmp_path / "x.icf" - vcf.explode_init(icf_path, [self.data_path]) + icf_mod.explode_init(icf_path, [self.data_path]) summary_file = icf_path / "wip" / f"p{partition}.json" assert not summary_file.exists() - vcf.explode_partition(icf_path, partition) + icf_mod.explode_partition(icf_path, partition) with open(summary_file) as f: s1 = f.read() - vcf.explode_partition(icf_path, partition) + icf_mod.explode_partition(icf_path, partition) with open(summary_file) as f: s2 = f.read() assert s1 == s2 @@ -169,19 +170,19 @@ def test_double_explode_partition(self, tmp_path): @pytest.mark.parametrize("partition", [-1, 3, 100]) def test_explode_partition_out_of_range(self, tmp_path, partition): icf_path = tmp_path / "x.icf" - vcf.explode_init(icf_path, [self.data_path]) + icf_mod.explode_init(icf_path, [self.data_path]) with pytest.raises(ValueError, match="Partition index not in the valid range"): - vcf.explode_partition(icf_path, partition) + icf_mod.explode_partition(icf_path, partition) def test_explode_same_file_twice(self, tmp_path): icf_path = tmp_path / "x.icf" with pytest.raises(ValueError, match="Duplicate path provided"): - vcf.explode(icf_path, [self.data_path, self.data_path]) + icf_mod.explode(icf_path, [self.data_path, self.data_path]) def test_explode_same_data_twice(self, tmp_path): icf_path = tmp_path / "x.icf" with pytest.raises(ValueError, match="Overlapping VCF regions"): - vcf.explode(icf_path, [self.data_path, "tests/data/vcf/sample.bcf"]) + icf_mod.explode(icf_path, [self.data_path, "tests/data/vcf/sample.bcf"]) class TestGeneratedFieldsExample: @@ -197,7 +198,7 @@ def icf(self, tmp_path_factory): # df = sgkit.load_dataset("tmp/fields.vcf.sg") # print(df["variant_IC2"]) # print(df["variant_IC2"].values) - return vcf.explode(out, [self.data_path]) + return icf_mod.explode(out, [self.data_path]) @pytest.fixture(scope="class") def schema(self, icf): @@ -265,16 +266,16 @@ class TestInitProperties: def run_explode(self, tmp_path, **kwargs): icf_path = tmp_path / "icf" - vcf.explode(icf_path, [self.data_path], **kwargs) - return vcf.IntermediateColumnarFormat(icf_path) + icf_mod.explode(icf_path, [self.data_path], **kwargs) + return icf_mod.IntermediateColumnarFormat(icf_path) def run_dexplode(self, tmp_path, **kwargs): icf_path = tmp_path / "icf" - summary = vcf.explode_init(icf_path, [self.data_path], **kwargs) + summary = icf_mod.explode_init(icf_path, [self.data_path], **kwargs) for j in range(summary.num_partitions): - vcf.explode_partition(icf_path, j) - vcf.explode_finalise(icf_path) - return vcf.IntermediateColumnarFormat(icf_path) + icf_mod.explode_partition(icf_path, j) + icf_mod.explode_finalise(icf_path) + return icf_mod.IntermediateColumnarFormat(icf_path) @pytest.mark.parametrize( "compressor", @@ -290,12 +291,12 @@ def test_compressor_explode(self, tmp_path, compressor): def test_default_compressor_explode(self, tmp_path): icf = self.run_explode(tmp_path) - assert icf.metadata.compressor == vcf.ICF_DEFAULT_COMPRESSOR.get_config() + assert icf.metadata.compressor == icf_mod.ICF_DEFAULT_COMPRESSOR.get_config() assert icf.metadata.compressor["cname"] == "zstd" def test_default_compressor_dexplode(self, tmp_path): icf = self.run_dexplode(tmp_path) - assert icf.metadata.compressor == vcf.ICF_DEFAULT_COMPRESSOR.get_config() + assert icf.metadata.compressor == icf_mod.ICF_DEFAULT_COMPRESSOR.get_config() assert icf.metadata.compressor["cname"] == "zstd" @pytest.mark.parametrize( @@ -326,40 +327,40 @@ class TestCorruptionDetection: def test_missing_field(self, tmp_path): icf_path = tmp_path / "icf" - vcf.explode(icf_path, [self.data_path]) + icf_mod.explode(icf_path, [self.data_path]) shutil.rmtree(icf_path / "POS") - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) with pytest.raises(FileNotFoundError): icf["POS"].values # noqa B018 def test_missing_chunk_index(self, tmp_path): icf_path = tmp_path / "icf" - vcf.explode(icf_path, [self.data_path]) + icf_mod.explode(icf_path, [self.data_path]) chunk_index_path = icf_path / "POS" / "p0" / "chunk_index" assert chunk_index_path.exists() chunk_index_path.unlink() - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) with pytest.raises(FileNotFoundError): icf["POS"].values # noqa B018 def test_missing_chunk_file(self, tmp_path): icf_path = tmp_path / "icf" - vcf.explode(icf_path, [self.data_path]) + icf_mod.explode(icf_path, [self.data_path]) chunk_file = icf_path / "POS" / "p0" / "2" assert chunk_file.exists() chunk_file.unlink() - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) with pytest.raises(FileNotFoundError): icf["POS"].values # noqa B018 def test_empty_chunk_file(self, tmp_path): icf_path = tmp_path / "icf" - vcf.explode(icf_path, [self.data_path]) + icf_mod.explode(icf_path, [self.data_path]) chunk_file = icf_path / "POS" / "p0" / "2" assert chunk_file.exists() with open(chunk_file, "w") as _: pass - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) with pytest.raises(RuntimeError, match="blosc"): icf["POS"].values # noqa B018 @@ -367,21 +368,21 @@ def test_empty_chunk_file(self, tmp_path): @pytest.mark.parametrize("length", [10, 100, 190, 194]) def test_truncated_chunk_file(self, tmp_path, length): icf_path = tmp_path / "icf" - vcf.explode(icf_path, [self.data_path]) + icf_mod.explode(icf_path, [self.data_path]) chunk_file = icf_path / "POS" / "p0" / "2" with open(chunk_file, "rb") as f: buff = f.read(length) assert len(buff) == length with open(chunk_file, "wb") as f: f.write(buff) - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) # Either Blosc or pickling errors happen here with pytest.raises((RuntimeError, pickle.UnpicklingError)): icf["POS"].values # noqa B018 def test_chunk_incorrect_length(self, tmp_path): icf_path = tmp_path / "icf" - vcf.explode(icf_path, [self.data_path]) + icf_mod.explode(icf_path, [self.data_path]) chunk_file = icf_path / "POS" / "p0" / "2" compressor = numcodecs.Blosc(cname="zstd") with open(chunk_file, "rb") as f: @@ -392,7 +393,7 @@ def test_chunk_incorrect_length(self, tmp_path): pkl = pickle.dumps(x[0]) with open(chunk_file, "wb") as f: f.write(compressor.encode(pkl)) - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) with pytest.raises(ValueError, match="Corruption detected"): icf["POS"].values # noqa B018 with pytest.raises(ValueError, match="Corruption detected"): @@ -405,7 +406,7 @@ class TestSlicing: @pytest.fixture(scope="class") def icf(self, tmp_path_factory): out = tmp_path_factory.mktemp("data") / "example.exploded" - return vcf.explode( + return icf_mod.explode( out, [self.data_path], column_chunk_size=0.0125, worker_processes=0 ) @@ -468,3 +469,27 @@ def test_slice(self, icf, start, stop): pos = np.array(col.values) pos_slice = np.array(list(col.iter_values(start, stop))) nt.assert_array_equal(pos[start:stop], pos_slice) + + +@pytest.mark.parametrize( + "regions", + [ + # Overlapping partitions + [("1", 100, 200), ("1", 150, 250)], + # Overlap by one position + [("1", 100, 201), ("1", 200, 300)], + # End coord is *inclusive* + [("1", 100, 201), ("1", 201, 300)], + # Contained overlap + [("1", 100, 300), ("1", 150, 250)], + # Exactly equal + [("1", 100, 200), ("1", 100, 200)], + ], +) +def test_check_overlap(regions): + partitions = [ + icf_mod.VcfPartition("", region=vcf_utils.Region(contig, start, end)) + for contig, start, end in regions + ] + with pytest.raises(ValueError, match="Overlapping VCF regions"): + icf_mod.check_overlapping_partitions(partitions) diff --git a/tests/test_vcf.py b/tests/test_vcf.py index 1284582b..706cc359 100644 --- a/tests/test_vcf.py +++ b/tests/test_vcf.py @@ -6,7 +6,8 @@ import xarray.testing as xt import zarr -from bio2zarr import core, vcf, vcf_utils +from bio2zarr import core, vcf +from bio2zarr import icf as icf_mod @pytest.fixture(scope="module") @@ -17,7 +18,7 @@ def vcf_file(): @pytest.fixture(scope="module") def icf_path(vcf_file, tmp_path_factory): out = tmp_path_factory.mktemp("data") / "example.exploded" - vcf.explode(out, [vcf_file]) + icf_mod.explode(out, [vcf_file]) return out @@ -97,7 +98,7 @@ def test_exploded_metadata_mismatch(self, tmpdir, icf_path, version): with pytest.raises( ValueError, match="Intermediate columnar metadata format version mismatch" ): - vcf.IcfMetadata.fromdict(d) + icf_mod.IcfMetadata.fromdict(d) @pytest.mark.parametrize("version", ["0.0", "1.0", "xxxxx", 0.1]) def test_encode_metadata_mismatch(self, tmpdir, icf_path, version): @@ -138,29 +139,29 @@ def assert_json_round_trip(self, schema): assert schema == schema2 def test_generated_no_changes(self, icf_path): - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) self.assert_json_round_trip(vcf.VcfZarrSchema.generate(icf)) def test_generated_no_fields(self, icf_path): - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) schema = vcf.VcfZarrSchema.generate(icf) schema.fields.clear() self.assert_json_round_trip(schema) def test_generated_no_samples(self, icf_path): - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) schema = vcf.VcfZarrSchema.generate(icf) schema.samples.clear() self.assert_json_round_trip(schema) def test_generated_change_dtype(self, icf_path): - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) schema = vcf.VcfZarrSchema.generate(icf) schema.field_map()["variant_position"].dtype = "i8" self.assert_json_round_trip(schema) def test_generated_change_compressor(self, icf_path): - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) schema = vcf.VcfZarrSchema.generate(icf) schema.field_map()["variant_position"].compressor = {"cname": "FAKE"} self.assert_json_round_trip(schema) @@ -172,7 +173,7 @@ class TestSchemaEncode: ) def test_codec(self, tmp_path, icf_path, cname, clevel, shuffle): zarr_path = tmp_path / "zarr" - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) schema = vcf.VcfZarrSchema.generate(icf) for var in schema.fields: var.compressor["cname"] = cname @@ -192,7 +193,7 @@ def test_codec(self, tmp_path, icf_path, cname, clevel, shuffle): @pytest.mark.parametrize("dtype", ["i4", "i8"]) def test_genotype_dtype(self, tmp_path, icf_path, dtype): zarr_path = tmp_path / "zarr" - icf = vcf.IntermediateColumnarFormat(icf_path) + icf = icf_mod.IntermediateColumnarFormat(icf_path) schema = vcf.VcfZarrSchema.generate(icf) schema.field_map()["call_genotype"].dtype = dtype schema_path = tmp_path / "schema" @@ -331,30 +332,6 @@ def test_call_GQ(self, schema): } -@pytest.mark.parametrize( - "regions", - [ - # Overlapping partitions - [("1", 100, 200), ("1", 150, 250)], - # Overlap by one position - [("1", 100, 201), ("1", 200, 300)], - # End coord is *inclusive* - [("1", 100, 201), ("1", 201, 300)], - # Contained overlap - [("1", 100, 300), ("1", 150, 250)], - # Exactly equal - [("1", 100, 200), ("1", 100, 200)], - ], -) -def test_check_overlap(regions): - partitions = [ - vcf.VcfPartition("", region=vcf_utils.Region(contig, start, end)) - for contig, start, end in regions - ] - with pytest.raises(ValueError, match="Overlapping VCF regions"): - vcf.check_overlapping_partitions(partitions) - - class TestVcfDescriptions: @pytest.mark.parametrize( ("field", "description"), @@ -545,7 +522,7 @@ def test_variant_fields(self, tmp_path, field): vcf_file = tmp_path / "test.vcf" self.generate_vcf(vcf_file, info_field=field) with pytest.raises(ValueError, match=f"INFO field name.*{field}"): - vcf.explode(tmp_path / "x.icf", [tmp_path / "test.vcf.gz"]) + icf_mod.explode(tmp_path / "x.icf", [tmp_path / "test.vcf.gz"]) @pytest.mark.parametrize( "field", @@ -559,7 +536,7 @@ def test_call_fields(self, tmp_path, field): vcf_file = tmp_path / "test.vcf" self.generate_vcf(vcf_file, format_field=field) with pytest.raises(ValueError, match=f"FORMAT field name.*{field}"): - vcf.explode(tmp_path / "x.icf", [tmp_path / "test.vcf.gz"]) + icf_mod.explode(tmp_path / "x.icf", [tmp_path / "test.vcf.gz"]) class TestBadSchemaChanges: diff --git a/tests/test_vcf_examples.py b/tests/test_vcf_examples.py index 27d22468..78a48066 100644 --- a/tests/test_vcf_examples.py +++ b/tests/test_vcf_examples.py @@ -9,6 +9,7 @@ import xarray.testing as xt from bio2zarr import constants, provenance, vcf, verification +from bio2zarr import icf as icf_mod class TestSmallExample: @@ -304,7 +305,7 @@ def test_split(self, ds, tmp_path, worker_processes, rotate): @pytest.mark.parametrize("worker_processes", [0, 1, 2]) def test_full_pipeline(self, ds, tmp_path, worker_processes): exploded = tmp_path / "example.exploded" - vcf.explode( + icf_mod.explode( exploded, [self.data_path], worker_processes=worker_processes, @@ -323,7 +324,7 @@ def test_max_variant_chunks( self, ds, tmp_path, max_variant_chunks, variants_chunk_size ): exploded = tmp_path / "example.exploded" - vcf.explode(exploded, [self.data_path]) + icf_mod.explode(exploded, [self.data_path]) out = tmp_path / "example.zarr" vcf.encode( exploded, @@ -818,14 +819,14 @@ class TestSplitFileErrors: def test_entirely_incompatible(self, tmp_path): path = "tests/data/vcf/" with pytest.raises(ValueError, match="Incompatible"): - vcf.explode_init( + icf_mod.explode_init( tmp_path / "if", [path + "sample.vcf.gz", path + "1kg_2020_chrM.bcf"] ) def test_duplicate_paths(self, tmp_path): path = "tests/data/vcf/" with pytest.raises(ValueError, match="Duplicate"): - vcf.explode_init(tmp_path / "if", [path + "sample.vcf.gz"] * 2) + icf_mod.explode_init(tmp_path / "if", [path + "sample.vcf.gz"] * 2) @pytest.mark.parametrize( @@ -872,16 +873,16 @@ def test_split_explode(tmp_path): "tests/data/vcf/sample.vcf.gz.3.split/X.vcf.gz", ] out = tmp_path / "test.explode" - work_summary = vcf.explode_init(out, paths, target_num_partitions=15) + work_summary = icf_mod.explode_init(out, paths, target_num_partitions=15) assert work_summary.num_partitions == 3 with pytest.raises(FileNotFoundError): - pcvcf = vcf.IntermediateColumnarFormat(out) + pcvcf = icf_mod.IntermediateColumnarFormat(out) for j in range(work_summary.num_partitions): - vcf.explode_partition(out, j) - vcf.explode_finalise(out) - pcvcf = vcf.IntermediateColumnarFormat(out) + icf_mod.explode_partition(out, j) + icf_mod.explode_finalise(out) + pcvcf = icf_mod.IntermediateColumnarFormat(out) assert pcvcf.fields["POS"].vcf_field.summary.asdict() == { "num_chunks": 3, "compressed_size": 587, diff --git a/validation.py b/validation.py index 05a2b578..bb846ced 100644 --- a/validation.py +++ b/validation.py @@ -6,7 +6,7 @@ import click -from bio2zarr import vcf, verification +from bio2zarr import icf, vcf, verification # TODO add support here for split vcfs. Perhaps simplest to take a # directory provided as input as indicating this, and then having @@ -44,7 +44,7 @@ def cli(vcfs, worker_processes, force): if force and exploded.exists(): shutil.rmtree(exploded) if not exploded.exists(): - vcf.explode( + icf.explode( exploded, files, worker_processes=worker_processes,