In [None]:
import os
import numpy as np
import bz2
import pickle
import zarr
import dask.array as da
import sys
import random
import networkx as nx
import matplotlib.pyplot as plt

In [None]:
os.chdir('/Users/Device6/Documents/Research/bgoodlab/microbiome_evolution/')

In [None]:
import config
from utils import sample_utils, core_gene_utils, diversity_utils, parallel_utils

In [None]:
species_name = 'Bacteroides_vulgatus_57955'
dh = parallel_utils.DataHoarder(species_name, mode='QP')

In [None]:
good_chromo = dh.chromosomes[dh.general_mask]
core_genes = core_gene_utils.get_sorted_core_genes(species_name)

In [None]:
# choose a clade cutoff based on the plot
div_dir = os.path.join(config.analysis_directory, 'pairwise_divergence', 'between_hosts', '%s.csv' % species_name)
div_mat = np.loadtxt(div_dir, delimiter=',')
uptri = np.triu_indices(div_mat.shape[0], 1)
divs = div_mat[uptri]
_ = plt.hist(divs[divs < 0.03], bins=100, density=True)
plt.xlabel('Pairwise synonymous divergence')
plt.ylabel('Density')
# plt.savefig(os.path.join(config.analysis_directory, 'misc', 'B_vulgatus_pairwise.pdf'))

In [None]:
clade_samples = np.nonzero(div_mat[0, :] < 0.03)[0]
clade_mask = div_mat[0, :] < 0.03 # very importantly, use this to keep only the main clade of B vulgatus

In [None]:
%%time
# pile up plot of the core 4D genome
thresholds = [4000, 6000, 8000]
cumu_runs = np.zeros([np.sum(dh.general_mask), 3])
core_gene_cum = np.zeros([len(core_genes), 3])

# for pair in random.sample(clade_pairs, 1000):
rep = 0
while rep < 3000:
    pair = random.sample(clade_samples, 2)
    if div_mat[pair[0], pair[1]] < 0.005: # removing cousin pairs
        continue
    # get the snp data
    snp_vec, coverage_arr = dh.get_snp_vector(pair)
    # get the location in the full array
    snp_to_core = np.nonzero(coverage_arr)[0]
    snp_genome_locs = snp_to_core[np.nonzero(snp_vec)[0]]

    runs = parallel_utils.compute_runs_all_chromosomes(snp_vec, good_chromo[coverage_arr])
    # filter the runs and accumulate
    for i in range(3):
        threshold = thresholds[i]
        event_starts = snp_genome_locs[:-1][runs > threshold]
        event_ends = snp_genome_locs[1:][runs > threshold]    
        for start, end in zip(event_starts, event_ends):
            cumu_runs[start:end, i] += 1
            gene_start = np.nonzero(core_genes == dh.gene_names[dh.general_mask][start])[0][0]
            gene_end = np.nonzero(core_genes == dh.gene_names[dh.general_mask][end])[0][0]
            core_gene_cum[gene_start:gene_end, i] += 1
    rep += 1

In [None]:
def count_local_haplotypes(dh, start_idx, end_idx, sample_mask, snp_tol=1):
    hap = dh.snp_arr[start_idx:end_idx, sample_mask]
    covered = dh.covered_arr[start_idx:end_idx, sample_mask]
    # perform (dumb) pairwise comparison
    res = np.zeros((hap.shape[1], hap.shape[1]))
    for i in range(hap.shape[1]):
        for j in range(hap.shape[1]):
            v = hap[:, i] != hap[:, j]
            c = covered[:, i] & covered[:, j]
            res[i, j] = np.sum(v & c)
    G = nx.from_numpy_matrix(res <= 1)
    components = map(list, sorted(nx.connected_components(G), key=len, reverse=True))
    return components

In [None]:
%%time
# computing the sliding window analysis
window_size = 4000
stride = 1000
start = 0
all_fracs = []
import time
t0 = time.time()
while start < dh.snp_arr.shape[0]:
    end = min(dh.snp_arr.shape[0], start + window_size)
    cs = count_local_haplotypes(dh, start, end, clade_mask)
    frac = np.array(map(len, cs)) / float(np.sum(clade_mask))
    all_fracs.append(frac)
    start += stride
    if start % 10000 == 0:
        print("{:.2}% at {:.2} secs".format(100 * float(start) / dh.snp_arr.shape[0], time.time() - t0))

In [None]:
import pandas as pd
df = pd.DataFrame()
df['Start'] = np.arange(0, dh.snp_arr.shape[0], stride)
df['End'] = np.minimum(df['Start'] + window_size, dh.snp_arr.shape[0])
df['MaxFrac'] = map(max, all_fracs)
df['Heterozygosity'] = map(lambda x: np.sum(x*x), all_fracs)
df['MaxFrac^2'] = df['MaxFrac'] ** 2
df['H12'] = map(lambda x: np.sum(x*x) + 2*x[0]*x[1] if len(x) > 1 else x[0]**2, all_fracs)
df['H2'] = map(lambda x: np.sum(x*x) - x[0]**2, all_fracs)

In [None]:
# for snp density plot
has_snp = np.sum((dh.snp_arr & dh.covered_arr)[:, clade_mask], axis=1) > 0

In [None]:
from matplotlib import cm # for color map

fig, axes = plt.subplots(4, 1, figsize=(8, 6), dpi=300)
snp_window = 10000
axes[0].plot(np.convolve(has_snp, np.ones(snp_window)/snp_window, mode='Same'))
for i in range(3):
    axes[1].plot(cumu_runs[:, i] / 3000, label=thresholds[i])
locs = np.array((df['Start'] + df['End']) / 2).astype(int)
axes[2].plot((df['Start'] + df['End']) / 2, df['Heterozygosity'], '.', markersize=1, label='Heterozygosity (H1)')
axes[2].plot((df['Start'] + df['End']) / 2, df['H12'], '.', markersize=1, label='H12')

axes[3].plot(locs, np.array(df['H2'] / df['Heterozygosity']), label='H2/H1')


for i in range(3):
    axes[i].set_xticklabels([])
axes[1].legend()
axes[2].legend()
axes[0].set_ylabel('SNP density')
axes[1].set_ylabel('Frac of pairs sharing')
axes[3].set_ylabel('H2 / H1')
axes[3].set_xlabel('Core gene 4D location')
plt.tight_layout()

In [None]:
from matplotlib import cm # for color map

fig, axes = plt.subplots(4, 1, figsize=(8, 6), dpi=300)
snp_window = 10000
axes[0].plot(np.convolve(has_snp, np.ones(snp_window)/snp_window, mode='Same'))
for i in range(3):
    axes[1].plot(cumu_runs[:, i] / 3000, label=thresholds[i])

from scipy.interpolate import interp1d
y = np.array(df['H2'] / df['Heterozygosity'])
x = np.array((df['Start'] + df['End']) / 2).astype(int)
f = interp1d(x, y)
xnew = np.arange(min(x), max(x))
c_val = cumu_runs[xnew, 0]
axes[3].scatter(xnew, f(xnew), s=1, c=cm.Blues(np.abs(c_val / max(c_val))), label='H2/H1')

axes[2].plot(x[:-4], df['Heterozygosity'][:-4], '.', markersize=1, label='Heterozygosity (H1)')
axes[2].plot(x[:-4], df['H12'][:-4], '.', markersize=1, label='H12')

for i in range(3):
    axes[i].set_xticklabels([])
axes[1].legend()
axes[2].legend()
axes[0].set_ylabel('SNP density')
axes[1].set_ylabel('Frac of pairs sharing')
axes[3].set_ylabel('H2 / H1')
axes[3].set_xlabel('Core gene 4D location')
plt.tight_layout()

In [None]:
plt.subplots(1, 1, figsize=(8,2), dpi=300)
plt.ylim([0, 0.07])
for i in range(3):
    plt.plot(cumu_runs[:, i] / 3000, label=thresholds[i])