In [1]:
import pandas as pd
import numpy as np
import anjl
import numba
import sgkit as sg
import allel

In [16]:
#chrom dict
scaflens={'CM023248' : 93706023,
'CM023249' : 88747589,
'CM023250' : 22713616}

df_samples = pd.read_csv('/Users/dennistpw/Projects/AsGARD/metadata/cease_combinedmetadata_noqc.20250212.csv')

#palettes
#palettes
pop_code_cols = {
    'APA' : '#ff7f00', #orange
    'SAE' : '#6a3d9a', #dark purple
    'SAR' : '#cab2d6', #ligher purple
    'IRS' : '#c27a88', #not sure yet
    'IRH' : '#c57fc9', #not sure yet
    'INB' : '#96172e', #darkred
    'INM' : '#f03e5e', #lightred
    'DJI' : '#507d2a', #sap green
    'ETB' : '#007272', #cobalt turq
    'ETS' : '#33a02c',#green
    'ETW' : '#a6cee3',#cerulean
    'SUD' : '#fccf86',#ochre
    'YEM' : '#CC7722'#pinkish
}

contigs_dict = {'CM023248':0,
                'CM023249':1,
                'CM023250':2}

In [14]:
#define stevegen500 functions
def select_random_genos(
                    ds, 
                    n_snps=100_000):
                     #selects given number of genos at random
                    keep_indices = np.random.choice(ds.call_genotype.shape[0], n_snps, replace=False)
                    keep_indices.sort()
                    thinned_callset = ds.isel(variants=~keep_indices)
                    return(thinned_callset)

def load_geno_ds(chrom, 
                 df_samples,
                sample_query=None, 
                n_snps=None, 
                sample_list=None, 
                start=None, 
                end=None, 
                min_minor_ac=None,
                acconly=True
                ):
                 # load sample metadata
     #load ds
     ds = sg.load_dataset(f'/Users/dennistpw/Projects/AsGARD/data/variants_combined_cohorts/combined_cohorts.{chrom}.zarr')

     if sample_query:
          # locate selected samples
        loc_samples = df_samples.query(sample_query).index
        df_samples = df_samples.loc[loc_samples, :]
        ds = ds.isel(samples=loc_samples)
     elif sample_list:
          loc_samples = df_samples['sample_id'].isin(sample_list)
          df_samples = df_samples.loc[loc_samples, :]
          ds = ds.isel(samples=loc_samples)
     else:
          pass
          
     
     #if minmaf is specified, select minmaf  
     if min_minor_ac:   
        print(f'subsetting to segregating sites')
        ac = allel.GenotypeArray(ds['call_genotype']).count_alleles()
        macbool = ac[:,1] >= min_minor_ac
        print(f'selected {np.sum(macbool)} sites with a min mac > {min_minor_ac}')
        ds_analysis = ds.sel(variants=(macbool))
     else:
        ds_analysis = ds

     #get accessible only
     if acconly:
        print('subsetting to accessible sites only')
        accmask = ds_analysis['is_accessible'].compute()
        ds_analysis = ds_analysis.sel(variants=(accmask))
     else:
            pass
     
    #if numgenos is set, subset 
     if n_snps:
          ds_analysis = select_random_genos(ds_analysis, n_snps=n_snps)
     elif start:
        print(f"subsetting haps to range {chrom}:{start}-{end}")
        ds_analysis = ds_analysis.set_index(variants=("variant_contig", "variant_position")).sel(variants=(0, slice(start,end)))
     else:
            pass

     #return completed ds
     return(df_samples, ds_analysis)


# Define helper functions
@numba.njit
def square_to_condensed(i, j, n):
    """Convert distance matrix coordinates from square form (i, j) to condensed form."""

    assert i != j, "no diagonal elements in condensed matrix"
    if i < j:
        i, j = j, i
    return n * j - j * (j + 1) // 2 + i - 1 - j

@numba.njit
def genotype_cityblock(x, y):
    n_sites = x.shape[0]
    distance = np.float32(0)

    # Loop over sites.
    for i in range(n_sites):
        # Compute cityblock distance (absolute difference).
        d = np.fabs(x[i] - y[i])

        # Accumulate distance for the current pair.
        distance += d

    return distance

@numba.njit(parallel=True)
def biallelic_diplotype_pdist(X):
    n_samples = X.shape[0]
    n_pairs = (n_samples * (n_samples - 1)) // 2
    out = np.zeros(n_pairs, dtype=np.float32)

    # Loop over samples, first in pair.
    for i in range(n_samples):
        x = X[i, :]

        # Loop over observations again, second in pair.
        for j in numba.prange(i + 1, n_samples):
            y = X[j, :]

            # Compute distance for the current pair.
            d = genotype_cityblock(x, y)

            # Store result for the current pair.
            k = square_to_condensed(i, j, n_samples)
            out[k] = d

    return out


def get_pdist( 
    chrom,
    df_samples,
    sample_query=None,
    min_minor_ac=1,
    n_snps=100_000,
    sample_list=None):   

     # access SNP calls
    sample_df, ds = load_geno_ds(
          chrom=chrom,
          df_samples=df_samples,
          n_snps=n_snps,
          min_minor_ac=min_minor_ac,
          sample_query=sample_query,
          sample_list=sample_list,
          acconly = True)
    
   
    # access SNP genotypes
    gn = allel.GenotypeArray(ds["call_genotype"]).to_n_alt()

    X = np.ascontiguousarray(gn.T)

    dist = biallelic_diplotype_pdist(X)
    np.save(f"/Users/dennistpw/Projects/AsGARD/data/{chrom}.{n_snps}.{min_minor_ac}.dist.npy", dist)
    return(dist)

In [10]:
from scipy.spatial.distance import squareform  # type: ignore


def plot_njt(chrom, filename, n_snps):
    # Get pairwise distance
    dist = get_pdist(chrom=chrom,df_samples=df_samples, n_snps=n_snps)
    #Coerce to square matrix
    D = squareform(dist)
    # Build the NJ tree
    progress_options = dict(desc="Construct neighbour-joining tree", leave=False)
    Z = anjl.canonical_nj(
                    D=D, progress_options=progress_options
    )
    # Plot
    nj_2 = anjl.plot(
                Z=Z,
                leaf_data=df_samples,
                color="pop_code",
                color_discrete_map=pop_code_cols,
    )
    #Save
    nj_2.write_image(f"/Users/dennistpw/Projects/AsGARD/figures/{filename}.svg")
    nj_2.show()
    return(Z)


In [17]:
z3 = plot_njt('CM023248',
         filename='njt_chr2',
         n_snps=100_000)

subsetting to segregating sites
selected 8730863 sites with a min mac > 1
subsetting to accessible sites only


In [18]:
pd.DataFrame(z3)

Unnamed: 0,0,1,2,3,4
0,27.0,36.0,674.471741,629.528259,2.0
1,518.0,519.0,2627.079346,2635.920654,2.0
2,386.0,391.0,255.929611,270.070374,2.0
3,516.0,517.0,2656.447754,2748.552246,2.0
4,514.0,552.0,3071.343018,498.656891,3.0
...,...,...,...,...,...
545,1094.0,1095.0,61.976585,53.060783,94.0
546,1054.0,1068.0,334.324493,255.190628,28.0
547,1096.0,1097.0,45.354263,36.320297,122.0
548,1090.0,1098.0,58.901215,17.303772,532.0


In [24]:
plot_njt('CM023249',
         filename='njt_chr3',
        n_snps=100_000)

subsetting to segregating sites
selected 8714133 sites with a min mac > 1
subsetting to accessible sites only


array([[3.1700000e+02, 3.3100000e+02, 3.7880875e+02, 3.7419125e+02,
        2.0000000e+00],
       [3.8600000e+02, 3.9100000e+02, 2.7561313e+02, 2.8738687e+02,
        2.0000000e+00],
       [5.1400000e+02, 5.1700000e+02, 3.0557449e+03, 2.9312551e+03,
        2.0000000e+00],
       ...,
       [1.0950000e+03, 1.0960000e+03, 5.8687050e+01, 4.4551704e+01,
        4.9400000e+02],
       [1.0970000e+03, 1.0980000e+03, 1.9367561e+01, 5.7248611e+00,
        5.4500000e+02],
       [1.0310000e+03, 1.0990000e+03, 2.7202359e+02, 2.7202359e+02,
        5.5100000e+02]], dtype=float32)

In [23]:
plot_njt('CM023250',
         filename='njt_chrX',
        n_snps=100_000)

subsetting to segregating sites
selected 1363056 sites with a min mac > 1
subsetting to accessible sites only


array([[5.18000000e+02, 5.19000000e+02, 1.16446631e+03, 1.16453369e+03,
        2.00000000e+00],
       [5.14000000e+02, 5.15000000e+02, 3.54996069e+03, 3.63403931e+03,
        2.00000000e+00],
       [2.07000000e+02, 2.09000000e+02, 2.44087753e+02, 2.25912247e+02,
        2.00000000e+00],
       ...,
       [1.08000000e+03, 1.09700000e+03, 1.14754105e+02, 1.19995575e+01,
        2.80000000e+01],
       [1.09300000e+03, 1.09800000e+03, 3.33803558e+01, 9.85740662e+00,
        3.50000000e+01],
       [1.09600000e+03, 1.09900000e+03, 8.38593292e+00, 8.38593292e+00,
        5.51000000e+02]], dtype=float32)