In [None]:
# Set up environment

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import sgkit as sg
import json
import hashlib
import itertools as it


import allel; print('scikit-allel', allel.__version__)

#px config
config = {
  'toImageButtonOptions': {
    'format': 'png', # one of png, svg, jpeg, webp
    'filename': 'custom_image',
    'height': 500,
    'width': 700,
    'scale':6 # Multiply title/legend/axis/canvas sizes by this factor
  }
}

#palettes
pop_code_cols = {
    'SAE' : '#6a3d9a', #dark purple
    'SAR' : '#cab2d6', #ligher purple
    'INB' : '#96172e', #darkred
    'INM' : '#f03e5e', #lightred
    'APA' : '#ff7f00', #orange
    'IRN' : '#C2907A', #not sure yet
    'DJI' : '#507d2a', #sap green
    'ETW' : '#a6cee3',#cerulean
    'ETB' : '#007272', #cobalt turq
    'ETS' : '#33a02c',#green
    'SUD' : '#fccf86',#ochre
    'YEM' : '#CC7722'#pinkish
}

#load and filter metadata
#load and filter metadata
df_samples = pd.read_csv('/Users/dennistpw/Projects/AsGARD/metadata/cease_combinedmetadata.20250212.csv')


In [None]:
# Set up population dictionary to store per-pop allele counts

# Initialize an empty dictionary to store row indices for each level
pop_dict = {}

# Iterate through unique levels in the 'factor_column'
for pop in df_samples['pop_code'].unique():
    pop_dict[pop] = df_samples.index[df_samples['pop_code'] == pop].tolist()

# Remove dictionary entries with fewer than 5 values
pop_dict = {key: value for key, value in pop_dict.items() if len(value) >= 5}

In [None]:
# Helper functions
def hash_params(*args, **kwargs):
    """Helper function to hash analysis parameters."""
    o = {
        'args': args,
        'kwargs': kwargs
    }
    s = json.dumps(o, sort_keys=True).encode()
    h = hashlib.md5(s).hexdigest()
    return h

#function for getting allele counts
def snp_allele_counts(
    #gets all allele counts for a given chrom

        ds=None,
        sample_query=None,
        sample_list=None,
        min_ac=1
        ):
     
    if sample_query:
        loc_samples_a = df_samples.eval(sample_query).values
        ds = ds.isel(samples=loc_samples_a)
    elif sample_list:
        loc_samples_a = df_samples['sample_id'].isin(sample_list)
        ds = ds.isel(samples=loc_samples_a)

    ac = allel.GenotypeArray(ds.call_genotype.values).count_alleles()

    #need to make maf filter
    #if min_ac

    return ac

# Function to run Fst genome scan
def do_fst_scan(
          chrom=None,
          sample_query_a=None, 
          sample_query_b=None,
          sample_list_a = None,
          sample_list_b=None,
          winsize=10000,
          analysis_name = 'fst',
          results_dir=None
    ):
            # construct a key to save the results under
    results_key = hash_params(
        chrom=chrom,
        sample_query_a=sample_query_a, 
        sample_query_b=sample_query_b,
        sample_list_a = sample_list_a,
        sample_list_b=sample_list_b, 
        winsize=winsize,
        analysis_name = analysis_name,
        results_dir=results_dir
        )

        # define paths for results files
    data_path = f'{results_dir}/{results_key}-fst.csv'

    try:
        # try to load previously generated results
        data = pd.read_csv(data_path)
        return data
    except FileNotFoundError:
        # no previous results available, need to run analysis
        print(f'running analysis: {results_key}')

    print('setting up inputs')

    #load ds
    ds = sg.load_dataset(f'/Users/dennistpw/Projects/AsGARD/data/variants_combined_cohorts/combined_cohorts.{chrom}.zarr')

    # Get accessible only
    print("Subsetting to accessible sites only")
    accmask = ds['is_accessible'].compute()
    ds = ds.sel(variants=(accmask))

    #get allele counts for a
    if sample_query_a:
        ac1 = snp_allele_counts(ds, sample_query = sample_query_a)
    else:
        ac1 = snp_allele_counts(ds, sample_list = sample_list_a)

    #get ac for b
    if sample_query_a:
        ac2 = snp_allele_counts(ds, sample_query = sample_query_b)
    else:
        ac2 = snp_allele_counts(ds, sample_list = sample_list_b)
    
    #get pos
    pos = ds.variant_position.values

    print("computing Fst")
    fst = allel.moving_hudson_fst(ac1, ac2, size=winsize)
    # Sometimes Fst can be very slightly below zero, clip for simplicity.
    fst = np.clip(fst, a_min=0, a_max=1)
    x = allel.moving_statistic(pos, statistic=np.mean, size=winsize)

    # save results
    fstdf = pd.DataFrame(
            {'chrom' : chrom,
            'midpos':x,
            'fst':fst})

    fstdf.to_csv(data_path, index=False)
    print(f'saved results: {results_key}')

    return(fstdf)

# Function to calculate genomewide average Fst

def do_fst_av(
          chrom=None,
          sample_query_a=None, 
          sample_query_b=None,
          sample_list_a = None,
          sample_list_b=None,
          winsize=10000,
          analysis_name = 'fst_average',
          results_dir=None
    ):
            # construct a key to save the results under
    results_key = hash_params(
        chrom=chrom,
        sample_query_a=sample_query_a, 
        sample_query_b=sample_query_b,
        sample_list_a = sample_list_a,
        sample_list_b=sample_list_b, 
        winsize=winsize,
        analysis_name = analysis_name,
        results_dir=results_dir
        )

        # define paths for results files
    data_path = f'{results_dir}/{results_key}-fst.csv'

    try:
        # try to load previously generated results
        data = pd.read_csv(data_path)
        return data
    except FileNotFoundError:
        # no previous results available, need to run analysis
        print(f'running analysis: {results_key}')

    print('setting up inputs')

    #load ds
    ds = sg.load_dataset(f'/Users/dennistpw/Projects/AsGARD/data/variants_combined_cohorts/combined_cohorts.{chrom}.zarr')

    #get allele counts for a
    if sample_query_a:
        ac1 = snp_allele_counts(ds, sample_query = sample_query_a)
    else:
        ac1 = snp_allele_counts(ds, sample_list = sample_list_a)

    #get ac for b
    if sample_query_a:
        ac2 = snp_allele_counts(ds, sample_query = sample_query_b)
    else:
        ac2 = snp_allele_counts(ds, sample_list = sample_list_b)
    
    #get pos
    pos = ds.variant_position.values

    print("allele counts and pos loaded")

    print("computing average Hudson Fst")
    fst = allel.average_hudson_fst(ac1, ac2, blen=winsize)


    fdict = {'fst':fst[0],'se':fst[1]}
    fstdf = pd.DataFrame([fdict])

    fstdf.to_csv(data_path, index=False)
    print(f'saved results: {results_key}')

    return(fstdf)



In [None]:
# Do genomescan (lots of output)

fstlist = []
winsize=10000
output = '/Users/dennistpw/Projects/AsGARD/data/fst_20240712'

for chrom in ['CM023248','CM023249','CM023250']:
    for popa, popb in it.combinations(df_samples.pop_code.unique(),2):
        
        fst_df = do_fst_scan(
            chrom = chrom,
            sample_query_a = f'pop_code == "{popa}"',
            sample_query_b = f'pop_code == "{popb}"',
            results_dir = output,
            analysis_name=f'{chrom}.{winsize}.{popa}.{popb}'
        )

        fst_df['popa'] = popa
        fst_df['popb'] = popb

        fstlist.append(fst_df)

    
fst_bigdf  = pd.concat(fstlist)

fst_bigdf['comp'] = fst_bigdf['popa'] +'_' + fst_bigdf['popb']

from matplotlib.gridspec import GridSpec

max_values = fst_bigdf.groupby('chrom')['midpos'].max()
total_max = max_values.sum()
column_widths = (max_values / total_max).values
col_var_levels = max_values.index
    
cyp6_region = {'CM023248': {'x_min': 67473117, 'x_max': 67501071, 'y_min': 0, 'y_max': 1}}
ace1_region = {'CM023248': {'x_min': 60916071, 'x_max': 60917000, 'y_min': 0, 'y_max': 1}}
vgsc_region = {'CM023249': {'x_min': 42817709, 'x_max': 42817800, 'y_min': 0, 'y_max': 1}}
gste_region = {'CM023249': {'x_min': 70572788, 'x_max': 70584603, 'y_min': 0, 'y_max': 1}}
carboxylesterase_region = {'CM023249':{'x_min' : 18816120, 'x_max' :18816120, 'y_min': 0, 'y_max': 1}}
cyp9_region = {'CM023250': {'x_min': 9721225, 'x_max': 9722225, 'y_min': 0, 'y_max': 1}}
diagk = {'CM023250': {'x_min': 4578144, 'x_max': 4578144, 'y_min': 0, 'y_max': 1}}

In [None]:
# Plot genome scan

# Get the unique row and column variables
row_var_levels = fst_bigdf['comp'].unique()

# Initialize the figure and GridSpec
fig = plt.figure(figsize=(sum(column_widths) * 14, len(row_var_levels) * 2))  # Plots half as high
gs = GridSpec(len(row_var_levels), len(col_var_levels) + 1, width_ratios=list(column_widths) + [0.1])

# Create the subplots
for row_idx, row_val in enumerate(row_var_levels):
    for col_idx, col_val in enumerate(col_var_levels):
        ax = fig.add_subplot(gs[row_idx, col_idx])
        subset = fst_bigdf[(fst_bigdf['comp'] == row_val) & (fst_bigdf['chrom'] == col_val)]
        sns.lineplot(data=subset, x='midpos', y='fst', ax=ax)

        ax.fill_between(x=subset['midpos'], y1=subset['fst'], color='blue', alpha=0.1)  # Adjust alpha for transparency if needed

        #ax.set_title(f"{row_val} - {col_val}")
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_visible(False)

        # Remove y-axis for 2nd and 3rd columns
        if col_idx > 0:
            ax.set_ylabel('')
            ax.yaxis.set_visible(False)

        if col_val == 'CM023248':
            ax.fill_betweenx([cyp6_region[col_val]['y_min'], cyp6_region[col_val]['y_max']],
                            cyp6_region[col_val]['x_min'], cyp6_region[col_val]['x_max'],
                            color='gray', alpha=0.3)
            ax.fill_betweenx([ace1_region[col_val]['y_min'], ace1_region[col_val]['y_max']],
                            ace1_region[col_val]['x_min'], ace1_region[col_val]['x_max'],
                            color='gray', alpha=0.3)
        elif col_val == 'CM023249':
            ax.fill_betweenx([vgsc_region[col_val]['y_min'], vgsc_region[col_val]['y_max']],
                            vgsc_region[col_val]['x_min'], vgsc_region[col_val]['x_max'],
                            color='gray', alpha=0.3)
            ax.fill_betweenx([gste_region[col_val]['y_min'], gste_region[col_val]['y_max']],
                            gste_region[col_val]['x_min'], gste_region[col_val]['x_max'],
                            color='gray', alpha=0.3)
            ax.fill_betweenx([carboxylesterase_region[col_val]['y_min'], carboxylesterase_region[col_val]['y_max']],
                            carboxylesterase_region[col_val]['x_min'], carboxylesterase_region[col_val]['x_max'],
                            color='gray', alpha=0.3)
        elif col_val == 'CM023250':
            ax.fill_betweenx([cyp9_region[col_val]['y_min'], cyp9_region[col_val]['y_max']],
                            cyp9_region[col_val]['x_min'], cyp9_region[col_val]['x_max'],
                            color='gray', alpha=0.3)
            ax.fill_betweenx([diagk[col_val]['y_min'], diagk[col_val]['y_max']],
                            diagk[col_val]['x_min'], diagk[col_val]['x_max'],
                            color='gray', alpha=0.3)
            
        # Remove x-axis
        ax.set_xlabel('')
        #ax.xaxis.set_visible(False)

        ax.set_ylim(0, 0.8)

        
    # Add the row title
    row_ax = fig.add_subplot(gs[row_idx, -1])
    row_ax.text(0.5, 0.5, row_val, va='center', ha='center', fontsize=12, rotation=90, transform=row_ax.transAxes)
    row_ax.axis('off')

plt.subplots_adjust(wspace=0.01, hspace=0.01)  # Reduce column and row padding
plt.tight_layout()
plt.show()
fig.savefig('../figures/fstscan.svg')


In [None]:
# Calculate average pairwise Fst between cohorts

fstlist = []
output = '/Users/dennistpw/Projects/AsGARD/data/fst_20240712'

for chrom in ['CM023248','CM023249','CM023250']:
    for popa, popb in it.combinations(df_samples.pop_code.unique(),2):
        
        hud_fst_df = do_fst_av(
            chrom = chrom,
            sample_query_a = f'pop_code == "{popa}"',
            sample_query_b = f'pop_code == "{popb}"',
            results_dir = output,
            #analysis_name=f'{chrom}.{winsize}.{popa}.{popb}.hudfst'
        )

        hud_fst_df['popa'] = popa
        hud_fst_df['popb'] = popb
        hud_fst_df['chrom'] = chrom

        fstlist.append(hud_fst_df)

In [None]:
# Plot heatmap

# Make ordered df for plotting as a heatmap

# Ordered list of cohorts
order = ['YEM',
 'SUD',
 'ETS',
 'ETB',
 'ETW',
 'DJI',
 'INM',
 'INB',
 'APA',
 'IRS',
 'IRH',
 'SAR',
 'SAE']

# Make df
fst_hud_df = pd.concat(fstlist)
sub = fst_hud_df[fst_hud_df['chrom'] == 'CM023248']
sub2 = sub.copy()
order = list(pop_code_cols.keys())
order.reverse()

sub2.columns = ['fst', 'se', 'popb', 'popa', 'chrom']
subc=pd.concat([sub,sub2])
pivot_df = subc.pivot(index="popa", columns="popb", values="fst")

#reorder df
reordered_df = pivot_df.loc[order, order]

# Create a mask for the lower triangle
mask = np.triu(np.ones_like(reordered_df, dtype=bool))
sns.set_style('white')

# Plot the heatmap with the mask
plt.figure(figsize=(8, 6))
# Plot the heatmap with the mask
plt.figure(figsize=(8, 6))
hm = sns.heatmap(
    reordered_df,
    mask=mask,
    cmap='viridis',  # Use viridis color palette
    annot=True,
    cbar=True,
    fmt=".2f",
    annot_kws={"size": 12.5}
)

# Customize axis labels
hm.set_xlabel('Population A')
hm.set_ylabel('Population B')
plt.savefig('../figures/fst_heatmap.svg')
plt.savefig('../figures/fst_heatmap.png')

In [None]:
# Now get pairwise Fst between all sample locations for spatial analaysis

# Now, fst between all invasive locations for analysis and modelling in R
loc_groups = df_samples.query('country != "Pakistan" & country != "Afghanistan" & country != "SaudiArabia" & country != "India"').groupby('location').count()
locs = loc_groups['sample_id'][loc_groups['sample_id'] > 5].index.tolist()
locs
#cheap manual reorder
locs = ['Nangarhar','DjiboutiCity','AdenCity','Dubti','Jiga','Modjo','Babile','KebriDehar','Danan','PortSudan','Haiya','AlShukria','EastElglabat','ElZedab','SouthShandi','GeziraIslang','Arkaweet','ElSalamaniaWest','ElMeaileg','Wafara','AlGalaa','Agaja']

#Let's quickly calculate Fst here by location, in locs with > 10 samples, to infer potential isoBD
fstlist = []
winsize=10000
output = '/Users/dennistpw/Projects/AsGARD/data/fst_byloc/'

for chrom in ['CM023248','CM023249','CM023250']:
    for popa, popb in it.combinations(locs,2):
        
        hud_fst_df = do_fst_av(
            chrom = chrom,
            sample_query_a = f'location == "{popa}"',
            sample_query_b = f'location == "{popb}"',
            results_dir = output,
            analysis_name=f'{chrom}.{winsize}.{popa}.{popb}.hudfst'
        )

        hud_fst_df['loca'] = popa
        hud_fst_df['locb'] = popb
        hud_fst_df['chrom'] = chrom

        fstlist.append(hud_fst_df)

