In [None]:
from Bio import SeqIO
import glob
import pandas as pd
from itertools import combinations, product
import numpy as np
from scipy import stats 

# Input files
pop_map = '' # ClusterPop Clusters
msa_file = '' # Concatenated whole genome alignment from phybreak script 2
pop0_phybreak_output = '' # file that indicates for each tree, whether population I is monophyletic and what percentage of branch length lies within population I
pop1_phybreak_output = '' # same as above but for population II
core_sweep_pop_I_outfile = '' #
core_sweep_pop_II_outfile = '' #

p1 = 0.0138272139869 # Average nucleotide diversity of population I
p0 = 0.00521122600799 # Average nucleotide diversity of population II
alpha = 0.95 # Cutoff for determining whether a putative sweep region is significantly lower in diversity than expected

total_alignment = {s.id : str(s.seq) for s in SeqIO.parse(msa_file, 'fasta')}

df = pd.read_table(pop_map)
pop0 = [strain.replace('_', '').replace('.','').replace('--', '_') for strain in df[df['Cluster_ID'] == 0].Strain]
pop1 = [strain.replace('_', '').replace('.','').replace('--', '_') for strain in df[df['Cluster_ID'] == 0.1].Strain]
pop2 = [strain.replace('_', '').replace('.','').replace('--', '_') for strain in df[df['Cluster_ID'] == 0.2].Strain]

# div_df is the dataframe of nucleotide diversity per block
div_df = pd.read_csv('pi_join_info.csv', index_col=0)
div_df['block'] = div_df.index
div_df['Length'] = div_df.end - div_df.start
div_df['Midpoint'] = div_df.start + (div_df.end - div_df.start) // 2
div_df['Start'] = div_df['start']
div_df['End'] = div_df['end']
div_df = div_df.sort_values('Midpoint')

In [None]:
def concatenate_windows(sweepdf, length_cutoff=1000):
    i = 0
    final_df = pd.DataFrame()
    current_end = None
    blocks = []
    for index, row in sweepdf.iterrows():
        if not current_end:
            current_end = row['End']
            blocks.append(index)
        elif row['Start'] <= current_end:
            blocks.append(index)

        else: 
            if len(blocks) > 0:
                if sweepdf.loc[max(blocks), 'End'] - sweepdf.loc[min(blocks), 'Start'] >= length_cutoff:
                    block_begin = min(blocks)
                    block_end = max(blocks)
                    final_df.loc[i, 'Start'] = sweepdf.loc[min(blocks), 'Start']
                    final_df.loc[i, 'End'] = sweepdf.loc[max(blocks), 'End']
                    final_df.loc[i, 'Start tree'] = min(blocks)
                    final_df.loc[i, 'End tree'] = max(blocks)
                    i += 1
            blocks = []
        current_end = row['End']
        
    if len(blocks) > 0:
        if sweepdf.loc[max(blocks), 'End'] - sweepdf.loc[min(blocks), 'Start'] >= length_cutoff:
            block_begin = min(blocks)
            block_end = max(blocks)
            final_df.loc[i, 'Start'] = sweepdf.loc[min(blocks), 'Start']
            final_df.loc[i, 'End'] = sweepdf.loc[max(blocks), 'End']
            final_df.loc[i, 'Start tree'] = min(blocks)
            final_df.loc[i, 'End tree'] = max(blocks)

    final_df['Midpoint'] = final_df.Start + ((final_df.End - final_df.Start) // 2)
    return final_df

def calc_all_divs(concatenated_positions,
                  population,
                  pi,
                  alpha):
    finaldf = pd.DataFrame(index=concatenated_positions.index, columns=['Start',
                                                                        'End',
                                                                        'Start tree',
                                                                        'End tree',
                                                                        'pop_pi'])
    for i in concatenated_positions.index:
        start = int(concatenated_positions.loc[i, 'Start'])
        end = int(concatenated_positions.loc[i, 'End'])
        seqs = {}
        for strain, s in total_alignment.items():
            seqs[strain] = s[start: end]
        
        pop_div = calc_pop_div(population, seqs)

        finaldf.loc[i, ] = [start,
                            end,
                            concatenated_positions.loc[i, 'Start tree'],
                            concatenated_positions.loc[i, 'End tree'],
                            pop_div]
        
    finaldf['Length'] = finaldf.End - finaldf.Start
    finaldf['Midpoint'] = finaldf.Start + ((finaldf.End - finaldf.Start) // 2)
    calculate_ci(finaldf, 'pop_pi', alpha, pi)
    return finaldf
    
def count_divs(s1, s2):
    d = 0
    for b1, b2 in zip(s1, s2):
        if b1 != b2:
            d += 1
    return d * 1.0 / len(s1)

def calc_pop_div(pop, s_dict):
    aves = []
    for strain1, strain2 in combinations(pop, 2):
        if strain1 in s_dict.keys() and strain2 in s_dict.keys():
            aves.append(count_divs(s_dict[strain1], s_dict[strain2]))
    return(np.average(aves))

def calculate_ci(df, pop_name, alpha, p_pop):
    low, high = stats.binom.interval(alpha, list(df.Length), p_pop)
    df['ci_low'] = low / df.Length
    df['ci_high'] = high / df.Length

def passes_phylo_criteria(full_df, cutoff, focus_pop):
    return (full_df.focus < cutoff) & (full_df.monophy == 1) & ((full_df[focus_pop] < full_df.ci_low) | (full_df[focus_pop]==0))

def find_sweeps(phybreak_df,
                pi_info_df,
                pop_pi, alpha,
                pop_name,
                pop_list):
    full_df = pd.merge(pi_info_df, phybreak_df)
    full_df.index = full_df.block
    calculate_ci(full_df, pop_name, alpha, pop_pi)
    
    # Finds cutoff based on fraction of branch length within the focus population
    cutoff = np.percentile(full_df.focus, 5)
    trees_passing_phy_criteria = full_df[passes_phylo_criteria(full_df, cutoff, pop_name)]
    
    concat = concatenate_windows(trees_passing_phy_criteria, length_cutoff=500)
    concat = calc_all_divs(concat, pop_list, pop_pi, alpha)
    for col in concat:
        concat[col] = concat[col].astype(float)
    return (full_df, concat)



In [None]:
pop_0_phybreak = pd.read_table(pop0_phybreak_output)
plotdf_pop0, concat_pop0 = find_sweeps(pop_0_phybreak, div_df, p0, alpha, 'Pop0', pop0)
concat_pop0.to_csv(core_sweep_pop_I_outfile, index=None)

In [None]:
pop_1_phybreak = pd.read_table(pop1_phybreak_output)
plotdf_pop1, concat_pop1 = find_sweeps(pop_1_phybreak, div_df, p1, alpha, 'Pop1', pop1)
concat_pop1.to_csv(core_sweep_pop_II_outfile, index=None)