In [1]:
import sgkit as sg
import allel
import plotly.express as px
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import glob
from sklearn.mixture import GaussianMixture
from sklearn.utils.extmath import row_norms

Treemix takes a tab delimited file of allele counts by population - let's define our populations here based on the countries and the pca

In [26]:
ds = sg.load_dataset('/home/dennist/lstm_data/cease/variants_bycohort/combined_cohorts/zarr/combined_cohorts.CM023249.zarr/')

#load and filter metadata
#load and filter metadata
df_samples = pd.read_table('/home/dennist/lstm_data/cease/variant_metadata/cease.combined.metadata.20240703.txt')
df_samples = df_samples[df_samples['qc_pass'] == 1].sort_values('order').reset_index()

#extract gts
gt = allel.GenotypeArray(ds.call_genotype)


#next, ld prune. this takes a wee while

In [27]:
def ld_prune(gn, size, step, threshold=.1, n_iter=1):
    for i in range(n_iter):
        loc_unlinked = allel.locate_unlinked(gn, size=size, step=step, threshold=threshold)
        n = np.count_nonzero(loc_unlinked)
        n_remove = gn.shape[0] - n
        print('iteration', i+1, 'retaining', n, 'removing', n_remove, 'variants')
    return loc_unlinked

In [28]:
n = 500000  # number of SNPs to choose randomly
vidx = np.random.choice(gt.shape[0], n, replace=False)
vidx.sort()
gtr = gt.take(vidx, axis=0)
gnr = gtr.to_n_alt()
locun = ld_prune(gnr, size=500, step=200, threshold=.1, n_iter=2)


iteration 1 retaining 150014 removing 349986 variants
iteration 2 retaining 150014 removing 349986 variants


In [29]:
np.sum(locun)

150014

In [30]:
gt_u = gtr.compress(locun)

In [31]:
# Initialize an empty dictionary to store row indices for each level
pop_dict = {}

# Iterate through unique levels in the 'factor_column'
for level in df_samples['analysis_pop'].unique():
    # Get the row indices where the 'factor_column' matches the current level
    indices = df_samples.index[df_samples['analysis_pop'] == level].tolist()
    
    # Store the indices in the dictionary with the level as the key
    pop_dict[level] = indices

# Remove dictionary entries with fewer than 5 values
pop_dict = {key: value for key, value in pop_dict.items() if len(value) >= 5}

#get all pops too
pop_dict['all']  = df_samples.index.tolist()

In [32]:
ac_subpop = gt_u.count_alleles_subpops(pop_dict)

In [33]:
#count alleles by pop & convert to table
ac_subpop = gt.count_alleles_subpops(pop_dict)
#get segregating variants only
is_seg = ac_subpop['all'].is_segregating()[:]

combined_data = {}

# Process each item in the dictionary
for name, array in ac_subpop.items():
    # Check if the array has at least two columns

    #get seg sites
    array = array.compress(is_seg)
    if array.shape[1] < 2:
        raise ValueError(f"Array {name} does not have at least two columns")
    
    # Combine each element of the two columns into a single string
    combined_array = np.array([f"{row[0]},{row[1]}" for row in array])
    
    # Store the resulting array in the combined_data dictionary
    combined_data[name] = combined_array

# Convert the combined data into a DataFrame
ac_df = pd.DataFrame(combined_data)

##subsample randomly instead of ld pruning
df_sample = ac_df.sample(frac=0.1)

df_sample = df_sample.drop('all', axis=1)


In [23]:
#try to permute across different randomly chosen sets (10 times?)

!mkdir /home/dennist/lstm_scratch/cease_workspace/treemix_20240711_permutes

In [25]:
for i in range(9):
    vidx = np.random.choice(gt.shape[0], n, replace=False)
    vidx.sort()
    gtr = gt.take(vidx, axis=0)
    gnr = gtr.to_n_alt()
    locun = ld_prune(gnr, size=500, step=200, threshold=.1, n_iter=1)
    ac_subpop = gt_u.count_alleles_subpops(pop_dict)

    #count alleles by pop & convert to table
    ac_subpop = gt.count_alleles_subpops(pop_dict)
    #get segregating variants only
    is_seg = ac_subpop['all'].is_segregating()[:]

    combined_data = {}

    # Process each item in the dictionary
    for name, array in ac_subpop.items():
        # Check if the array has at least two columns

        #get seg sites
        array = array.compress(is_seg)
        if array.shape[1] < 2:
            raise ValueError(f"Array {name} does not have at least two columns")
        
        # Combine each element of the two columns into a single string
        combined_array = np.array([f"{row[0]},{row[1]}" for row in array])
        
        # Store the resulting array in the combined_data dictionary
        combined_data[name] = combined_array

    # Convert the combined data into a DataFrame
    ac_df = pd.DataFrame(combined_data)

    ##subsample randomly instead of ld pruning
    df_sample = ac_df.sample(frac=0.1)

    df_sample = df_sample.drop('all', axis=1)


    df_sample.to_csv(f'/home/dennist/lstm_scratch/cease_workspace/treemix_20240711_permutes/pruned_bypop_afs.{i}.txt', sep='\t', quoting=False, index=False)

iteration 1 retaining 156682 removing 343318 variants
iteration 1 retaining 156126 removing 343874 variants
iteration 1 retaining 157239 removing 342761 variants
iteration 1 retaining 156130 removing 343870 variants
iteration 1 retaining 156020 removing 343980 variants
iteration 1 retaining 156072 removing 343928 variants
iteration 1 retaining 156406 removing 343594 variants
iteration 1 retaining 156342 removing 343658 variants
iteration 1 retaining 155990 removing 344010 variants


In [22]:
df_sample.to_csv('/home/dennist/lstm_scratch/cease_workspace/treemix_20240711/pruned_bypop_afs.txt', sep='\t', quoting=False, index=False)

In [None]:
#run treemix using the shell scripts
#find optimal M if any using R script
#optimal M looks like 0/1?
#high data robustness...may plot without migration edges