In [1]:
from pandas import DataFrame, read_csv
import pandas as pd
import random
import numpy as np
import json
from collections import defaultdict, Counter
import math
import sys
import itertools
import time
import scipy.stats 
from sklearn.neighbors import KernelDensity
import glob
import matplotlib.pyplot as plt
import seaborn as sns
import re
from scipy.optimize import curve_fit
import copy
import dill
import matplotlib.patches as mpatches
import copy
import dill



chr_lengths= {1: 643000,
                         2: 947000,
                         3: 1100000,
                         4: 1200000,
                         5: 1350000,
                         6: 1420000,
                         7: 1450000,
                         8: 1500000,
                         9: 1550000,
                         10: 1700000,
                         11: 2049999,
                         12: 2300000,
                         13: 2950000,
                         14: 3300000}

class NumpyEncoder(json.JSONEncoder):
    """ Special json encoder for numpy types """

    def default(self, obj):
        if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
                            np.int16, np.int32, np.int64, np.uint8,
                            np.uint16, np.uint32, np.uint64)):
            return int(obj)
        elif isinstance(obj, (np.float_, np.float16, np.float32,
                              np.float64)):
            return float(obj)
        elif isinstance(obj, (np.ndarray,)):  #### This is the fix
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

cpalette1 = sns.color_palette('Reds_r', 2)
cpalette2 = sns.color_palette('Blues_r', 3)
cpalette3 = sns.color_palette('Greys_r', 3)

color_map_dict = {'PC': cpalette1[0],
                 'FS': cpalette1[1],
                 'MS': cpalette1[1],
                 'GC': cpalette2[0],
                 'HS': cpalette2[1],
                  'FAV': cpalette2[2],
                  'GGC': cpalette3[0],
                  'HAV': cpalette3[1],
                  'FCS': cpalette3[2]}

In [9]:
param_dict = {}

for param_combo in ['2.0_11.3']:
    sim_types = ['fs', 'ms'] #not needed, but legacy
    raw_sim_data = defaultdict(dict)
    for sim_type in sim_types:    
        raw_sim_data[sim_type]['r_totals'] = json.load(open('../sims/f1_f2/{x}_r_totals_combined.json'.format(x=sim_type,p=param_combo)))
        raw_sim_data[sim_type]['ibd_segment_max'] = json.load(open('../sims/f1_f2/{x}_ibd_segment_max_combined.json'.format(x=sim_type,p=param_combo)))
        raw_sim_data[sim_type]['ibd_segment_numbers'] = json.load(open('../sims/f1_f2/{x}_ibd_segment_numbers_combined.json'.format(x=sim_type,p=param_combo)))
        max_segment_p_dict = defaultdict(dict)
        for relationship in raw_sim_data[sim_type]['ibd_segment_max']:
            for chrom in raw_sim_data[sim_type]['ibd_segment_max'][relationship]:
                max_segment_p_dict[relationship][chrom] = np.asarray([x for x in np.asarray(raw_sim_data[sim_type]['ibd_segment_max'][relationship][chrom]) / chr_lengths[int(chrom)]])
        raw_sim_data[sim_type]['p_ibd_max'] = max_segment_p_dict
    param_dict[param_combo] = copy.deepcopy(raw_sim_data)

for sim_type in sim_types:
    for key,value in param_dict['2.0_11.3'][sim_type].items():
        try:
            p1_p1 = value.pop('P1.P2')
        except:
            continue
            
def fit_beta(relationships,):
    beta_variables = {}
    r_dict = raw_sim_data['fs']['r_totals']
    for relationship in relationships:
        print(relationship)
        r_list = r_dict[relationship]
        r_list = [x if x != 0 else 1e-9 for x in r_list]
        alpha, beta, loc, scale = scipy.stats.beta.fit(r_list, floc=0,fscale=1)
        #x = np.linspace(0, 1, 100)
        beta_variables[relationship] = (alpha, beta, loc, scale)
        
    return beta_variables

plot_relationships = ['P1.B11', 'P1.B21', 'P1.F11', 'P1.F21',
                     'F11.F12', 'F11.F21', 'F21.F22',
                     'F11.B12','B12.F21','F11.B21', 'B21.F21','B11.B21']

beta_relatedness = fit_beta(plot_relationships)

def plot_r_totals(beta_vals_dict, relationships, color_map_dict):
    x = np.linspace(0, 1, 100)
    #plt.figure(figsize = (8,5))
    i= 1
    for relationship in relationships:
        alpha, beta, loc, scale = beta_vals_dict[relationship]
        pdf = scipy.stats.beta.pdf(x, alpha, beta, loc, scale)

        plt.plot(x, pdf, color = color_map_dict[relationship], label = relationship)
        plt.xlabel('Genome-wide Total Relatedness', fontsize = 15)
        plt.ylabel('Density', fontsize = 15)
        plt.xlim(-0.05,1.05)
        plt.ylim(-0.05, 10)
        plt.legend(ncols = 3, fontsize = 12)
        plt.tick_params(axis='both', which='major', labelsize=12)

P1.B11
P1.B21
P1.F11
P1.F21
F11.F12
F11.F21
F21.F22
F11.B12
B12.F21
F11.B21
B21.F21
B11.B21


In [11]:

def create_pdfs(p_max_segment_dict, bandwidth= 0.05):
    max_segment_pdf_dict = defaultdict(dict)
    for relationship in p_max_segment_dict:
        for chromosome in p_max_segment_dict[relationship]:
            data = np.asarray(p_max_segment_dict[relationship][chromosome])
            mask0 = np.asarray(p_max_segment_dict[relationship][chromosome]) <= 0 + bandwidth
            mask1 = np.asarray(p_max_segment_dict[relationship][chromosome]) >= 1 - bandwidth
            mask = ~mask0 & ~mask1
            data = np.asarray(data[mask])

            p0 = (np.sum(mask0) + 1 )/ (len(mask0) + 3)
            p1 = (np.sum(mask1) + 1 )/ (len(mask1) + 3)
            kde = KernelDensity(kernel='gaussian',bandwidth=bandwidth).fit(data.reshape(-1,1))#training of model
            #beta_params = scipy.stats.beta.fit(data)

            max_segment_pdf_dict[relationship][int(chromosome)] =  (p0, p1, kde, bandwidth)#, beta_params)
    return max_segment_pdf_dict


def calc_seg_count_pmf(segment_counts_dict):
    '''empirical segment_count_pmf, with add one smoothing
    theta_i = (x_i + alpha) / (N + alpha * d)
    where x_i is the count of the cateogory i
    alpha is set to 1
    d is the number of categories
    N is the sample count'''
    seg_count_pmf = defaultdict(lambda: defaultdict(lambda : defaultdict(dict)))
    for relationship in segment_counts_dict:
        for chrom in segment_counts_dict[relationship]:
            counts = Counter(segment_counts_dict[relationship][chrom])
            total = np.sum(list(counts.values()))
            max_count= max(counts)
            bins = range(0,max_count + 2) #add an additional category representing max + 1
            n_bins = len(bins)
            for x in bins:
                if x in counts:
                    seg_count_pmf[relationship][chrom][x] = (counts[x] + 1)/(total + n_bins)
                else:
                    seg_count_pmf[relationship][chrom]['misc'] = 1 / (total + n_bins)
    return seg_count_pmf

def evaluate_max_segment_piecewise_pdf(x, relationship, chrom, pdf_dict, bandwidth =0.02):
    '''modeled as a kde with spikes'''
    #print(fs_max_segment_pdf_dict[relationship][int(chrom)])
    p0,p1,kde, tstep, = pdf_dict[relationship][int(chrom)]
                       
    if float(x) <= 0 + bandwidth:
        return p0/tstep
    elif float(x) >= 1 - bandwidth:
        return p1/tstep
    else:
        x_transmute = np.asarray([x])
        #if beta_dis:
        #    pdf = scipy.stats.beta.pdf(x, alpha, beta, loc, scale)
        #else:
        pdf = np.exp(kde.score_samples(x_transmute.reshape(-1,1)))[0]
        pdf = pdf * (1-p0-p1)
        return pdf

    
raw_sim_data = param_dict['2.0_11.3']
tmp_pdfs = defaultdict(dict)
for sim_type in raw_sim_data:
    print(sim_type)
    tmp_pdfs[sim_type]['r_beta'] = fit_beta(raw_sim_data[sim_type]['r_totals'])
    tmp_pdfs[sim_type]['p_max_segment'] = create_pdfs(raw_sim_data[sim_type]['p_ibd_max'])
    tmp_pdfs[sim_type]['segment_count'] = calc_seg_count_pmf(raw_sim_data[sim_type]['ibd_segment_numbers'])


pdfs = {}
pdfs['r_beta'] = tmp_pdfs['fs']['r_beta']
pdfs['p_max_segment'] = tmp_pdfs['fs']['p_max_segment']
pdfs['segment_count'] = tmp_pdfs['fs']['segment_count']

fs
P1.F11
P1.F12
P1.B11
P1.B12
P1.B21
P1.B22
P1.B31
P1.B32
P1.B41
P1.B42
P1.F21
P1.F22
P2.F11
P2.F12
P2.B11
P2.B12
P2.B21
P2.B22
P2.B31
P2.B32
P2.B41
P2.B42
P2.F21
P2.F22
F11.F12
F11.B11
F11.B12
F11.B21
F11.B22
F11.B31
F11.B32
F11.B41
F11.B42
F11.F21
F11.F22
F12.B11
F12.B12
F12.B21
F12.B22
F12.B31
F12.B32
F12.B41
F12.B42
F12.F21
F12.F22
B11.B12
B11.B21
B11.B22
B11.B31
B11.B32
B11.B41
B11.B42
B11.F21
B11.F22
B12.B21
B12.B22
B12.B31
B12.B32
B12.B41
B12.B42
B12.F21
B12.F22
B21.B22
B21.B31
B21.B32
B21.B41
B21.B42
B21.F21
B21.F22
B22.B31
B22.B32
B22.B41
B22.B42
B22.F21
B22.F22
B31.B32
B31.B41
B31.B42
B31.F21
B31.F22
B32.B41
B32.B42
B32.F21
B32.F22
B41.B42
B41.F21
B41.F22
B42.F21
B42.F22
F21.F22
ms


In [2]:
data_df = read_csv('/Users/weswong/Documents/GitHub/cross_data/ferdig_v2/NF54_NHP4026_F2_comparison_stats.txt',
                  sep='\t')


parents = ['AB_BSA_222_NHP4026.P','AB_BSA_220_NF54.P']

nodes = []
for comparison in data_df['comparison']:
    s1,s2 = comparison.split(':')
    if s1 in parents:
        nodes.append(s2)
        print(comparison)

AB_BSA_220_NF54.P:AB_BSA_222_NHP4026.P
AB_BSA_220_NF54.P:HI0001.F1
AB_BSA_220_NF54.P:HI0003.F1
AB_BSA_220_NF54.P:HI0004.F1
AB_BSA_220_NF54.P:HI0005.F1
AB_BSA_220_NF54.P:HI0007.F1
AB_BSA_220_NF54.P:HI0010.F1
AB_BSA_220_NF54.P:HI0019.F1
AB_BSA_220_NF54.P:HI0020.F1
AB_BSA_220_NF54.P:HI0021.F1
AB_BSA_220_NF54.P:HI0022.F1
AB_BSA_220_NF54.P:HI0024.F1
AB_BSA_220_NF54.P:HI0025.F1
AB_BSA_220_NF54.P:HI0026.F1
AB_BSA_220_NF54.P:HI0028.F1
AB_BSA_220_NF54.P:HI0029.F1
AB_BSA_220_NF54.P:HI0030.F1
AB_BSA_220_NF54.P:HI0031.F1
AB_BSA_220_NF54.P:HI0032.F1
AB_BSA_220_NF54.P:HI0033.F1
AB_BSA_220_NF54.P:HI0035.F1
AB_BSA_220_NF54.P:HI0038.F1
AB_BSA_220_NF54.P:HI0039.F1
AB_BSA_220_NF54.P:HI0040.F1
AB_BSA_220_NF54.P:HI0041.F1
AB_BSA_220_NF54.P:HI0047.F1
AB_BSA_220_NF54.P:HI0049.F1
AB_BSA_220_NF54.P:HI0050.F1
AB_BSA_220_NF54.P:HI0053.F1
AB_BSA_220_NF54.P:HI0054.F1
AB_BSA_220_NF54.P:HI0059.F1
AB_BSA_220_NF54.P:HI0060.F1
AB_BSA_220_NF54.P:HI0061.F1
AB_BSA_220_NF54.P:HI0062.F1
AB_BSA_220_NF54.P:HI0063.F1
AB_BSA_22

In [3]:
class Sample_comparison:
    def __init__(self, comparison, r_total, max_ibd_segments, n_ibd_segments):
        self.comparison = comparison
        self.s1 = comparison.split(':')[0]
        self.s2 = comparison.split(':')[1]
        self.r_total = r_total

        self.max_ibd_segment = max_ibd_segments
        self.n_segment_count = n_ibd_segments
        
        
def format_ingest_data(file):
    sample_data = defaultdict(dict)
    df = DataFrame(read_csv(file, sep = '\t'))
    data = df.to_numpy()
    n_ibd_segment_dict = {}
    max_ibd_segment_dict = {}
    for row in data:
        comparison = row[0]
        s1 = comparison.split(':')[0]
        s2 = comparison.split(':')[1]
        relatedness = row[1]
        n_ibd_segments = row[2:16]
        max_ibd_segments = row[16:]
        #for chrom, n_ibd, max_ibd in zip(range(1,15), n_ibd_segments, max_ibd_segments):
        #    n_ibd_segment_dict[chrom] = n_ibd
        #    max_ibd_segment_dict[chrom] = max_ibd
        print(comparison)
        S = Sample_comparison(comparison, relatedness,max_ibd_segments, n_ibd_segments)
        sample_data[s1][s2] = S
        sample_data[s2][s1] = copy.deepcopy(S)
    return sample_data

sample_data = format_ingest_data('/Users/weswong/Documents/GitHub/cross_data/ferdig_v2/NF54_NHP4026_F2_comparison_stats.txt')

sample_data2 = format_ingest_data('/Users/weswong/Documents/GitHub/cross_data/ferdig_v2/NF54_NHP4026_F2_comparison_stats.txt')


AB_BSA_220_NF54.P:AB_BSA_222_NHP4026.P
AB_BSA_220_NF54.P:HI0001.F1
AB_BSA_220_NF54.P:HI0003.F1
AB_BSA_220_NF54.P:HI0004.F1
AB_BSA_220_NF54.P:HI0005.F1
AB_BSA_220_NF54.P:HI0007.F1
AB_BSA_220_NF54.P:HI0010.F1
AB_BSA_220_NF54.P:HI0019.F1
AB_BSA_220_NF54.P:HI0020.F1
AB_BSA_220_NF54.P:HI0021.F1
AB_BSA_220_NF54.P:HI0022.F1
AB_BSA_220_NF54.P:HI0024.F1
AB_BSA_220_NF54.P:HI0025.F1
AB_BSA_220_NF54.P:HI0026.F1
AB_BSA_220_NF54.P:HI0028.F1
AB_BSA_220_NF54.P:HI0029.F1
AB_BSA_220_NF54.P:HI0030.F1
AB_BSA_220_NF54.P:HI0031.F1
AB_BSA_220_NF54.P:HI0032.F1
AB_BSA_220_NF54.P:HI0033.F1
AB_BSA_220_NF54.P:HI0035.F1
AB_BSA_220_NF54.P:HI0038.F1
AB_BSA_220_NF54.P:HI0039.F1
AB_BSA_220_NF54.P:HI0040.F1
AB_BSA_220_NF54.P:HI0041.F1
AB_BSA_220_NF54.P:HI0047.F1
AB_BSA_220_NF54.P:HI0049.F1
AB_BSA_220_NF54.P:HI0050.F1
AB_BSA_220_NF54.P:HI0053.F1
AB_BSA_220_NF54.P:HI0054.F1
AB_BSA_220_NF54.P:HI0059.F1
AB_BSA_220_NF54.P:HI0060.F1
AB_BSA_220_NF54.P:HI0061.F1
AB_BSA_220_NF54.P:HI0062.F1
AB_BSA_220_NF54.P:HI0063.F1
AB_BSA_22

HI0162.F2:HI0225.F2
HI0162.F2:HI0227.F2
HI0162.F2:HI0230.F2
HI0162.F2:HI0232.F2
HI0162.F2:HI0239.F2
HI0162.F2:HI0241.F2
HI0162.F2:HI0242.F2
HI0162.F2:HI0246.F2
HI0162.F2:HI0251.F2
HI0162.F2:HI0253.F2
HI0162.F2:HI0254.F2
HI0162.F2:HI0256.F2
HI0162.F2:HI0258.F2
HI0162.F2:HI0259.F2
HI0162.F2:HI0263.F2
HI0162.F2:HI0265.F2
HI0162.F2:HI0266.F2
HI0162.F2:HI0267.F2
HI0162.F2:HI0268.F2
HI0163.F2:HI0164.F2
HI0163.F2:HI0165.F2
HI0163.F2:HI0170.F2
HI0163.F2:HI0173.F2
HI0163.F2:HI0174.F2
HI0163.F2:HI0175.F2
HI0163.F2:HI0177.F2
HI0163.F2:HI0183.F2
HI0163.F2:HI0185.F2
HI0163.F2:HI0187.F2
HI0163.F2:HI0188.F2
HI0163.F2:HI0189.F2
HI0163.F2:HI0190.F2
HI0163.F2:HI0194.F2
HI0163.F2:HI0196.F2
HI0163.F2:HI0197.F2
HI0163.F2:HI0198.F2
HI0163.F2:HI0199.F2
HI0163.F2:HI0200.F2
HI0163.F2:HI0203.F2
HI0163.F2:HI0206.F2
HI0163.F2:HI0207.F2
HI0163.F2:HI0208.F2
HI0163.F2:HI0211.F2
HI0163.F2:HI0213.F2
HI0163.F2:HI0215.F2
HI0163.F2:HI0216.F2
HI0163.F2:HI0225.F2
HI0163.F2:HI0227.F2
HI0163.F2:HI0230.F2
HI0163.F2:HI0232.F2


HI0081.F1:HI0173.F2
HI0081.F1:HI0174.F2
HI0081.F1:HI0175.F2
HI0081.F1:HI0177.F2
HI0081.F1:HI0183.F2
HI0081.F1:HI0185.F2
HI0081.F1:HI0187.F2
HI0081.F1:HI0188.F2
HI0081.F1:HI0189.F2
HI0081.F1:HI0190.F2
HI0081.F1:HI0194.F2
HI0081.F1:HI0196.F2
HI0081.F1:HI0197.F2
HI0081.F1:HI0198.F2
HI0081.F1:HI0199.F2
HI0081.F1:HI0200.F2
HI0081.F1:HI0203.F2
HI0081.F1:HI0206.F2
HI0081.F1:HI0207.F2
HI0081.F1:HI0208.F2
HI0081.F1:HI0211.F2
HI0081.F1:HI0213.F2
HI0081.F1:HI0215.F2
HI0081.F1:HI0216.F2
HI0081.F1:HI0225.F2
HI0081.F1:HI0227.F2
HI0081.F1:HI0230.F2
HI0081.F1:HI0232.F2
HI0081.F1:HI0239.F2
HI0081.F1:HI0241.F2
HI0081.F1:HI0242.F2
HI0081.F1:HI0246.F2
HI0081.F1:HI0251.F2
HI0081.F1:HI0253.F2
HI0081.F1:HI0254.F2
HI0081.F1:HI0256.F2
HI0081.F1:HI0258.F2
HI0081.F1:HI0259.F2
HI0081.F1:HI0263.F2
HI0081.F1:HI0265.F2
HI0081.F1:HI0266.F2
HI0081.F1:HI0267.F2
HI0081.F1:HI0268.F2
HI0082.F1:HI0084.F1
HI0082.F1:HI0087.F1
HI0082.F1:HI0089.F1
HI0082.F1:HI0090.F1
HI0082.F1:HI0091.F1
HI0082.F1:HI0094.F1
HI0082.F1:HI0101.F1


In [12]:
def calc_logp_rtotal(G_model, data):
        r_total_beta_params = pdfs['r_beta'][G_model]
        P_rtotal = scipy.stats.beta.logpdf(data, *r_total_beta_params)
        return P_rtotal

def calc_logp_max_ibd_segment(G_model, data, chrom):
    #data set to self.max_ibd_segment[p][idx] data
    return np.log(evaluate_max_segment_piecewise_pdf(data, G_model, chrom, pdfs['p_max_segment']))

def calc_logp_p_seg_count(G_model, data, chrom):
    n_segments = data
    if n_segments in pdfs['segment_count'][G_model][str(chrom)].keys():
        P_seg_count = np.log(pdfs['segment_count'][G_model][str(chrom)][n_segments])
    else:
        P_seg_count = np.log(pdfs['segment_count'][G_model][str(chrom)]['misc'])
    return P_seg_count


parents = ['AB_BSA_220_NF54.P', 'AB_BSA_222_NHP4026.P']
for ref_node in list(sample_data.keys())[2:]:
    for triangulation_node in list(sample_data[ref_node].keys())[2:]:
        if (ref_node not in parents) and (triangulation_node not in parents):
            S = sample_data[ref_node][triangulation_node]

            p1_s1 = sample_data['AB_BSA_220_NF54.P'][ref_node]
            p2_s1 = sample_data['AB_BSA_222_NHP4026.P'][ref_node]
            
            p1_s2 = sample_data['AB_BSA_220_NF54.P'][triangulation_node]
            s1_s2 = sample_data[ref_node][triangulation_node]

            r_total_dict = {'P1.s1': p1_s1.r_total,
                           'P2.s1':p2_s1.r_total,
                            'P1.s2': p1_s2.r_total,
                           's1.s2':s1_s2.r_total}

            max_ibd_segment_dict = {'P1.s1': p1_s1.max_ibd_segment,
                                   'P2.s1': p2_s1.max_ibd_segment,
                                   'P1.s2': p1_s2.max_ibd_segment,
                                   's1.s2': s1_s2.max_ibd_segment}

            n_segment_count_dict = {'P1.s1': p1_s1.n_segment_count,
                                   'P2.s1': p2_s1.n_segment_count,
                                   'P1.s2': p1_s2.n_segment_count,
                                   's1.s2': s1_s2.n_segment_count}

            founder_logL_dict = defaultdict(lambda:0)
            triangulation_logL_dict = defaultdict(lambda:0)
            founder_logL_dict['B1'] += calc_logp_rtotal('P1.B11', r_total_dict['P1.s1']) + calc_logp_rtotal('P2.B11', r_total_dict['P2.s1'])
            founder_logL_dict['B2'] += calc_logp_rtotal('P1.B21', r_total_dict['P1.s1']) + calc_logp_rtotal('P2.B21', r_total_dict['P2.s1'])
            founder_logL_dict['F1'] += calc_logp_rtotal('P1.F11', r_total_dict['P1.s1']) + calc_logp_rtotal('P2.F11', r_total_dict['P2.s1'])
            founder_logL_dict['F2'] += calc_logp_rtotal('P1.F21', r_total_dict['P1.s1']) + calc_logp_rtotal('P2.F21', r_total_dict['P2.s1'])

            triangulation_logL_dict['F1:F1'] += calc_logp_rtotal('P1.F11', r_total_dict['P1.s1']) + calc_logp_rtotal('P1.F11', r_total_dict['P1.s2']) + calc_logp_rtotal('F11.F12', r_total_dict['s1.s2']) 
            triangulation_logL_dict['F1:F2'] += calc_logp_rtotal('P1.F11', r_total_dict['P1.s1']) + calc_logp_rtotal('P1.F21', r_total_dict['P1.s2']) + calc_logp_rtotal('F11.F21', r_total_dict['s1.s2']) 
            
            triangulation_logL_dict['F2:F1'] += calc_logp_rtotal('P1.F21', r_total_dict['P1.s1']) + calc_logp_rtotal('P1.F11', r_total_dict['P1.s2']) + calc_logp_rtotal('F11.F21', r_total_dict['s1.s2']) 
            triangulation_logL_dict['F2:F2'] += calc_logp_rtotal('P1.F21', r_total_dict['P1.s1']) + calc_logp_rtotal('P1.F21', r_total_dict['P1.s2']) + calc_logp_rtotal('F21.F22', r_total_dict['s1.s2']) 
            

            for chrom in range(1,15):
                chrom_idx = chrom - 1
                founder_logL_dict['B1'] += calc_logp_max_ibd_segment('P1.B11', max_ibd_segment_dict['P1.s1'][chrom_idx], chrom) + calc_logp_max_ibd_segment('P2.B11', max_ibd_segment_dict['P2.s1'][chrom_idx], chrom)
                founder_logL_dict['B2'] += calc_logp_max_ibd_segment('P1.B21', max_ibd_segment_dict['P1.s1'][chrom_idx], chrom) + calc_logp_max_ibd_segment('P2.B21', max_ibd_segment_dict['P2.s1'][chrom_idx], chrom)
                founder_logL_dict['F1'] += calc_logp_max_ibd_segment('P1.F11', max_ibd_segment_dict['P1.s1'][chrom_idx], chrom) + calc_logp_max_ibd_segment('P2.F11', max_ibd_segment_dict['P2.s1'][chrom_idx], chrom)
                founder_logL_dict['F2'] += calc_logp_max_ibd_segment('P1.F21', max_ibd_segment_dict['P1.s1'][chrom_idx], chrom) + calc_logp_max_ibd_segment('P2.F21', max_ibd_segment_dict['P2.s1'][chrom_idx], chrom)

                founder_logL_dict['B1'] += calc_logp_p_seg_count('P1.B11', n_segment_count_dict['P1.s1'][chrom_idx], chrom) + calc_logp_p_seg_count('P2.B11', n_segment_count_dict['P2.s1'][chrom_idx], chrom)
                founder_logL_dict['B2'] += calc_logp_p_seg_count('P1.B21', n_segment_count_dict['P1.s1'][chrom_idx], chrom) + calc_logp_p_seg_count('P2.B21', n_segment_count_dict['P2.s1'][chrom_idx], chrom)
                founder_logL_dict['F1'] += calc_logp_p_seg_count('P1.F11', n_segment_count_dict['P1.s1'][chrom_idx], chrom) + calc_logp_p_seg_count('P2.F11', n_segment_count_dict['P2.s1'][chrom_idx], chrom)
                founder_logL_dict['F2'] += calc_logp_p_seg_count('P1.F21', n_segment_count_dict['P1.s1'][chrom_idx], chrom) + calc_logp_p_seg_count('P2.F21', n_segment_count_dict['P2.s1'][chrom_idx], chrom)
                
                #Max IBD segment-----
                triangulation_logL_dict['F1:F1'] += calc_logp_max_ibd_segment('P1.F11',max_ibd_segment_dict['P1.s1'][chrom_idx], chrom)
                triangulation_logL_dict['F1:F1'] += calc_logp_max_ibd_segment('P1.F11',max_ibd_segment_dict['P1.s2'][chrom_idx], chrom)
                triangulation_logL_dict['F1:F1'] += calc_logp_max_ibd_segment('F11.F12',max_ibd_segment_dict['s1.s2'][chrom_idx], chrom)

                triangulation_logL_dict['F1:F2'] += calc_logp_max_ibd_segment('P1.F11',max_ibd_segment_dict['P1.s1'][chrom_idx], chrom)
                triangulation_logL_dict['F1:F2'] += calc_logp_max_ibd_segment('P1.F21',max_ibd_segment_dict['P1.s2'][chrom_idx], chrom)
                triangulation_logL_dict['F1:F2'] += calc_logp_max_ibd_segment('F11.F21',max_ibd_segment_dict['s1.s2'][chrom_idx], chrom)
                
                triangulation_logL_dict['F2:F1'] += calc_logp_max_ibd_segment('P1.F21',max_ibd_segment_dict['P1.s1'][chrom_idx], chrom)
                triangulation_logL_dict['F2:F1'] += calc_logp_max_ibd_segment('P1.F11',max_ibd_segment_dict['P1.s2'][chrom_idx], chrom)
                triangulation_logL_dict['F2:F1'] += calc_logp_max_ibd_segment('F11.F21',max_ibd_segment_dict['s1.s2'][chrom_idx], chrom)
                
                triangulation_logL_dict['F2:F2'] += calc_logp_max_ibd_segment('P1.F21',max_ibd_segment_dict['P1.s1'][chrom_idx], chrom)
                triangulation_logL_dict['F2:F2'] += calc_logp_max_ibd_segment('P1.F21',max_ibd_segment_dict['P1.s2'][chrom_idx], chrom)
                triangulation_logL_dict['F2:F2'] += calc_logp_max_ibd_segment('F21.F22',max_ibd_segment_dict['s1.s2'][chrom_idx], chrom)
                
                #Seg Count
                triangulation_logL_dict['F1:F1'] += calc_logp_p_seg_count('P1.F11',n_segment_count_dict['P1.s1'][chrom_idx], chrom)
                triangulation_logL_dict['F1:F1'] += calc_logp_p_seg_count('P1.F11',n_segment_count_dict['P1.s2'][chrom_idx], chrom)
                triangulation_logL_dict['F1:F1'] += calc_logp_p_seg_count('F11.F12',n_segment_count_dict['s1.s2'][chrom_idx], chrom)

                triangulation_logL_dict['F1:F2'] += calc_logp_p_seg_count('P1.F11',n_segment_count_dict['P1.s1'][chrom_idx], chrom)
                triangulation_logL_dict['F1:F2'] += calc_logp_p_seg_count('P1.F21',n_segment_count_dict['P1.s2'][chrom_idx], chrom)
                triangulation_logL_dict['F1:F2'] += calc_logp_p_seg_count('F11.F21',n_segment_count_dict['s1.s2'][chrom_idx], chrom)
                
                triangulation_logL_dict['F2:F1'] += calc_logp_p_seg_count('P1.F21',n_segment_count_dict['P1.s1'][chrom_idx], chrom)
                triangulation_logL_dict['F2:F1'] += calc_logp_p_seg_count('P1.F11',n_segment_count_dict['P1.s2'][chrom_idx], chrom)
                triangulation_logL_dict['F2:F1'] += calc_logp_p_seg_count('F11.F21',n_segment_count_dict['s1.s2'][chrom_idx], chrom)
                
                triangulation_logL_dict['F2:F2'] += calc_logp_p_seg_count('P1.F21',n_segment_count_dict['P1.s1'][chrom_idx], chrom)
                triangulation_logL_dict['F2:F2'] += calc_logp_p_seg_count('P1.F21',n_segment_count_dict['P1.s2'][chrom_idx], chrom)
                triangulation_logL_dict['F2:F2'] += calc_logp_p_seg_count('F21.F22',n_segment_count_dict['s1.s2'][chrom_idx], chrom)
                
            S.founder_logL_dict = founder_logL_dict
            S.max_founder_logL = max(founder_logL_dict, key=founder_logL_dict.get)
            S.triangulation_logL_dict = triangulation_logL_dict
            S.max_triangulation_logL = max(triangulation_logL_dict, key=triangulation_logL_dict.get)
    counts = Counter([sample_data[ref_node][triangulation_node].max_triangulation_logL.split(':')[0] for triangulation_node in list(sample_data[ref_node])[2:]])

    print(ref_node, counts )
    
f1_samples =[]
f2_samples = []
for key in list(sample_data.keys())[2:]:
    if 'F1' in key:
        f1_samples.append(key)
    else:
        f2_samples.append(key)  

HI0001.F1 Counter({'F1': 127, 'F2': 7})
HI0003.F1 Counter({'F1': 131, 'F2': 3})
HI0004.F1 Counter({'F1': 134})
HI0005.F1 Counter({'F1': 130, 'F2': 4})
HI0007.F1 Counter({'F1': 133, 'F2': 1})
HI0010.F1 Counter({'F1': 132, 'F2': 2})
HI0019.F1 Counter({'F1': 131, 'F2': 3})
HI0020.F1 Counter({'F1': 123, 'F2': 11})
HI0021.F1 Counter({'F1': 125, 'F2': 9})
HI0022.F1 Counter({'F1': 134})
HI0024.F1 Counter({'F1': 101, 'F2': 33})
HI0025.F1 Counter({'F1': 123, 'F2': 11})
HI0026.F1 Counter({'F1': 120, 'F2': 14})
HI0028.F1 Counter({'F1': 127, 'F2': 7})
HI0029.F1 Counter({'F1': 130, 'F2': 4})
HI0030.F1 Counter({'F1': 132, 'F2': 2})
HI0031.F1 Counter({'F1': 112, 'F2': 22})
HI0032.F1 Counter({'F1': 134})
HI0033.F1 Counter({'F1': 131, 'F2': 3})
HI0035.F1 Counter({'F1': 133, 'F2': 1})
HI0038.F1 Counter({'F1': 133, 'F2': 1})
HI0039.F1 Counter({'F1': 130, 'F2': 4})
HI0040.F1 Counter({'F1': 133, 'F2': 1})
HI0041.F1 Counter({'F1': 111, 'F2': 23})
HI0047.F1 Counter({'F1': 121, 'F2': 13})
HI0049.F1 Counter({'

In [None]:
f1_classification = defaultdict(list)
f2_classification = defaultdict(list)
for node in list(sample_data.keys())[2:]:
    S = list(sample_data[node].values())[-1]
    if 'F1' in node:
        f1_classification[S.max_founder_logL].append(node)
    else:
        f2_classification[S.max_founder_logL].append(node)
    sorted_C = sorted(dict(S.founder_logL_dict), key = lambda x: S.founder_logL_dict[x], reverse = True)
    #print(node, S.max_founder_logL, [(k,S.founder_logL_dict[k]) for k in sorted_C])

print('F1_pool')
f1_classification_counts = {}
for key in f1_classification:
    f1_classification_counts[key] = len(f1_classification[key])
    
print('F2_pool')
f2_classification_counts = {}
for key in f2_classification:
    f2_classification_counts[key] = len(f2_classification[key])
backcrosses = f1_classification['B1'] + f1_classification['B2'] + f2_classification['B1'] + f2_classification['B2']

In [None]:
f1_triangulated_classification = defaultdict(list)
f1_triangulated_classification['B1'] = f1_classification['B1']
f1_triangulated_classification['B2'] = f1_classification['B2']


f2_triangulated_classification = defaultdict(list)
f2_triangulated_classification['B1'] = f2_classification['B1']
f2_triangulated_classification['B2'] = f2_classification['B2']

f1_nonbackcrosses = [x for x in f1_samples if x not in backcrosses]
for node in f1_nonbackcrosses:
    counts = Counter([sample_data[node][k].max_triangulation_logL.split(':')[0] for k in list(sample_data[node])[2:] if k in f1_nonbackcrosses])
    triangulated_classification = max(counts, key=counts.get)
    f1_triangulated_classification[triangulated_classification].append(node)
f1_classification_counts = defaultdict(lambda: 0)
for key in f1_triangulated_classification:
    f1_classification_counts[key]= len(f1_triangulated_classification[key])
    
f2_nonbackcrosses = [x for x in f2_samples if x not in backcrosses]
for node in f2_nonbackcrosses:
    print(node, counts)
    counts = Counter([sample_data[node][k].max_triangulation_logL.split(':')[0] for k in list(sample_data[node])[2:] if k in f2_nonbackcrosses])
    triangulated_classification = max(counts, key=counts.get)
    f2_triangulated_classification[triangulated_classification].append(node)
    
f2_classification_counts = defaultdict(lambda: 0)
for key in f2_triangulated_classification:
    f2_classification_counts[key]= len(f2_triangulated_classification[key])

In [None]:
integrated_triangulation = json.load(open('full_triangulation_p1_p2.json'))
integrated_triangulation_f1counts = Counter(list(integrated_triangulation['f1_pool'].values()))
integrated_triangulation_f1sum = sum(Counter(list(integrated_triangulation['f1_pool'].values())).values())
integrated_triangulation_f2counts = Counter(list(integrated_triangulation['f2_pool'].values()))
integrated_triangulation_f2sum = sum(Counter(list(integrated_triangulation['f2_pool'].values())).values())

plt.figure(figsize=(8,4))
x_array = np.asarray([0,1,2,3])
plt.subplot(1,2,1)
width = 0.4
plt.bar(x_array, [integrated_triangulation_f1counts[k]/integrated_triangulation_f1sum for k in ['B1', 'B2', 'F1', 'F2']], width = width)
plt.bar(x_array+width, [len(f1_triangulated_classification[k])/integrated_triangulation_f1sum for k in ['B1', 'B2', 'F1', 'F2']], width = width)
plt.xticks(x_array, ['B1', 'B2', 'F1', 'F2'])
plt.ylabel('Proportion')
plt.xlabel('Classification')
plt.title('F1 pool')

plt.subplot(1,2,2)
plt.bar(x_array, [integrated_triangulation_f2counts[k]/integrated_triangulation_f2sum for k in ['B1', 'B2', 'F1', 'F2']], width = width,
       label = 'Integrated')
plt.bar(x_array+width, [len(f2_triangulated_classification[k])/integrated_triangulation_f2sum for k in ['B1', 'B2', 'F1', 'F2']], width = width,
       label = 'Tiered')

plt.xticks(x_array, ['B1', 'B2', 'F1', 'F2'])
plt.ylabel('Proportion')
plt.xlabel('Classification')
plt.title('F2 pool')
plt.legend()



In [None]:
tiered_triangulated_classification = defaultdict(lambda: {})
for classification in f2_triangulated_classification:
    for sample in f2_triangulated_classification[classification]:
        tiered_triangulated_classification['f2_pool'][sample] = classification
for classification in f1_triangulated_classification:
    for sample in f1_triangulated_classification[classification]:
        tiered_triangulated_classification['f1_pool'][sample] = classification

In [None]:
classification_comparator = defaultdict(lambda: [['Sample', 'Integrated', 'Tiered']])

for pool in ['f1_pool', 'f2_pool']:
    for sample in integrated_triangulation[pool]:
        integrated_classification = integrated_triangulation[pool][sample]
        tiered_classification = tiered_triangulated_classification[pool][sample]
        classification_comparator[pool].append([sample, integrated_classification, tiered_classification])
    classification_comparator[pool] = DataFrame(classification_comparator[pool][1:], columns = classification_comparator[pool][0])

In [None]:
classification_comparator['f1_pool'].to_csv('f1_pool_classification.csv', index = False)
classification_comparator['f2_pool'].to_csv('f2_pool_classification.csv', index = False)

In [None]:
def calc_logp_rtotal(G_model, data):
        r_total_beta_params = pdfs['r_beta'][G_model]
        P_rtotal = scipy.stats.beta.logpdf(data, *r_total_beta_params)
        return P_rtotal

def calc_logp_max_ibd_segment(G_model, data, chrom):
    #data set to self.max_ibd_segment[p][idx] data
    return np.log(evaluate_max_segment_piecewise_pdf(data, G_model, chrom, pdfs['p_max_segment']))

def calc_logp_p_seg_count(G_model, data, chrom):
    n_segments = data
    if n_segments in pdfs['segment_count'][G_model][str(chrom)].keys():
        P_seg_count = np.log(pdfs['segment_count'][G_model][str(chrom)][n_segments])
    else:
        P_seg_count = np.log(pdfs['segment_count'][G_model][str(chrom)]['misc'])
    return P_seg_count



genealogy_classes = ['B1', 'B2', 'F1', 'F2']
    
genealogy_keyswaps = {}
for G in genealogy_classes:
    genealogy_keyswaps[G + ':' + G] = ('P1.' + G +'1', #p1_s1
                                       'P2.' + G +'1', #p2_s1
                                       'P1.' + G +'1', #p1_s2
                                       'P2.' + G +'1', #p2_s2
                                       G+ '1.' + G + '2') #s1_s2

# all unique permutations
for x in itertools.permutations(genealogy_classes,2):
    s1,s2 = x
    combo =s1+ '1.' + s2 + '2'
    if combo in ['B21.B12','B11.F12','B21.F12', 'F21.B12', 'F21.B22', 'F21.F12']:
        genealogy_keyswaps[s1 + ':' + s2] = ('P1.' + s1 +'1',
                                           'P2.' + s1 +'1',
                                             'P1.' + s2 +'1',
                                           'P2.' + s2 +'1',
                                           s2+ '1.' + s1 + '2') 
    else:
        genealogy_keyswaps[s1 + ':' + s2] = ('P1.' + s1 +'1',
                                           'P2.' + s1 +'1',
                                             'P1.' + s2 +'1',
                                           'P2.' + s2 +'1',
                                           s1+ '1.' + s2 + '2')  
    
triangulation_dict = defaultdict(dict)
parents = ['AB_BSA_220_NF54.P', 'AB_BSA_222_NHP4026.P']
for ref_node in sample_data2:
    for triangulation_node in sample_data2[ref_node]:
        if (ref_node not in parents) and (triangulation_node not in parents):
            S = sample_data2[ref_node][triangulation_node]

            p1_s1 = sample_data['AB_BSA_220_NF54.P'][ref_node]
            p2_s1 = sample_data['AB_BSA_222_NHP4026.P'][ref_node]
            
            p1_s2 = sample_data['AB_BSA_220_NF54.P'][triangulation_node]
            p2_s2 = sample_data['AB_BSA_222_NHP4026.P'][triangulation_node]
            
            s1_s2 = sample_data[ref_node][triangulation_node]

            r_total_dict = {'P1.s1': p1_s1.r_total,
                           'P2.s1':p2_s1.r_total,
                            'P1.s2': p1_s2.r_total,
                            'P2.s2': p2_s2.r_total,
                           's1.s2':s1_s2.r_total}

            max_ibd_segment_dict = {'P1.s1': p1_s1.max_ibd_segment,
                                   'P2.s1': p2_s1.max_ibd_segment,
                                   'P1.s2': p1_s2.max_ibd_segment,
                                   'P2.s2': p2_s2.max_ibd_segment,
                                   's1.s2': s1_s2.max_ibd_segment}

            n_segment_count_dict = {'P1.s1': p1_s1.n_segment_count,
                                   'P2.s1': p2_s1.n_segment_count,
                                   'P1.s2': p1_s2.n_segment_count,
                                    'P2.s2': p2_s2.n_segment_count,
                                   's1.s2': s1_s2.n_segment_count}
            
            triangulation_logL_dict = defaultdict(lambda:0)
            founder_logL_dict = defaultdict(lambda:0)
            for G_comparison in genealogy_keyswaps:
                #theta refers to the parameter used for the theoretical expectation distribution
                p1_s1_theta, p2_s1_theta, p1_s2_theta, p2_s2_theta, s1_s2_theta = genealogy_keyswaps[G_comparison]
                triangulation_logL_dict[G_comparison] += calc_logp_rtotal(p1_s1_theta, r_total_dict['P1.s1'])
                triangulation_logL_dict[G_comparison] += calc_logp_rtotal(p2_s1_theta, r_total_dict['P2.s1'])
                triangulation_logL_dict[G_comparison] += calc_logp_rtotal(p1_s2_theta, r_total_dict['P1.s2'])
                triangulation_logL_dict[G_comparison] += calc_logp_rtotal(p2_s2_theta, r_total_dict['P2.s2']) 
                triangulation_logL_dict[G_comparison] += calc_logp_rtotal(s1_s2_theta, r_total_dict['s1.s2']) 

                founder_logL_dict['B1'] += calc_logp_rtotal('P1.B11', r_total_dict['P1.s1']) + calc_logp_rtotal('P2.B11', r_total_dict['P2.s1'])
                founder_logL_dict['B2'] += calc_logp_rtotal('P1.B21', r_total_dict['P1.s1']) + calc_logp_rtotal('P2.B21', r_total_dict['P2.s1'])
                founder_logL_dict['F1'] += calc_logp_rtotal('P1.F11', r_total_dict['P1.s1']) + calc_logp_rtotal('P2.F11', r_total_dict['P2.s1'])
                founder_logL_dict['F2'] += calc_logp_rtotal('P1.F21', r_total_dict['P1.s1']) + calc_logp_rtotal('P2.F21', r_total_dict['P2.s1'])

            
            
                for chrom in range(1,15):
                    chrom_idx = chrom - 1
                    #max IBD segment
                    founder_logL_dict['B1'] += calc_logp_max_ibd_segment('P1.B11', max_ibd_segment_dict['P1.s1'][chrom_idx], chrom) + calc_logp_max_ibd_segment('P2.B11', max_ibd_segment_dict['P2.s1'][chrom_idx], chrom)
                    founder_logL_dict['B2'] += calc_logp_max_ibd_segment('P1.B21', max_ibd_segment_dict['P1.s1'][chrom_idx], chrom) + calc_logp_max_ibd_segment('P2.B21', max_ibd_segment_dict['P2.s1'][chrom_idx], chrom)
                    founder_logL_dict['F1'] += calc_logp_max_ibd_segment('P1.F11', max_ibd_segment_dict['P1.s1'][chrom_idx], chrom) + calc_logp_max_ibd_segment('P2.F11', max_ibd_segment_dict['P2.s1'][chrom_idx], chrom)
                    founder_logL_dict['F2'] += calc_logp_max_ibd_segment('P1.F21', max_ibd_segment_dict['P1.s1'][chrom_idx], chrom) + calc_logp_max_ibd_segment('P2.F21', max_ibd_segment_dict['P2.s1'][chrom_idx], chrom)

                    founder_logL_dict['B1'] += calc_logp_p_seg_count('P1.B11', n_segment_count_dict['P1.s1'][chrom_idx], chrom) + calc_logp_p_seg_count('P2.B11', n_segment_count_dict['P2.s1'][chrom_idx], chrom)
                    founder_logL_dict['B2'] += calc_logp_p_seg_count('P1.B21', n_segment_count_dict['P1.s1'][chrom_idx], chrom) + calc_logp_p_seg_count('P2.B21', n_segment_count_dict['P2.s1'][chrom_idx], chrom)
                    founder_logL_dict['F1'] += calc_logp_p_seg_count('P1.F11', n_segment_count_dict['P1.s1'][chrom_idx], chrom) + calc_logp_p_seg_count('P2.F11', n_segment_count_dict['P2.s1'][chrom_idx], chrom)
                    founder_logL_dict['F2'] += calc_logp_p_seg_count('P1.F21', n_segment_count_dict['P1.s1'][chrom_idx], chrom) + calc_logp_p_seg_count('P2.F21', n_segment_count_dict['P2.s1'][chrom_idx], chrom)
                
                
                
                    triangulation_logL_dict[G_comparison] += calc_logp_max_ibd_segment(p1_s1_theta,
                                                                                       max_ibd_segment_dict['P1.s1'][chrom_idx], 
                                                                                       chrom)
                    triangulation_logL_dict[G_comparison] += calc_logp_max_ibd_segment(p2_s1_theta,
                                                                                       max_ibd_segment_dict['P2.s1'][chrom_idx], 
                                                                                       chrom)
                    triangulation_logL_dict[G_comparison] += calc_logp_max_ibd_segment(p1_s2_theta,
                                                                                       max_ibd_segment_dict['P1.s2'][chrom_idx], 
                                                                                       chrom)
                    triangulation_logL_dict[G_comparison] += calc_logp_max_ibd_segment(p2_s2_theta,
                                                                                       max_ibd_segment_dict['P2.s2'][chrom_idx], 
                                                                                       chrom)
                    triangulation_logL_dict[G_comparison] += calc_logp_max_ibd_segment(s1_s2_theta,
                                                                                       max_ibd_segment_dict['s1.s2'][chrom_idx], 
                                                                                       chrom)

                    #Seg Count
                    triangulation_logL_dict[G_comparison] += calc_logp_p_seg_count(p1_s1_theta,
                                                                                   n_segment_count_dict['P1.s1'][chrom_idx], chrom)
                    triangulation_logL_dict[G_comparison] += calc_logp_p_seg_count(p2_s1_theta,
                                                                                   n_segment_count_dict['P2.s1'][chrom_idx], chrom)
                    triangulation_logL_dict[G_comparison] += calc_logp_p_seg_count(p1_s2_theta,
                                                                                   n_segment_count_dict['P1.s2'][chrom_idx], chrom)
                    triangulation_logL_dict[G_comparison] += calc_logp_p_seg_count(p2_s2_theta,
                                                                                   n_segment_count_dict['P2.s2'][chrom_idx], chrom)
                    triangulation_logL_dict[G_comparison] += calc_logp_p_seg_count(s1_s2_theta,
                                                                                   n_segment_count_dict['s1.s2'][chrom_idx], chrom)

            S.founder_logL_dict = founder_logL_dict
            S.max_founder_logL = max(founder_logL_dict, key=founder_logL_dict.get)
            S.triangulation_logL_dict = triangulation_logL_dict
            S.max_triangulation_logL = max(triangulation_logL_dict, key=triangulation_logL_dict.get)
            triangulation_dict[ref_node][triangulation_node] = copy.deepcopy(S)
    counts = Counter([triangulation_dict[ref_node][triangulation_node].max_triangulation_logL.split(':')[0] for triangulation_node in list(triangulation_dict[ref_node])[2:]])
    print(ref_node, counts)
    
f1_samples =[]
f2_samples = []
for key in list(sample_data2.keys())[2:]:
    if 'F1' in key:
        f1_samples.append(key)
    else:
        f2_samples.append(key)  

In [None]:
categorization = defaultdict(dict)
for ref_node in f1_samples:
    counts = Counter([triangulation_dict[ref_node][triangulation_node].max_triangulation_logL.split(':')[0] for triangulation_node in list(triangulation_dict[ref_node])[2:]])
    max_count = max(counts, key=counts.get)
    categorization['f1_pool'][ref_node] = max_count
    
for ref_node in f2_samples:
    counts = Counter([triangulation_dict[ref_node][triangulation_node].max_triangulation_logL.split(':')[0] for triangulation_node in list(triangulation_dict[ref_node])[2:]])
    max_count = max(counts, key=counts.get)
    categorization['f2_pool'][ref_node] = max_count
    
#json.dump(categorization, open('full_triangulation_p1_p2.json', 'w'))

#json.dump(categorization, open('full_triangulation_p1_p2.json', 'w'))
#dill.dump(sample_data, open('full_triangulation_p1_p2.dill', 'wb'))

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
f1_pool_counts = Counter(list(categorization['f1_pool'].values()))
plt.bar(genealogy_classes, [f1_pool_counts[k]/ 55. for k in genealogy_classes])
plt.title('F1 Pool')
plt.ylabel('Proportion')
plt.xlabel('Classification')
plt.subplot(1,2,2)
f2_pool_counts = Counter(list(categorization['f2_pool'].values()))
plt.bar(genealogy_classes, [f2_pool_counts[k]/ 80. for k in genealogy_classes])
plt.title('F2 Pool')
plt.ylabel('Proportion')
plt.xlabel('Classification')
