# dnadna: Deep Neural Architecture for DNA
This notebook reproduces an example of population size history inference performed by the SPIDNA deep learning method described in the paper ["Deep learning for population size history inference: design, comparison and combination with approximate Bayesian computation"](https://www.biorxiv.org/content/10.1101/2020.01.20.910539v1.full.pdf) (Sanchez et al.). It uses the dnadna package that can be installed by following the instructions [here](https://mlgenetics.gitlab.io/dnadna/introduction.html#installation).

We will simulate SNP data for six scenarios with population size history defined by hand (e.g. expansion, decline or bottleneck) and use a pretrained version of SPIDNA to reconstruct these population size histories. This architecture has been trained using data generated with msprime and prior described in Sanchez et al. [methods section](https://www.biorxiv.org/content/10.1101/2020.01.20.910539v1.full.pdf#page=9). Therefore, using the same architecture to infer population size histories far from this prior might lead to high prediction errors.

## Import

In [None]:
import numpy as np
import pandas as pd
import os
import msprime
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path

## Simulate scenario
First, we simulate six different scenarios with 21 time steps using msprime (Kelleher et al., 2016).

In [None]:
def simulate_scenario(population_size, population_time, seed, num_replicates, mutation_rate, recombination_rate, 
                      segment_length, num_sample):
    
    demographic_events = [msprime.PopulationParametersChange(
                time=population_time[i],
                growth_rate=0,
                initial_size=population_size[i]) for i in range(1, len(population_time))]

    population_configurations = [msprime.PopulationConfiguration(
                sample_size=num_sample,
                initial_size=population_size[0])]

    tree_sequence = msprime.simulate(
                length=segment_length,
                population_configurations=population_configurations,
                demographic_events=demographic_events,
                recombination_rate=recombination_rate,
                mutation_rate=mutation_rate,
                num_replicates=num_replicates,
                random_seed=seed)
    pos = []
    snp = []
    for i, rep in enumerate(tree_sequence):
                positions = [variant.site.position for variant in rep.variants()]
                positions = np.array(positions) - np.array([0] + positions[:-1])
                positions = positions.astype(int)
                pos.append(positions)
                SNPs = rep.genotype_matrix().T.astype(np.uint8)
                snp.append(SNPs)
    data = [[snp[i], pos[i]] for i in range(len(snp))]
    data = [np.vstack([d[1], d[0]]) for d in data]
    return data

In [None]:
# Population sizes are defined on a log10 scale
scenarios = {'Medium': [3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7, 3.7],   
             'Large': 4.7 * np.ones(shape=21, dtype='float'), 
             'Decline': [2.5, 2.5, 3, 3, 3, 3, 3.2, 3.4, 3.6, 3.8, 4, 4.2, 4.6, 4.6, 4.6, 4.6, 4.6, 4.6, 4.6, 4.6, 4.6], 
             'Expansion': [4.7, 4.7, 4.7, 4.6, 4.6, 4.5, 4.4, 4.3, 4, 3.7, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4, 3.4], 
             'Bottleneck': [4.8, 4.8, 4.8, 4.8, 4.8, 4.8, 4.8, 4.8, 4.8, 4.5, 4.15, 3.8, 4.3, 4.8, 4.55, 4.3, 4.05, 3.8, 3.8, 3.8, 3.8], 
             'Zigzag': [4.8, 4.8, 4.8, 4.5, 4.15, 3.8, 4.15, 4.5, 4.8, 4.5, 4.15, 3.8, 4.3, 4.8, 4.55, 4.3, 4.05, 3.8, 3.8, 3.8, 3.8]}
scenarios = {k:10**np.array(scenarios[k]) for k in scenarios.keys()}
seed = 2
num_replicates = 100
mutation_rate = 1e-8
segment_length = 2e6
time_rate = 0.06
tmax = 130000
num_time_windows = 21
num_sample = 50
population_time = [(np.exp(np.log(1 + time_rate * tmax) * i /
                  (num_time_windows - 1)) - 1) / time_rate for i in
                  range(num_time_windows)]
snp_data = {}
for k in scenarios.keys():
    print(f'Simulating scenario \"{k}\"')
    population_size = scenarios[k]
    recombination_rate = np.random.uniform(low=1e-9, high=1e-8)
    snp_data[k] = simulate_scenario(population_size, population_time, seed, num_replicates, mutation_rate, 
                             recombination_rate, segment_length, num_sample)

## Save data
We save data into ```.npz``` files with two keywords ```SNP``` and ```POS``` for the SNP matrix and its associated position vector.

In [None]:
data_path = 'data'
Path(data_path).mkdir(parents=True, exist_ok=True)
file_list = []
for k in snp_data.keys():
    Path(os.path.join(data_path, k)).mkdir(parents=True, exist_ok=True)
    for i in range(len(snp_data[k])):
        SNP = snp_data[k][i][1:,:]
        POS = snp_data[k][i][0,:]
        npz_save_path = os.path.join(data_path, k, f'{k}_{i}')
        np.savez(npz_save_path, POS=POS, SNP=SNP)
        file_list.append(f'{npz_save_path}.npz')
with open('file_list', 'a') as f:
    f.write('\n'.join(file_list))

## Inference with dnadna package

You can use the _predict_ command with arguments _--progress-bar_ (displays a progress bar for prediction) as follows:

In [None]:
! dnadna predict pretrained_SPIDNA/pretrained_SPIDNA_net.pth $(cat file_list) --progress-bar -o result.csv

You can use Bash wildcards when specifying paths in dnadna commands such as:

In [None]:
! dnadna predict pretrained_SPIDNA/pretrained_SPIDNA_net.pth data/*/*.npz -o result_bis.csv --progress-bar

## Generate Plot

In [None]:
# Load predictions
predictions = pd.read_csv('result.csv')
predictions['scenario_name'] = predictions.apply(lambda path: path[0].split('/')[-2], axis=1)
predictions = predictions.drop(columns=['path'])

sns.set(style="ticks", font_scale=1.2)
fig = plt.figure(dpi=300, figsize=(15/1.2, 15/1.2))

# Add one step to represent the infinite time
population_time1 = np.append(population_time, population_time[-1]*1.5)

axs = []
i = 1
for key in scenarios.keys():
    axs.append(fig.add_subplot(3, 2, i))
    
    population_size = np.append(np.array(scenarios[key]), scenarios[key][-1])
    ax = sns.lineplot(x=population_time1, y=population_size, drawstyle='steps-post', color=(0,0,0))
    ax.lines[0].set_linestyle("--")
    
    population_time2 = np.append(np.array(population_time1), 0)
    boxplot_widths = (population_time2[1:] - population_time2[:-1])/2
    boxplot_widths[-1] = - boxplot_widths[-1]
    boxplot_positions = population_time2[:-1] + boxplot_widths
    dat = predictions.loc[predictions['scenario_name'] == key].iloc[:, range(21)]
    dat.columns = population_time1[:-1]
    dat = np.transpose(dat)
    plt.boxplot([dat.loc[i] for i in dat.index], vert=True, positions=boxplot_positions[:-1],
                widths=boxplot_widths[:-1], sym='k.')
    
    plt.ylabel('')
    plt.xlabel('')
    plt.ylim(80, 5e5)
    plt.yscale('log')
    plt.xscale('log')
    plt.title(key)
    i += 1
    
fig.text(0.5, 0.04, 'Generations before present (log scale)', ha='center')
fig.text(0.04, 0.5, 'Effective population size (log scale)', va='center', rotation='vertical')
[ax.get_xaxis().set_visible(False) for ax in axs[:4]]
[ax.get_yaxis().set_visible(False) for ax in [axs[1], axs[3], axs[5]]]
[ax.margins(x=0) for ax in axs]

plt.show()