In [1]:
import xarray
import zarr
import bed_reader
import allel
from Bio import SeqIO
import numpy as np
import pandas as pd

In [6]:
# Helper funcs
# Define metadata and qc bool globally to start
#define useful variables
zarr_base_path = f"/scratch/user/uqtdenni/afar_production_bunya/curation/uq-beebe-001/combined_zarr/{{contig}}.zarr"

# Let's start by converting zarrs for the 5 largest contigs - we can do QC on them...
ref_path = '/scratch/user/uqtdenni/afar_production_bunya/reference/VectorBase-54_AfarautiFAR1_Genome.fasta'
# now let's get a list of the contigs that we are going to call over
contig_lengths = {}
for record in SeqIO.parse(ref_path, "fasta"):
    seq_id = record.id
    seq_length = len(record.seq)
    contig_lengths[seq_id] = seq_length
filtered_contigs = {k: v for k, v in sorted(contig_lengths.items(), key=lambda item: item[1], reverse=True) if v > 100000}

# Because these data are unstaged, we need to faff about a bit more and load the unstaged metadata to exclude extra dud samples
# And load the final (cleaned) metadata
df_samples = pd.read_csv('/scratch/user/uqtdenni/far_hin_1.x/work/metadata_development_20250702/metadata-staged-speciesconfirmed-20251011.txt',sep='\t')


# Zarr location
zarr_base_path = f"/scratch/user/uqtdenni/afar_production_bunya/curation/uq-beebe-001/staged_zarr/{{contig}}.zarr"

# Define helper functions

In [15]:
def hash_params(*args, **kwargs):
    """Helper function to hash analysis parameters."""
    o = {
        'args': args,
        'kwargs': kwargs
    }
    s = json.dumps(o, sort_keys=True).encode()
    h = hashlib.md5(s).hexdigest()
    return h

# Define helper functions
def load_genotype_array(contig, df_samples=df_samples, sample_query = None, n_snps=None):
    # Load gts and remove failed qc samples
    z = zarr.open(zarr_base_path.format(contig=contig))
    
    # Variant-level mask: punctulatus_group_filter_pass
    filter_mask = z[f"{contig}/filter_pass"][:]
    
    # Apply combined variant mask
    gt = allel.GenotypeChunkedArray(z[f"calldata/GT"])
    gt = gt.compress(filter_mask, axis=0)    # Filter variants

    # If an additional mask is supplied to subset the data from the finished metadata, apply, else return all samples
    if sample_query is not None:
        bool_query = np.array(df_samples.eval(sample_query))
        gt = gt.compress(bool_query, axis=1)
    if n_snps is not None:
            gt = select_random_genotypes_sorted(gt, n_snps)

    return gt


def select_random_elements_sorted(g, x, replace=False, seed=None):
    """
    Select x random rows from a 2D array, returned in sorted order.

    Parameters:
    - array: array 2d, shape (n_genotypes, n_features)
    - x: int, number of rows to select
    - replace: bool, whether sampling is with replacement (default: False)
    - seed: int, random seed for reproducibility (default: None)

    Returns:
    - np.ndarray of shape (x, n_features)
    """

    # Select random sites from that set
    rng = np.random.default_rng(seed)
    n_rows = g.shape[0]
    if not replace and x > n_rows:
        raise ValueError(f"Cannot select {x} rows without replacement from {n_rows} total rows.")
    indices = rng.choice(n_rows, size=x, replace=replace)
    sorted_indices = np.sort(indices)

    return g[sorted_indices]

def compute_ac(contig, is_biallelic=True, is_segregating=True, min_minor_ac=1, n_snps=None, sample_query=None, to_alt = None):
    
    g = load_genotype_array(contig=contig, sample_query=sample_query)  
    
    ac = g.count_alleles()
    
    mask = None
    
    # Apply biallelic filter
    if is_biallelic:
        biallelic_mask = ac.is_biallelic()
        mask = biallelic_mask if mask is None else mask & biallelic_mask
    
    # Apply segregating filter
    if is_segregating:
        segregating_mask = ac.is_segregating()
        mask = segregating_mask if mask is None else mask & segregating_mask
    
    # Apply minor allele count filter
    if min_minor_ac is not None:
        an = ac.sum(axis=1)
    # Apply minor allele count condition.
        ac_minor = ac[:, 1:].sum(axis=1)
        if isinstance(min_minor_ac, float):
            ac_minor_frac = ac_minor / an
            loc_minor_mask = ac_minor_frac >= min_minor_ac
        else:
            loc_minor_mask = ac_minor >= min_minor_ac
        mask = loc_minor_mask if mask is None else mask & loc_minor_mask
    
    # Apply all filters at once
    if mask is not None:
        gt = g.compress(mask)
    
    # Random selection (if needed)
    if n_snps is not None:  # Fixed: 'if' instead of 'is'
        gt = select_random_elements_sorted(gt, n_snps)
    
    if to_alt is not None:
        return gt.to_n_alt()
    else:
        return gt.count_alleles()
    



In [12]:
sample_query = 'species_pca == "hinesorum"'
contig = "KI915040"
n_snps = 100_000

In [19]:
# Load data, set some filters (filter mask pass, variable, biallelic, not a singleston)
x = zarr.open(zarr_base_path.format(contig=contig))
g = allel.GenotypeArray(x['calldata/GT'])

# Start making filters over the whole dataset (e.g. allelism)
ac = g.count_alleles() #Count alleles

# Initial filters
flt = ac.is_biallelic() & (ac.max_allele() == 1)
filter_mask = x['variants/filter_pass']
has_asterisk = ~np.any(x['variants/ALT'][:] == '*', axis=1)

# Combine initial filters
initial_mask = flt & filter_mask & has_asterisk

# Apply initial mask to ALT array to check for truly biallelic sites
alt_temp = x['variants/ALT'][:][initial_mask]
truly_biallelic_subset = np.sum(alt_temp != '', axis=1) == 1

# Create final mask by combining initial mask with truly biallelic filter
final_mask = np.zeros(len(initial_mask), dtype=bool)
final_mask[initial_mask] = truly_biallelic_subset

print(f"Original sites: {len(initial_mask)}")
print(f"After initial filters: {np.sum(initial_mask)}")
print(f"After truly biallelic filter: {np.sum(final_mask)}")

Original sites: 7175601
After initial filters: 3761669
After truly biallelic filter: 3660986


In [21]:
# Subset to farauti_ss inds passing filters
#far_inds = df_samples.eval(sample_query)
g_f = g#.compress(far_inds, axis=1)

In [23]:
# Assemble everything
gt = select_random_elements_sorted(g_f.compress(final_mask), 100_000, seed=1234)
c = select_random_elements_sorted(x['variants/CHROM'][:].compress(final_mask), 100_000, seed=1234)
pos = select_random_elements_sorted(x['variants/POS'][:].compress(final_mask), 100_000, seed=1234)
ref = select_random_elements_sorted(x['variants/REF'][:].compress(final_mask), 100_000, seed=1234)
alt = select_random_elements_sorted(x['variants/ALT'][:][final_mask], 100_000, seed=1234)[:, 0]

gn = gt.to_n_alt().T
iid = list(df_samples.index)


properties = {
        "iid": iid,
        "chromosome": np.ones(100000, dtype=int), # just ones as admix kicks off if you have text as chrom ids
        "bp_position": pos.astype(int),
        "allele_1": ref,
        "allele_2":alt,
    }

bed_reader.to_bed(
    filepath='all_samples_1e5.bed',
    val=gn,
    properties=properties,
    count_A1=True,
)