# Load simulated data

In [1]:
import numpy as np
import torch
from sbi.utils import BoxUniform

from teddy.data.Alphabet import Alphabet
from teddy.data.dataset import MsaLabels

lower_bound = torch.as_tensor([1.0, 1.0])
upper_bound = torch.as_tensor([5.0, 10.0])
prior = BoxUniform(low=lower_bound, high=upper_bound)

alphabet = Alphabet(["A", "C", "G", "T"])

In [None]:
def observe(clock_rate):
    simulation = MsaLabels(f"data/observation/seq", alphabet)

    theta_0 = torch.Tensor(simulation[0][1])
    x_0 = torch.flatten(torch.Tensor(simulation[0][0][0]))

    return theta_0, x_0

def load_data(clock_rate):
    simulations = MsaLabels(f"data/clock_rate-{clock_rate}/seq", alphabet)
    
    theta, x = [], []

    for index, simulation in enumerate(simulations):
        if index == 0:
            continue
        theta.append(simulation[1])
        x.append(simulation[0][0])

    theta = torch.Tensor(np.array(theta))
    x = torch.flatten(torch.Tensor(np.array(x)), start_dim=1)

    print(f"Data for clock rate 1e-{clock_rate} loaded.")

    return theta, x

# NPE training

In [3]:
from sbi.inference import NPE
from sbi.neural_nets import posterior_nn
from sbi.neural_nets.embedding_nets import FCEmbedding

In [4]:
def train_npe(clock_rate, embedding=False):
    if embedding:
        embedding_net = FCEmbedding(
            input_dim=x.shape[1],
            output_dim=10,
            num_layers=2,
            num_hiddens=50
        )
        density_estimator = posterior_nn(model="maf", embedding_net=embedding_net)
        npe = NPE(prior=prior, density_estimator=density_estimator)
    else:
        npe = NPE(prior=prior)
    
    theta, x = load_data(clock_rate)
    npe.append_simulations(theta, x)
    npe.train()

    print(f"NPE for clock rate 1e-{clock_rate} trained.")

    posterior = npe.build_posterior()
    return posterior

def sample_npe(clock_rate, sample_size=10_000, embedding=False):
    _, x_0 = observe(clock_rate)
    posterior = train_npe(clock_rate, embedding)
    samples = posterior.sample((sample_size,), x=x_0)
    print(f"Posterior for clock rate 1e-{clock_rate} sampled.")
    return samples

# NPE plotting

In [5]:
import matplotlib.pyplot as plt
from sbi.analysis import pairplot

In [6]:
def plot_npe(clock_rate, sample_size=10_000, embedding=False):
    theta_0, _ = observe(clock_rate)
    samples = sample_npe(clock_rate, sample_size, embedding)

    fig, axes = pairplot(
        samples,
        labels=["Basic reproduction number", "Infectious time"],
        figsize=(8, 8),
        points=theta_0,
        points_colors="r",
    )
    plt.suptitle(f"NPE Posterior (clock rate: 1e-{clock_rate})", y=1.02)
    plt.show()

def compare_npe(clock_rates, sample_size=10_000, embedding=False):
    samples, labels = [], []
    for clock_rate in clock_rates:
        samples.append(sample_npe(clock_rate, sample_size, embedding))
        labels.append(f"1e-{clock_rate}")
    
    fig, axes = pairplot(
        samples,
        labels=["Basic reproduction number", "Infectious time"],
        figsize=(8, 8),
        points_colors="r",
        diag="hist",
        upper="scatter",
    )
    fig.legend(
        labels,
        loc="upper right",
        bbox_to_anchor=(0.95, 0.95),
    )
    plt.suptitle("NPE Posterior Comparison (various clock rates)", y=1.02, fontsize=14)
    plt.show()

# Sandbox

In [8]:
compare_npe([0, 1, 2, 3, 4, 5, 6])

Observation for clock rate 1e-0 loaded.
Data for clock rate 1e-0 loaded.
 Neural network successfully converged after 32 epochs.NPE for clock rate 1e-0 trained.


  0%|          | 0/10000 [00:00<?, ?it/s]

Posterior for clock rate 1e-0 sampled.
Observation for clock rate 1e-1 loaded.
Data for clock rate 1e-1 loaded.
 Neural network successfully converged after 29 epochs.NPE for clock rate 1e-1 trained.


  0%|          | 0/10000 [00:00<?, ?it/s]

Posterior for clock rate 1e-1 sampled.
Observation for clock rate 1e-2 loaded.
Data for clock rate 1e-2 loaded.
 Neural network successfully converged after 30 epochs.NPE for clock rate 1e-2 trained.


  0%|          | 0/10000 [00:00<?, ?it/s]

Posterior for clock rate 1e-2 sampled.
Observation for clock rate 1e-3 loaded.
Data for clock rate 1e-3 loaded.
 Neural network successfully converged after 27 epochs.NPE for clock rate 1e-3 trained.


  0%|          | 0/10000 [00:00<?, ?it/s]

KeyboardInterrupt: 