In [1]:
import zarr
import xarray as xr
import sgkit as sg
import pandas as pd
import numcodecs
from Bio import SeqIO
from fsspec.implementations.zip import ZipFileSystem
import dask
import dask.array as da
import fsspec
import numpy as np
import os

In [2]:
# Get started
# Load paths and metadata
ref_path = '/QRISdata/Q6151/dennistpw/far_hin_1.x/data/reference/VectorBase-54_AfarautiFAR1_Genome.fasta'
output_dir_path = f'/scratch/user/uqtdenni/snp_genotypes_combined.zarr'

In [3]:
# Functions
def get_contigs(ref_path, length_threshold=1e5):
    # now let's get a list of the contigs that we are going to call over
    contig_lengths = {}
    for record in SeqIO.parse(ref_path, "fasta"):
        seq_id = record.id
        seq_length = len(record.seq)
        contig_lengths[seq_id] = seq_length
    
    # Filter to called contigs (over 1e5 bp)
    return {k: v for k, v in sorted(contig_lengths.items(), key=lambda item: item[1], reverse=True) if v > length_threshold}

def load_samples():
    
    ## Returns the sorted unique list of samples from the sample manifest

    # Read the samples from the file
    fofn = pd.read_table('/home/uqtdenni/far_hin_1.x/metadata/fofn.tsv')

    # Get the unique list of samples from the DataFrame and sort
    samples = fofn['zarr_name'].dropna().unique().tolist()
    return samples

#called_contigs = get_contigs(ref_path)
#samples = load_samples()

In [5]:
called_contigs = get_contigs(ref_path)
samples = load_samples()

In [None]:
%%time
def open_sample_contig(sample, contig):
    zip_file = f"/scratch/user/uqtdenni/snp_genotypes_zarr/{sample}.zarr.zip"
    
    zip_fs = ZipFileSystem(zip_file)

    # Map the full path to one contig
    contig_path = f"{sample}.zarr/{contig}"
    
    zarr_store = zip_fs.get_mapper(contig_path)

    z = xr.open_zarr(zarr_store, consolidated=False)

    
    return z


def rechunk_dataset(ds, max_chunk_size=100_000):
    """
    Rechunk all data variables in an xarray.Dataset so that
    chunks along the 'variants' dimension are <= max_chunk_size,
    and samples/ploidy chunks stay reasonable.
    
    Args:
        ds (xr.Dataset): Input dataset
        max_chunk_size (int): max chunk size for variants dimension
    
    Returns:
        xr.Dataset: Rechunked dataset
    """
    rechunked_vars = {}
    for var in ds.data_vars:
        arr = ds[var]
        chunks = arr.chunks
        
        # If no chunks info (not dask array), skip rechunk
        if chunks is None:
            rechunked_vars[var] = arr
            continue
        
        dims = arr.dims
        
        # Default: keep all chunks except 'variants' dimension,
        # rechunk variants to max_chunk_size (or less if smaller)
        chunk_sizes = []
        for dim in dims:
            if dim == "variants":
                # set chunk size along variants dimension
                chunk_sizes.append(min(max_chunk_size, arr.sizes[dim]))
            else:
                # keep current chunk size or whole dimension if chunk info missing
                dim_idx = dims.index(dim)
                try:
                    chunk_sizes.append(chunks[dim_idx][0])
                except:
                    chunk_sizes.append(arr.sizes[dim])
        
        # Apply rechunk via dask
        rechunked_vars[var] = arr.chunk(dict(zip(dims, chunk_sizes)))
    
    return ds.update(rechunked_vars)

# The field you want to merge (example)

# Dictionary to hold merged contig arrays
merged_contigs = {}

for contig in ['KI915067']:
    print(f"Processing contig: {contig}")
    arrays = []
    for sample in samples:
        ds = open_sample_contig(sample, contig)  # your existing function
        arrays.append(ds)

    # Combine samples (dim='samples') across datasets
    merged_ds = xr.concat(arrays, dim="samples")

    # Fix dims if needed
    merged_ds = merged_ds.transpose("variants", "samples", ...)

    # Drop unwanted vars, convert ids to str
    # Drop unwanted variables
    merged_ds = merged_ds.drop_vars(['filter_id', 'variant_id','variant_id_mask','call_genotype_mask','call_genotype_phased','variant_id_mask','variant_quality','contig_length', 'contig_id','variant_filter','variant_contig'])

        # Calculate allele lengths (assumes variant_allele is an array of strings)
    allele_lengths = merged_ds['variant_allele'].astype(str).str.len()
    
    # Create boolean mask: True for indels (length > 1)
    is_indel = allele_lengths > 1
    
    # Variables representing genotype calls to mask — adjust as needed
    genotype_vars = ['call_genotype', 'variant_allele', 'call_AD', 'call_GQ','variant_MQ']
    
    for var in genotype_vars:
        if var in merged_ds:
            # Set genotype calls to -1 for indels, keep others as is
            merged_ds[var] = merged_ds[var].where(~is_indel, other=-1)

    # Rechunk 
    new_chunks = {'variants': 10000, 'samples': 10}

    for var_name, da in merged_ds.data_vars.items():
        # Only keep keys in new_chunks that are actually in the dims of this variable
        chunks_for_var = {dim: size for dim, size in new_chunks.items() if dim in da.dims}
        merged_ds[var_name] = da.chunk(chunks_for_var)
        #print(f"{var_name} chunks: {merged_ds[var_name].chunks}") 

        
    # Convert object dtype variables to string dtype
    for var in ['sample_id', 'variant_allele']:
        if var in merged_ds:
            merged_ds[var] = merged_ds[var].astype(str)
    
#    # Save the dataset
    output_dir = f"/scratch/user/uqtdenni/merged_zarr/{contig}.zarr"
    if os.path.exists(output_dir):
        import shutil
        shutil.rmtree(output_dir)

    merged_ds.to_zarr(output_dir, encoding={k: {} for k in merged_ds.data_vars})
    print(f"Saved merged dataset for {contig} to {output_dir}")

Processing contig: KI915067
