In [None]:
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import sgkit as sg
import json
import numba
import scipy
from scipy.spatial.distance import squareform
from scipy.spatial import distance
import scipy.cluster.hierarchy as sch
import hashlib
import itertools as it
import plotly.express as px
from IPython.display import Image
import xarray
import dask
sns.set_style('white')
sns.set_style('ticks')
sns.set_context('notebook')
import zarr
import allel; print('scikit-allel', allel.__version__)
import os
from multiprocessing.pool import ThreadPool
dask.config.set(pool=ThreadPool(20))
from dask.diagnostics import ProgressBar
# quieten dask warnings about large chunks
dask.config.set(**{'array.slicing.split_large_chunks': True})

#load and filter metadata
#df_samples[~df_samples['sample_id'].isin(missing_indv)]
#metadata = df_samples[~df_samples['sample_id'].isin(missing_indv)]
#metadata.reset_index(inplace=True)

#chrom dict
scaflens={'CM023248' : 93706023,
'CM023249' : 88747589,
'CM023250' : 22713616}

df_samples = pd.read_csv('/Users/dennistpw/Projects/AsGARD/metadata/cease_combinedmetadata_qcpass.20240914.csv')

analysis_popcols =  {'saudi_e': '#6a3d9a',
'saudi_r': '#cab2d6',
'india_m': '#a6cee3',
 'india_b': '#1f78b4',
 'afgh_pak': '#ff7f00',
 'djibouti': '#e31a1c',
 'ethiopia_n': '#33a02c',
 'ethiopia_so': '#b2df8a',
 'yemen': '#fdbf6f',
 'sudan': '#fb9a99'}



In [2]:
#define stevegen1000 functions
def select_random_genos(
                    ds, 
                    numgenos=100_000):
                     #selects given number of genos at random
                    keep_no = int(100_000)
                    keep_indices = np.random.choice(ds.call_genotype.shape[0], 100_000, replace=False)
                    keep_indices.sort()
                    thinned_callset = ds.isel(variants=~keep_indices)
                    return(thinned_callset)

def load_hap_ds(chrom, 
                sample_query=None, 
                numgenos=None, 
                sample_list=None, 
                start=None, 
                end=None, 
                min_minor_ac=0,
                df_samples=df_samples):
                 # load sample metadata
     #load ds
     ds = sg.load_dataset(f'/Users/dennistpw/Projects/AsGARD/data/variants_combined_cohorts/combined_cohorts.phased.{chrom}.zarr')

     #if sample query or list are specified, subset accordingly
     if sample_query:
          # locate selected samples
          loc_samples = df_samples.eval(sample_query).values
          df_samples = df_samples.loc[loc_samples, :]
          ds = ds.isel(samples=loc_samples)
     elif sample_list:
          loc_samples = df_samples['sample_id'].isin(sample_list)
          df_samples = df_samples.loc[loc_samples, :]
          ds = ds.isel(samples=loc_samples)
     else:
          pass
          
     #if numgenos is set, subset 
     if numgenos:
          ds_analysis = select_random_genos(ds)
     else:
          ds_analysis = ds

     #if region is set, subset to region
     if start:
               #subset to region of interest
              print(f"subsetting haps to range {chrom}:{start}-{end}")
              ds_analysis = ds_analysis.set_index(variants=("variant_contig", "variant_position")).sel(variants=(0, slice(start,end)))
     else:
            pass
     
     #if minmaf is specified, select minmaf     
     print(f'subsetting to segregating sites')
     ac = allel.GenotypeArray(ds_analysis['call_genotype']).count_alleles()
     macbool = ac[:,1] >= min_minor_ac
     print(f'selected {np.sum(macbool)} sites with a min mac > {min_minor_ac}')
     ds_analysis = ds_analysis.sel(variants=(macbool))

     #get accessible only
     print('subsetting to accessible sites only')
     accmask = ds_analysis['is_accessible'].compute()
     ds_analysis = ds_analysis.sel(variants=(accmask))

     #return completed ds
     return(df_samples, ds_analysis)
            
#hashing function
def hash_params(*args, **kwargs):
    """Helper function to hash analysis parameters."""
    o = {
        'args': args,
        'kwargs': kwargs
    }
    s = json.dumps(o, sort_keys=True).encode()
    h = hashlib.md5(s).hexdigest()
    return h

#takes ds, returns gwss of chrom and samples if specified
def do_gwss(
     chrom,
     analysis_name = None,
     sample_query = None,
     sample_list = None,
     numgenos = None, 
     winsize=100,
     results_dir=None,
     ):       

     # construct a key to save the results under
     results_key = hash_params(
          chrom=chrom,
          sample_list=sample_list,
          analysis_name=analysis_name,
          sample_query=sample_query,
          numgenos=numgenos,
          winsize=winsize
     )

     # define paths for results files
     data_path = f'{results_dir}/{results_key}-data.csv'

     try:
          # try to load previously generated results
          data = pd.read_csv(data_path)
          return data
     except FileNotFoundError:
          # no previous results available, need to run analysis
          print(f'running analysis: {results_key}')

     #get data
     print('setting up inputs')
     df_samples, ds_analysis = load_hap_ds(chrom=chrom,sample_query=sample_query)
     
     # load haps and do scan
     ht = allel.GenotypeArray(ds_analysis['call_genotype'])
     ht = ht.to_haplotypes()
     pos = ds_analysis['variant_position'].compute()
     hstats = allel.moving_garud_h(ht, size=winsize)
     x = allel.moving_statistic(pos, statistic=np.mean, size=winsize)       
     print(f"completed scan for {chrom}, window size {winsize}")
     
     #make df
     h123df = pd.DataFrame({'chrom' : chrom, 'midpos':x, 'h1':hstats[0],'h12':hstats[1], 'h123':hstats[2],'h1_h2':hstats[3]})

     # save results
     h123df.to_csv(data_path, index=False)
     print(f'saved results: {results_key}')

     return(h123df)



In [None]:
winsize=1000
sels = []
for analysis_pop in df_samples.analysis_pop.unique():
    for chrom in ['CM023248', 'CM023249', 'CM023250']:
        selscan = do_gwss(chrom = chrom,
                        sample_query = f'analysis_pop == "{analysis_pop}"',
                        winsize=winsize,
                        results_dir='/Users/dennistpw/Projects/AsGARD/data/selection_20240923',
                        analysis_name=f"{analysis_pop}.{chrom}.{winsize}.ud2"
                        )
        #selscan = pd.DataFrame({'pop' : [analysis_pop], 'chrom':[chrom]})
        selscan['pop'] = analysis_pop
        print(f"completed selscan for {chrom}, {analysis_pop}, {winsize}")
        sels.append(selscan)

In [6]:
#prepare for plotting

#import useful libs
from matplotlib.ticker import FuncFormatter
from matplotlib.gridspec import GridSpec

# Concat to big table
selscans_df = pd.concat(sels)

#prepare values for plotting
max_h1_dict = dict(selscans_df.groupby('pop')['h1'].max())
max_values = selscans_df.groupby('chrom')['midpos'].max()
total_max = max_values.sum()
column_widths = (max_values / total_max).values
col_var_levels = max_values.index

#gene regions
cyp6_region = {'CM023248': {'x_min': 67473117, 'x_max': 67501071, 'y_min': 0, 'y_max': 1}}
ace1_region = {'CM023248': {'x_min': 60916071, 'x_max': 60917000, 'y_min': 0, 'y_max': 1}}
vgsc_region = {'CM023249': {'x_min': 42817709, 'x_max': 42817800, 'y_min': 0, 'y_max': 1}}
gste_region = {'CM023249': {'x_min': 70572788, 'x_max': 70584603, 'y_min': 0, 'y_max': 1}}
rdl_region = {'CM023249': {'x_min': 8345440, 'x_max': 8348441, 'y_min': 0, 'y_max': 1}}
cyp9_region = {'CM023250': {'x_min': 9721225, 'x_max': 9722225, 'y_min': 0, 'y_max': 1}}
carboxyl_cluster = {'CM023249':{'x_min' : 18784698, 'x_max' :18801820, 'y_min': 0, 'y_max': 1}}
COEJHE5E_region = {'CM023249':{'x_min' : 26291697, 'x_max' :26310369, 'y_min': 0, 'y_max': 1}}
diagk = {'CM023250': {'x_min': 4578144, 'x_max': 4578144, 'y_min': 0, 'y_max': 1}}

#order rows for plotting
rowvars_reordered = ['saudi_e', 'saudi_r', 'afgh_pak', 'india_b', 'india_m', 'djibouti', 'ethiopia_n', 'ethiopia_so', 'sudan', 'yemen']

In [None]:
# Get the unique row and column variables
row_var_levels = selscans_df['pop'].unique()

#set y labs to 2 dp
def format_func(value, tick_number):
    return f"{value:.2f}"

# Initialize the figure and GridSpec
fig = plt.figure(figsize=(sum(column_widths) * 25, len(row_var_levels) * 1.5))  # Plots half as high
gs = GridSpec(len(row_var_levels), len(col_var_levels) + 1, width_ratios=list(column_widths) + [0.1])

# Create the subplots
for row_idx, row_val in enumerate(rowvars_reordered):
    for col_idx, col_val in enumerate(col_var_levels):
        row_colour = analysis_popcols.get(row_val, 'black')  # Default color if not specified

        ax = fig.add_subplot(gs[row_idx, col_idx])
        subset = selscans_df[(selscans_df['pop'] == row_val) & (selscans_df['chrom'] == col_val)]
        sns.lineplot(data=subset, x='midpos', y='h1', ax=ax, linewidth=0.8, color=row_colour)

        ax.fill_between(x=subset['midpos'], y1=subset['h1'], color=row_colour, alpha=0.1)  # Adjust alpha for transparency if needed


        #Let's start by tinkering with the axes

        #set max on y to be max h1 for each pop - this emphasises the sweep signal but we must note in the legend!
        ax.set_ylim(0, max_h1_dict[row_val])
        ax.yaxis.set_major_formatter(FuncFormatter(format_func))
        #set y max as ticks
        ax.set_yticks([0, max_h1_dict[row_val]])
        ax.set_ylabel("")
        plt.subplots_adjust(top=0.85) 

        #now we want to remove clutter as much as possible
        #despine
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_visible(False)

        #remove x axes except bottom one
        if row_idx < len(rowvars_reordered) - 1:
            ax.xaxis.set_tick_params(which='both', bottom=False, top=False, labelbottom=False)  # Hide x-axis ticks and labels

        # Remove y-axis for 2nd and 3rd columns
        if col_idx > 0:
            ax.set_ylabel('')
            ax.yaxis.set_visible(False)

        # Remove x-axis titles and replace with chromosome numbering
        ax.set_xlabel('')
        if row_idx == 9:
            if col_idx == 0:
                ax.set_xlabel("2", fontsize=24, fontweight='bold')
            elif col_idx == 1:
                ax.set_xlabel("3", fontsize=24, fontweight='bold')
            elif col_idx == 2:
                ax.set_xlabel("X", fontsize=24, fontweight='bold')    # Add the row title

        #add row (population) labels
        row_ax = fig.add_subplot(gs[row_idx, -1])
        row_ax.text(0.5, 0.5, row_val, va='center', ha='center', fontsize=18, rotation=0, transform=row_ax.transAxes)
        row_ax.axis('off')

        #now, let's annotate our plots

        #add text annotations and lines for IR genes
        if col_val == 'CM023248' and row_idx == 0:
            # Cyp6
            ax.fill_betweenx([cyp6_region[col_val]['y_min'], cyp6_region[col_val]['y_max']],
                cyp6_region[col_val]['x_min'], cyp6_region[col_val]['x_max'],
                color='gray', alpha=0.5)
            ax.text(cyp6_region[col_val]['x_min'], 1, "Cyp6", fontsize=18, rotation=90)
            # Ace1
            ax.text(ace1_region[col_val]['x_min'], 1, "Ace1", fontsize=18, rotation=90)
            ax.fill_betweenx([ace1_region[col_val]['y_min'], ace1_region[col_val]['y_max']],
                ace1_region[col_val]['x_min'], ace1_region[col_val]['x_max'],
                color='gray', alpha=0.5)

        elif col_val == 'CM023249' and row_idx == 0:
            #Vgsc
            ax.fill_betweenx([vgsc_region[col_val]['y_min'], vgsc_region[col_val]['y_max']],
                vgsc_region[col_val]['x_min'], vgsc_region[col_val]['x_max'],
                color='gray', alpha=0.5)
            ax.text(vgsc_region[col_val]['x_min'], 1, "Vgsc", fontsize=18, rotation=90)

            #Gste
            ax.fill_betweenx([gste_region[col_val]['y_min'], gste_region[col_val]['y_max']],
                gste_region[col_val]['x_min'], gste_region[col_val]['x_max'],
                color='gray', alpha=0.5)
            ax.text(gste_region[col_val]['x_min'], 1, "Gste", fontsize=18, rotation=90)

            #Rdl
            ax.fill_betweenx([rdl_region[col_val]['y_min'], rdl_region[col_val]['y_max']],
                rdl_region[col_val]['x_min'], rdl_region[col_val]['x_max'],
                color='gray', alpha=0.5)
            ax.text(rdl_region[col_val]['x_min'],1, "Rdl", fontsize=18, rotation=90)

            #Coeae
            ax.fill_betweenx([carboxyl_cluster[col_val]['y_min'], carboxyl_cluster[col_val]['y_max']],
                            carboxyl_cluster[col_val]['x_min'], carboxyl_cluster[col_val]['x_max'],
                            color='gray', alpha=0.5)
            ax.text(carboxyl_cluster[col_val]['x_min'], 1, "Coeae", fontsize=18, rotation=90)

            #CoeJH
            ax.fill_betweenx([COEJHE5E_region[col_val]['y_min'], COEJHE5E_region[col_val]['y_max']],
                            COEJHE5E_region[col_val]['x_min'], COEJHE5E_region[col_val]['x_max'],
                            color='gray', alpha=0.5)
            ax.text(COEJHE5E_region[col_val]['x_min'], 1, "Coejh", fontsize=18, rotation=90)

        elif col_val == 'CM023250' and row_idx == 0:
            #Cyp9k1
            ax.fill_betweenx([cyp9_region[col_val]['y_min'], cyp9_region[col_val]['y_max']],
                cyp9_region[col_val]['x_min'], cyp9_region[col_val]['x_max'],
                color='gray', alpha=0.5)
            ax.text(cyp9_region[col_val]['x_min'], 1, "Cyp9k1", fontsize=18, rotation=90)

            #Diagk
            ax.fill_betweenx([diagk[col_val]['y_min'], diagk[col_val]['y_max']],
                diagk[col_val]['x_min'], diagk[col_val]['x_max'],
                color='gray', alpha=0.5)
            ax.text(diagk[col_val]['x_min'], 1, "Diagk", fontsize=18, rotation=90)

#Add master Y labels
fig.text(0.08, 0.5, 'H1', va='center', rotation='vertical', fontweight='bold', fontsize=18)

plt.subplots_adjust(wspace=0, hspace=0.2)  # Reduce column and row padding

fig.savefig('../figures/h12.svg')

plt.show()

#save as svg

## Haplotype clustering
Now let's take a look at some haplotype clustering dendrograms to see whether regions apparently under selection are shared between populations

In [1]:
#define some clustering and plotting functions
#get dist
@numba.njit(parallel=True)
def pdist_abs_hamming(X):
    n_obs = X.shape[0]
    n_ftr = X.shape[1]
    out = np.zeros((n_obs, n_obs), dtype=np.int32)
    for i in range(n_obs):
        x = X[i]
        for j in numba.prange(i + 1, n_obs):
            y = X[j]
            d = 0
            for k in range(n_ftr):
                if x[k] != y[k]:
                    d += 1
            out[i, j] = d
            out[j, i] = d
    return out

def plot_dendrogram(h, ax, df_samples):

    #Helper vars I am not fussed about
    color_threshold=0
    above_threshold_color='k'
    linkage_method = 'single'
    distance_sort=False
    count_sort=False

    dist = allel.pairwise_distance(h, 'hamming') * h.shape[0]

    # Hierarchical clustering.
    Z = sch.linkage(dist, method=linkage_method)

        #harmonise metadata with samples
    # Align sample metadata with haplotypes.

    # Repeat the dataframe so there is one row of metadata for each haplotype.
    leaf_data = pd.DataFrame(np.repeat(df_samples.values, 2, axis=0))
    leaf_data.columns = df_samples.columns

    sns.despine(ax=ax, offset=5, bottom=True, top=False)
    # Compute the dendrogram but don't plot it.
    dend = sch.dendrogram(
        Z,
        count_sort=count_sort,
        distance_sort=distance_sort,
        no_plot=False,
        color_threshold=color_threshold,
        above_threshold_color=above_threshold_color,
    )

    ax.get_xaxis().set_visible(False)
    ax.spines['top'].set_visible(False)
    sns.despine(ax=None, fig=None, offset=5, left=True, bottom=True, top=True, right=True)
    ax.set_ylim(bottom=-2)
    ax.set_ylabel('Distance (no. SNPs)')
    ax.autoscale(axis='x', tight=True)

    return Z, dend, leaf_data

def fig_hap_structure(h, df_samples=df_samples, h_display=None, mutations=None, vspans=[[]], cluster_labels=[], figsize=(10, 8), 
                      fn=None, dpi=150, height_ratios=(3, .2, 1.5, .2), hap_pops=None, legend=False, title=None):
    
    # create the figure
    fig = plt.figure(figsize=figsize)
    
    # define subplot layout
    gs_nrows = 2
    gs_ncols = 1
    gs = mpl.gridspec.GridSpec(gs_nrows, gs_ncols, hspace=0.04, wspace=0.04,
                               height_ratios=height_ratios)
    
    # dendrogram
    ax_dend = fig.add_subplot(gs[0, 0])
    z, r, leafdata = plot_dendrogram(ht, ax, df_samples)
    ax_dend.set_ylim(bottom=-5)
    if legend:
        handles = [mpl.patches.Patch(color=analysis_popcols[pop], label=pop) for pop in analysis_popcols.keys()]
        ax_dend.legend(handles=handles, loc='upper right', bbox_to_anchor=(1, 1), ncol=3)
    ax_dend.set_yticklabels(ax_dend.get_yticks().astype(int))
    ax_dend.set(xticklabels=[])
    ax_dend.set_title(title,fontdict={'fontsize':18}, loc='right')

    ax_dend.xaxis.set_tick_params(length=3, pad=2)
    ax_dend.yaxis.set_tick_params(length=3, pad=2)

    # population colours
    ax_pops = fig.add_subplot(gs[1, 0])
    if hap_pops is None:
        hap_pops = leafdata.analysis_pop.values
    x = hap_pops.take(r['leaves'])
    hap_clrs = [analysis_popcols[p] for p in x]
    ax_pops.broken_barh(xranges=[(i, 1) for i in range(h.shape[1])], yrange=(0, 1), color=hap_clrs);
    sns.despine(ax=ax_pops, offset=2, left=True, bottom=True, top=True, right=True)
    ax_pops.set_xticks([])
    ax_pops.set_yticks([])
    ax_pops.set_xlim(0, h.shape[1])
    ax_pops.yaxis.set_label_position('left')
    #ax_pops.set_ylabel('Population', rotation=0, ha='right', va='center')
    fig.savefig(f'../figures/{title}.svg')

In [342]:
#dict of gene specific locations
genes_main = {
    'Ace1' : {'chrom' : 'CM023248', 'start' : 67473117, 'end': 67501071},
    'Coeae' : {'chrom' : 'CM023249', 'start' : 18779698, 'end': 18804511},
    'Gste' : {'chrom' : 'CM023249', 'start' : 70582788, 'end': 70594603},
    'Cyp6' : {'chrom' : 'CM023248', 'start' : 67470071, 'end': 67514071},

}

genes_supp = {
    'Vgsc' : {'chrom' : 'CM023249', 'start' : 42804885, 'end': 42848176},
    'Rdl' : {'chrom' : 'CM023249', 'start' : 8345440, 'end': 8348441},
    'Cyp9k1' : {'chrom' : 'CM023250', 'start' : 9713374, 'end': 9729212},
}

In [None]:
#Now plot Genes for Main text

fig, axs = plt.subplots(nrows=4, ncols=1, figsize=(6, 12))
for i, gene in enumerate(genes_main.keys()):
    haps = load_hap_ds(chrom = genes_main[gene]['chrom'],
        start=genes_main[gene]['start'],
        end=genes_main[gene]['end'],
        min_minor_ac=1)

    ht = allel.GenotypeArray(haps[1]['call_genotype']).to_haplotypes()
    #iterate over genes and plot subplots
    fig_hap_structure(h=ht, df_samples=df_samples, title=gene)


In [None]:
#Now plot Genes for Supp
fig, axs = plt.subplots(nrows=4, ncols=1, figsize=(6, 12))
for i, gene in enumerate(genes_supp.keys()):
    haps = load_hap_ds(chrom = genes_supp[gene]['chrom'],
        start=genes_supp[gene]['start'],
        end=genes_supp[gene]['end'],
        min_minor_ac=1)

    ht = allel.GenotypeArray(haps[1]['call_genotype']).to_haplotypes()
    #iterate over genes and plot subplots
    fig_hap_structure(h=ht, df_samples=df_samples, title=gene)