# Imports, load data

In [10]:
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
from collections import defaultdict

from IPython.display import clear_output, display
from ipywidgets import IntProgress

from src.lai import compute_window_accuracies, assign_ancestries

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
f = open('data/imputations_t50_TMP.p', 'rb')
imputations = pkl.load(f)
f.close()

In [5]:
H_AFR = np.loadtxt('data/H_AFR.txt')
H_AMR = np.loadtxt('data/H_AMR.txt')
H_EAS = np.loadtxt('data/H_EAS.txt')
H_EUR = np.loadtxt('data/H_EUR.txt')

H_valid = np.zeros((0, H_AFR.shape[1]))
for H in [H_AFR, H_AMR, H_EUR, H_EAS]:
    H_valid = np.vstack([H_valid, H[-50:, :]])
    
del H_AFR, H_AMR, H_EAS, H_EUR

# Compute window accuracies

In [7]:
prog = IntProgress(value=0, max=16285)
display(prog)

def hook(col):
    prog.value = col
    
accs = compute_window_accuracies(imputations, H_valid, window_size=10, hooks=[hook])

prog.close()

# e = 206
# imputations['EUR'][:, -e:].sum()



IntProgress(value=0, max=16285)

AFR (200, 16285)
EUR (200, 16285)
AMR (200, 16285)
EAS (200, 16285)


## Save results

In [9]:
save_path = 'data/accuracies_td50.p'

out = dict()
out['model_info'] = imputations['model_info']
out['H_valid_info'] = imputations['H_valid_info']
out['accuracies'] = accs
try:
    f = open(save_path, 'wb')
    pkl.dump(out, f)
except:
    print('Saving accuracies failed.')
finally:
    f.close()

# Assign local ancestries

In [32]:
prog = IntProgress(value=0, max=100)
display(prog)

def hook(i, g_row):
    prog.value = i

_ = assign_ancestries(accs, hooks=[hook])

prog.close()

IntProgress(value=0)

2.6362664967906754e-07 0.0
[2.6362664967906754e-07, 2.6362664967906754e-07, 2.6362664967906754e-07, 2.6362664967906754e-07, 2.6362664967906754e-07, 2.6362664967906754e-07, 2.6362664967906754e-07, 2.6362664967906754e-07, 2.6362664967906754e-07, 2.6362664967906754e-07]

2.0339409297451338e-07 0.0
[2.0339409297451338e-07, 2.0339409297451338e-07, 2.0339409297451338e-07, 2.0339409297451338e-07, 2.0339409297451338e-07, 2.0339409297451338e-07, 2.0339409297451338e-07, 2.0339409297451338e-07, 2.0339409297451338e-07, 2.0339409297451338e-07]

1.8081694292726785e-07 0.0
[1.8081694292726785e-07, 1.8081694292726785e-07, 1.8081694292726785e-07, 1.8081694292726785e-07, 1.8081694292726785e-07, 1.8081694292726785e-07, 1.8081694292726785e-07, 1.8081694292726785e-07, 1.8081694292726785e-07, 1.8081694292726785e-07]

1.318618733765543e-07 0.0
[1.318618733765543e-07, 1.318618733765543e-07, 1.318618733765543e-07, 1.318618733765543e-07, 1.318618733765543e-07, 1.318618733765543e-07, 1.318618733765543e-07, 1.318

KeyboardInterrupt: 

In [28]:


ans_AFR = defaultdict(int)
for row in range(0, 50):
#     ans_AFR.update(set(list(np.unique(ancestries[row]))))
    ans, cts = np.unique(ancestries[row], return_counts=True)
    for a, c in zip(ans, cts):
        ans_AFR[a] += c
    
    print('='*30)
    print('Row %d' % row)
    for (a, c) in zip(ans, cts):
        print(tuple(a), ':', c)
        
    input()
    clear_output()
    

Row 4
('AFR', 'AMR') : 1
('AFR', 'EAS') : 10845
('AFR', 'AMR') : 9
('AFR', 'EAS') : 5430


KeyboardInterrupt: 