# Imports, load, preprocess data

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
from ipywidgets import IntProgress, Text, Layout
from tqdm import tqdm

from src.binomial import *
from src.bhc import *

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
reader = pd.read_csv('full_cleaned_data.csv', chunksize=1)
for chunk in reader:
    column_names = chunk.columns.tolist()
    break
column_names.remove('CEPH ID')

In [4]:
cols = list(np.random.choice(column_names, size=1000, replace=False))+['CEPH ID']

data = pd.read_csv('full_cleaned_data.csv', usecols=cols).set_index('CEPH ID')

In [5]:
pops = pd.read_csv('HGDP/hgdp/HGDP-CEPH-ID_populations.csv').set_index('CEPH ID')
pops.head()

Unnamed: 0_level_0,population,Geographic origin,Region,Pop7Groups,Sex,All LCLs (H1063),Unrelated (1st and 2nd degree) (H951)
CEPH ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
HGDP00001,Brahui,Pakistan,Asia,Central_South_Asia,M,yes,yes
HGDP00003,Brahui,Pakistan,Asia,Central_South_Asia,M,yes,yes
HGDP00005,Brahui,Pakistan,Asia,Central_South_Asia,M,yes,yes
HGDP00007,Brahui,Pakistan,Asia,Central_South_Asia,M,yes,yes
HGDP00009,Brahui,Pakistan,Asia,Central_South_Asia,M,yes,yes


In [6]:
for c in ['population', 'Geographic origin', 'Region', 'Pop7Groups', 'Sex']:
    data[c] = pops[c]
    
# remove rows/samples with null values
data = data.drop(data[data.isnull().any(axis=1)].index.tolist())

data.head()

Unnamed: 0_level_0,rs11718605,rs1401161,rs2031797,rs2319220,rs2778652,rs310172,rs3801427,rs4744894,rs518385,rs588952,...,rs9983407,rs9985487,rs9986506,rs999190,rs2404347,population,Geographic origin,Region,Pop7Groups,Sex
CEPH ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
HGDP00448,2,2,2,0,1,2,1,2,2,2,...,1,2,2,1,2,Biaka_Pygmy,Central African Republic,Subsaharan Africa,Africa,M
HGDP00479,2,2,2,0,2,1,2,2,0,1,...,2,2,2,2,2,Biaka_Pygmy,Central African Republic,Subsaharan Africa,Africa,M
HGDP00985,2,2,2,0,2,1,1,2,0,2,...,2,2,1,2,2,Biaka_Pygmy,Central African Republic,Subsaharan Africa,Africa,M
HGDP01094,2,2,2,0,2,0,2,2,2,2,...,1,2,2,0,2,Biaka_Pygmy,Central African Republic,Subsaharan Africa,Africa,M
HGDP00982,2,2,2,0,2,1,1,2,1,2,...,1,2,2,1,2,Mbuti_Pygmy,Democratic Republic of Congo,Subsaharan Africa,Africa,M


# Run

In [7]:
data['Geographic origin'].unique()

array(['Central African Republic', 'Democratic Republic of Congo',
       'Senegal', 'Nigeria', 'Kenya', 'Namibia', 'South Africa',
       'Algeria (Mzab)', 'Israel (Negev)', 'Israel (Carmel)',
       'Israel (Central)', 'Pakistan', 'China', 'Siberia', 'Japan',
       'Cambodia', 'New Guinea', 'Bougainville', 'France', 'Italy',
       'Italy (Bergamo)', 'Orkney Islands', 'Russia Caucasus', 'Russia',
       'Mexico', 'Colombia', 'Brazil'], dtype=object)

In [8]:
N = 20
M = 60
D_JPN = data.loc[
    data['Geographic origin'] == 'Japan'
].values[:int(N/2), :M].astype(np.int8)
D_CAR = data[
    data['Geographic origin'] == 'Central African Republic'
].values[:int(N/2), :M].astype(np.int8)

D = np.vstack([D_JPN, D_CAR])
indices = (['Japan%d' % i for i in range(10)])
indices = indices + ['Central_African_Republic%d' % i for i in range(10)]

model = BetaPriorMC(3., 3.)
crp_alpha = 0.8

In [9]:
bhc = BHC(D, crp_alpha, model, indices)

progress = IntProgress(value=0, max=N)
inner_progress = IntProgress(value=0, max=comb(N, 2))
text = Text(value='', layout=Layout(width='100%'))
display(progress)
display(inner_progress)
display(text)

def outer_hook(new_node, left, right, i):
    progress.value += 1
    inner_progress.value = 0
    inner_progress.max = comb(N-i-1, 2)
    text.value = 'i=%d: merging %s and %s' % (i, left.index, right.index)
    
def inner_hook(node, left, right, j):
    inner_progress.value += 1
    
outer_hooks = [outer_hook]
inner_hooks = [inner_hook]

bhc.build_tree(inner_hooks, outer_hooks)

progress.close()
inner_progress.close()
text.close()

IntProgress(value=0, max=20)

IntProgress(value=0, max=190)

Text(value='', layout=Layout(width='100%'))

In [13]:
import pickle as pkl
pkl.dump(bhc, open('BHC_SAVED.p', 'wb'))

  


# prior distribution plotting

In [None]:
a = 1.
b = 1.

xs = np.arange(0.01, 1., 0.01)
ys = beta_dist.pdf(xs, a, b)
ys /= np.sum(ys)

f, ax = plt.subplots()
ax.plot(xs, ys)
ax.set_ylim([0., 0.1])
ax.set_xlabel(r'$p$ (Bernoulli parameter)')
ax.set_ylabel(r'$Pr(p | \alpha=%.2f, \beta=%.2f)$' % (a, b))
plt.show()

# Testing

In [None]:
if 0:
    results = []
    samples_JPN = []
    lps_JPN = []
    samples_MIX = []
    lps_MIX = []
    for _ in tqdm(range(50)):

        # model = BetaBinomial(alpha=1., beta=1.)
        # model = BinomialMLE()
        model = BetaPriorMC(alpha=3., beta=3.)
        crp_alpha = 1.

        # c_i = np.random.randint(0, 10)
        # j_i1, j_i2 = np.random.choice(10, replace=False, size=2)
        # print(c_i, j_i1, j_i2)
        c_i = 0
        j_i1, j_i2 = 0, 1

        # D1 = data.values[0, :50].astype(np.int8).reshape(1, -1)
        d_CAR = data[data['Geographic origin'] == 'Central African Republic'].values[c_i, :50].reshape(1, -1).astype(np.int8)
        n_CAR = Leaf(d_CAR, crp_alpha, model)

        # D2 = data.values[1, :50].astype(np.int8).reshape(1, -1)
        d_JPN = data[data['Geographic origin'] == 'Japan'].values[j_i1, :50].reshape(1, -1).astype(np.int8)
        n_JPN = Leaf(d_JPN, crp_alpha, model)

        # D3 = data.values[500, :50].astype(np.int8).reshape(1, -1) 
        d_JPN2 = data[data['Geographic origin'] == 'Japan'].values[j_i2, :50].reshape(1, -1).astype(np.int8)
        n_JPN2 = Leaf(d_JPN2, crp_alpha, model)

        LL_CAR = n_CAR.log_pr_data_tk
        LL_JPN = n_JPN.log_pr_data_tk
        LL_JPN2 = n_JPN2.log_pr_data_tk



        N_JPN = Node.merge(n_JPN, n_JPN2)
        N_MIX = Node.merge(n_JPN, n_CAR)

        results.append((np.exp(N_JPN.log_rk), np.exp(N_MIX.log_rk)))
        samples_JPN.append(N_JPN.samples)
        samples_MIX.append(N_MIX.samples)
        lps_JPN.append(N_JPN.log_pr_data_h1)
        lps_MIX.append(N_MIX.log_pr_data_h1)


    results = np.array(results)
    h1, b1 = np.histogram(results[:, 0], bins=np.linspace(0, 1, 10))
    h2, b2 = np.histogram(results[:, 1], bins=np.linspace(0, 1, 10))

    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8), sharey=True)
    ax1.bar(b1[:-1], h1)
    ax2.bar(b2[:-1], h2)
    ax1.set_title('JPN')
    ax2.set_title('MIX')
    plt.show()
    

if 0:
    # model = BetaBinomial(alpha=1., beta=1.)
    # model = BinomialMLE()
    model = BetaPriorMC(alpha=3., beta=3.)
    crp_alpha = 1.

    # c_i = np.random.randint(0, 10)
    # j_i1, j_i2 = np.random.choice(10, replace=False, size=2)
    # print(c_i, j_i1, j_i2)
    c_i = 0
    j_i1, j_i2 = 0, 1

    # D1 = data.values[0, :50].astype(np.int8).reshape(1, -1)
    d_CAR = data[data['Geographic origin'] == 'Central African Republic'].values[c_i, :50].reshape(1, -1).astype(np.int8)
    n_CAR = Leaf(d_CAR, crp_alpha, model)

    # D2 = data.values[1, :50].astype(np.int8).reshape(1, -1)
    d_JPN = data[data['Geographic origin'] == 'Japan'].values[j_i1, :50].reshape(1, -1).astype(np.int8)
    n_JPN = Leaf(d_JPN, crp_alpha, model)

    # D3 = data.values[500, :50].astype(np.int8).reshape(1, -1) 
    d_JPN2 = data[data['Geographic origin'] == 'Japan'].values[j_i2, :50].reshape(1, -1).astype(np.int8)
    n_JPN2 = Leaf(d_JPN2, crp_alpha, model)

    LL_CAR = n_CAR.log_pr_data_tk
    LL_JPN = n_JPN.log_pr_data_tk
    LL_JPN2 = n_JPN2.log_pr_data_tk



    N_JPN = Node.merge(n_JPN, n_JPN2)
    N_MIX = Node.merge(n_JPN, n_CAR)
    
    # print('P(D_i | H_1^i):', LL_CAR, LL_JPN, LL_JPN2)
    # print('P(D_k | H_1^k):', N_JPN.log_pr_data_h1, N_MIX.log_pr_data_h1)
    # print('P(D_k | T_k):  ', N_JPN.log_pr_data_tk, N_MIX.log_pr_data_tk)
    # print('P(H_1^k | D_k):', np.exp(N_JPN.log_rk), np.exp(N_MIX.log_rk))


    print('JAPAN NODE')
    print('P(D_i | H_1^i):', LL_JPN, LL_JPN2)
    print('P(D_k | H_1^k):', N_JPN.log_pr_data_h1)
    print('P(D_k | T_k):  ', N_JPN.log_pr_data_tk)
    print('P(H_1^k | D_k):', np.exp(N_JPN.log_rk))

    print('')
    print('MIXED NODE')
    print('P(D_i | H_1^i):', LL_JPN, LL_CAR)
    print('P(D_k | H_1^k):', N_MIX.log_pr_data_h1)
    print('P(D_k | T_k):  ', N_MIX.log_pr_data_tk)
    print('P(H_1^k | D_k):', np.exp(N_MIX.log_rk))

In [None]:
# f, ax = plt.subplots()
# ax.scatter(range(50), results[:, 0])
# plt.show()

split = 0.6
idx_g = np.argwhere(results[:, 0] > split).squeeze()
idx_b = np.argwhere(results[:, 0] <= split).squeeze()

LP_good = np.array(lps_JPN)[idx_g]
RK_good = results[idx_g, 0]
LP_bad = np.array(lps_JPN)[idx_b]
RK_bad = results[idx_b, 0]

s_JPN = sorted(lps_JPN)
m1, m2 = s_JPN[24], s_JPN[25]
s_RK = sorted(results[:, 0])
r1, r2 = s_RK[24], s_RK[25]
RK_med = (r1 + r2) / 2.
LP_med = np.median(lps_JPN)

print(LP_med.shape, RK_med.shape)

plt.scatter(LP_good, RK_good, label='good')
plt.scatter(LP_bad, RK_bad, label='bad')
plt.scatter([LP_med], [RK_med], label='median')
plt.legend()
plt.show()
    