In [4]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import gurobipy as gp
from gurobipy import GRB
from ast import literal_eval
import scipy
import json
import tqdm

# Moment Bootstrap

As an alternative / addition to bootstrapping confidence intervals for probabilities, can do so for moments e.g.

$$ \mathbb{E}[X_{1}^{OB}] \quad \mathbb{E}[X_{2}^{OB}] \quad \mathbb{E}[X_{1}^{OB}X_{2}^{OB}] \quad \cdots $$

Which can then be easily scaled by capture efficiency to relate to $OG$ counts, and can form constraints relating to probabilities (and so CME) e.g.

$$ \mathbb{E}[X_{1}^{OG}] = \sum_{x_{1}^{OG}} x_{1}^{OG} p_{1}(x_{1}^{OG}) \in \text{CI} $$

## Code: Simulation

In [5]:
def gillespie(params, n, beta, tmax=100, ts=10, plot=False, initial_state=(0, 0)):
    '''
    Simulate a sample path of birth-death regulation model.

    Gillespie algorithm to simulate a sample path of the markov chain described
    by the birth-death regulation stochastic reaction network model with given
    parameters. After a burn-in time of 'tmax' samples are taken from the sample
    path at time intervals of 'ts'. The states / samples are pairs of counts
    (x1, x2) from a pair of genes.

    Args:
        params: dict of reaction rate constants 'k_tx_1', 'k_tx_2', 'k_deg_1',
                'k_deg_2', 'k_deg'
        n: sample size
        beta: per cell capture efficiency vector of size n / single value
        tmax: burn-in time of simulation
        ts: time between samples
        plot: toggle plotting of sample path
        intitial_state: starting state of simulation

    Returns:
        A dictionary containing results

        Samples without capture efficiency

        'x1_OG': n samples from gene 1
        'x2_OG': n samples from gene 2
        'OG': n pairs of samples

        Samples with capture efficiency

        'x1_OB': n samples from gene 1 affected by capture efficiency
        'x2_OB': n samples from gene 2 affected by capture efficiency
        'OB': n pairs of samples affected by capture efficiency
    '''

    # initialize random generator
    rng = np.random.default_rng()

    # initialise time and state
    t = 0
    path = [initial_state]
    jump_times = [0]

    # simulate for burn-in time and time between n samples
    while t < tmax + (n - 1) * ts:

        # current state
        x1, x2 = path[-1][0], path[-1][1]

        # transition rates
        q_tx_1 = params['k_tx_1']
        q_tx_2 = params['k_tx_2']
        q_deg_1 = x1 * params['k_deg_1']
        q_deg_2 = x2 * params['k_deg_2']
        q_reg = x1 * x2 * params['k_reg']
        q_hold = q_tx_1 + q_tx_2 + q_deg_1 + q_deg_2 + q_reg

        # holding time in current state
        t_hold = -np.log(rng.uniform()) / q_hold
        t += t_hold
        jump_times.append(t)

        # jump probability
        outcome = [1, 2, 3, 4, 5]
        prob = [
            q_tx_1 / q_hold,
            q_tx_2 / q_hold,
            q_deg_1 / q_hold,
            q_deg_2 / q_hold,
            q_reg / q_hold
        ]
        jump = rng.choice(outcome, p=prob)
        match jump:
            case 1:
                path.append((x1 + 1, x2))
            case 2:
                path.append((x1, x2 + 1))
            case 3:
                path.append((x1 - 1, x2))
            case 4:
                path.append((x1, x2 - 1))
            case 5:
                path.append((x1 - 1, x2 - 1))

    # take the transcript states
    x1_path = [state[0] for state in path]
    x2_path = [state[1] for state in path]

    # create step function of sample path from jump times and jump values
    x1_path_function = scipy.interpolate.interp1d(jump_times, x1_path, kind='previous')
    x2_path_function = scipy.interpolate.interp1d(jump_times, x2_path, kind='previous')

    # take values at sampling times as samples from stationary dist
    sample_times = [tmax + i * ts for i in range(n)]
    x1_samples = x1_path_function(sample_times)
    x2_samples = x2_path_function(sample_times)

    # convert to integers
    x1_samples = [int(x1) for x1 in x1_samples]
    x2_samples = [int(x2) for x2 in x2_samples]

    # apply capture efficiency: for each count, draw from Binomial(count, beta)
    x1_samples_beta = np.random.binomial(x1_samples, beta).tolist()
    x2_samples_beta = np.random.binomial(x2_samples, beta).tolist()

    # re-combine to pairs of samples
    samples = list(zip(x1_samples, x2_samples))
    samples_beta = list(zip(x1_samples_beta, x2_samples_beta))

    # plot sample paths
    if plot:
        x = np.linspace(0, tmax + (n - 1) * ts, 10000)
        plt.plot(x, x1_path_function(x), label="X1 sample path", color="blue")
        plt.plot(x, x2_path_function(x), label="X2 sample path", color="purple")
        #plt.axvline(tmax, label="Burn-in time", color="orange")
        plt.xlabel("Time")
        plt.ylabel("Counts")
        plt.legend()
        plt.show()

    # collect all sample paths: original and observed
    data = {
        'x1_OG': x1_samples,
        'x2_OG': x2_samples,
        'OG': samples,
        'x1_OB': x1_samples_beta,
        'x2_OB': x2_samples_beta,
        'OB': samples_beta
    }

    return data

## Code: Bootstrap 

### Probabilities

In [214]:
def bootstrap_probabilities(data, resamples=None, splits=1, thresh_OB=10, threshM_OB=10, plot=False, printing=False):
    '''
    Compute confidence intervals on the distribution of a sample of count pairs.

    Compute confidence intervals for the joint and marginal probabilities of the 
    sample using the percentile bootstrap and settings specified in the method
    object. Compute a state space truncation using a given threshold on the
    number of samples per interval, replacing intervals on probabilities of
    states outside the truncation by [0, 1] to improve coverage.

    Args:
        data: dict of information on integer counts of genes per cell
        method: instance of Hypothesis or Minimization class with settings
                stored as attributes

                .resamples: integer number of bootstrap resamples to use
                .splits: integer number of times to 'split' resampling across
                         multiple arrays to avoid memory issues
                .thresh_OB: threshold on observation frequency of a state pair
                            for state space truncation
                .threshM_OB: threshold on observation frequency on a state for
                             marginal state space truncation
        
        plot: toggle plotting of confidence intervals and estimates
        print: toggle printing of observed state space truncation

    Returns:
        A dictionary containing results

        Sample information:

        'sample': original sample used
        'sample_counts': occurances of each state pair in the original sample
        'sample_counts_x1': occurances of each state in the original sample (gene 1)
        'sample_counts_x2': occurances of each state in the original sample (gene 2)

        Confidence intervals:
    
        'joint': (2, _, _) numpy array of CI bounds on joint distribution
        'x1': (2, _) numpy array of CI bounds on marginal distribution (gene 1)
        'x2': (2, _) numpy array of CI bounds on marginal distribution (gene 2)

        Truncation information

        'min_x1_OB', 'max_x1_OB', 'min_x2_OB', 'max_x2_OB': joint truncation
        'minM_x1_OB', 'maxM_x1_OB': marginal truncation (gene 1)
        'minM_x2_OB', 'maxM_x2_OB': marginal truncation (gene 2)
        'thresh_flag': bool if joint state space was truncated
        'thresh_flag_x1': bool if marginal state space was truncated (gene 1)
        'thresh_flag_x2': bool if marginal state space was truncated (gene 2)
    '''

    # get sample size
    n = len(data['OB'])

    # get bootstrap size: default to sample size
    if resamples is None:
        resamples = n

    # initialize random generator
    rng = np.random.default_rng()

    # convert string to tuple if neccessary (pandas reading csv to string)
    #if type(sample[0]) == str:
    #    sample = [literal_eval(count_pair) for count_pair in sample]

    # compute maximum x1 and x2 values
    M = int(np.max(data['x1_OB']))
    N = int(np.max(data['x2_OB']))
    #M, N = np.max(sample, axis=0)
    #M, N = int(M), int(N)

    # map (x1, x2) pairs to integers: x2 + (N + 1) * x1
    integer_sample = np.array([x[1] + (N + 1)*x[0] for x in data['OB']], dtype='uint32')

    # maxiumum of integer sample
    D = (M + 1)*(N + 1) - 1

    # number of bootstrap samples per split (split to reduce memory usage)
    resamples_split = resamples // splits

    # setup count array
    counts = np.empty((resamples, M + 1, N + 1), dtype='uint32')

    # BS bootstrap samples: split into 'splits' number of BS_split x n arrays
    for split in range(splits):

        # BS_split bootstrap samples as BS_split x n array
        bootstrap_split = rng.choice(integer_sample, size=(resamples_split, n))

        # offset row i by (D + 1)i
        bootstrap_split += np.arange(resamples_split, dtype='uint32')[:, None]*(D + 1)

        # flatten, count occurances of each state and reshape, reversing map to give counts of each (x1, x2) pair
        counts_split = np.bincount(bootstrap_split.ravel(), minlength=resamples_split*(D + 1)).reshape(-1, M + 1, N + 1)

        # add to counts
        counts[(split * resamples_split):((split + 1) * resamples_split), :, :] = counts_split

    # sum over columns / rows to give counts (/n) of each x1 / x2 state
    x1_counts = counts.sum(axis=2)
    x2_counts = counts.sum(axis=1)

    # compute 2.5% and 97.5% quantiles for each p(x1, x2), p(x1) and p(x2)
    bounds = np.quantile(counts, [0.025, 0.975], axis=0)
    x1_bounds = np.quantile(x1_counts, [0.025, 0.975], axis=0)
    x2_bounds = np.quantile(x2_counts, [0.025, 0.975], axis=0)

    # scale to probability
    bounds = bounds / n
    x1_bounds = x1_bounds / n
    x2_bounds = x2_bounds / n

    # count occurances per (x1, x2) in the in original sample
    sample_counts = np.bincount(integer_sample, minlength=D + 1).reshape(M + 1, N + 1)

    # sum over columns / rows to give counts per x1 / x2 state
    x1_sample_counts = sample_counts.sum(axis=1)
    x2_sample_counts = sample_counts.sum(axis=0)

    # set truncation bounds
    min_x1_OB, max_x1_OB, min_x2_OB, max_x2_OB = M, 0, N, 0
    minM_x1_OB, maxM_x1_OB = M, 0
    minM_x2_OB, maxM_x2_OB = N, 0

    # set flag for changes
    thresh_flag = False
    thresh_flag_x1 = False
    thresh_flag_x2 = False

    # replace CI's for states below threshold occurances by [0, 1] bounds
    for x1 in range(M + 1):
        for x2 in range(N + 1):
            # below: replace
            if sample_counts[x1, x2] < thresh_OB:
                bounds[:, x1, x2] = [0.0, 1.0]
            # above: update truncation
            else:
                # check if smaller than current min
                if x1 < min_x1_OB:
                    min_x1_OB = x1
                    thresh_flag = True
                if x2 < min_x2_OB:
                    min_x2_OB = x2
                    thresh_flag = True
                # check if larger than current max
                if x1 > max_x1_OB:
                    max_x1_OB = x1
                    thresh_flag = True
                if x2 > max_x2_OB:
                    max_x2_OB = x2
                    thresh_flag = True

    for x1 in range(M + 1):
        # below: replace
        if x1_sample_counts[x1] < threshM_OB:
            x1_bounds[:, x1] = [0.0, 1.0]
        # above: update truncation
        else:
            # check if smaller than current min
            if x1 < minM_x1_OB:
                minM_x1_OB = x1
                thresh_flag_x1 = True
            # check if larger than current max
            if x1 > maxM_x1_OB:
                maxM_x1_OB = x1
                thresh_flag_x1 = True

    for x2 in range(N + 1):
        # below: replace
        if x2_sample_counts[x2] < threshM_OB:
            x2_bounds[:, x2] = [0.0, 1.0]
        # above: update truncation
        else:
            # check if smaller than current min
            if x2 < minM_x2_OB:
                minM_x2_OB = x2
                thresh_flag_x2 = True
            # check if larger than current max
            if x2 > maxM_x2_OB:
                maxM_x2_OB = x2
                thresh_flag_x2 = True

    # if no states were above threshold: default to max range, report
    if not thresh_flag:
        min_x1_OB, max_x1_OB, min_x2_OB, max_x2_OB = 0, M, 0, N
    if not thresh_flag_x1:
        minM_x1_OB, maxM_x1_OB = 0, M
    if not thresh_flag_x2:
        minM_x2_OB, maxM_x2_OB = 0, N

    # plotting
    if plot:
        fig, axs = plt.subplots(M + 1, N + 1, figsize=(10, 10))
        fig.tight_layout()
        for x1 in range(M + 1):
            for x2 in range(N + 1):
                # within truncation: green CI lines
                if (x1 >= min_x1_OB) and (x2 >= min_x2_OB) and (x1 <= max_x1_OB) and (x2 <= max_x2_OB):
                    color = "green"
                else:
                    color = "red"
                axs[x1, x2].hist(counts[:, x1, x2] / n)
                axs[x1, x2].set_title(f"p({x1}, {x2})")
                axs[x1, x2].axvline(bounds[0, x1, x2], color=color)
                axs[x1, x2].axvline(bounds[1, x1, x2], color=color)

        plt.suptitle("X1 X2 Confidence Intervals")
        plt.show()

        fig, axs = plt.subplots(1, M + 1, figsize=(10, 3))
        fig.tight_layout()
        for x1 in range(M + 1):
            # within truncation: green CI lines
            if (x1 >= minM_x1_OB) and (x1 <= maxM_x1_OB):
                color = "green"
            else:
                color = "red"
            axs[x1].hist(x1_counts[:, x1] / n)
            axs[x1].set_title(f"p({x1})")
            axs[x1].axvline(x1_bounds[0, x1], color=color)
            axs[x1].axvline(x1_bounds[1, x1], color=color)

        plt.suptitle("X1 Confidence Intervals")
        plt.show()

        fig, axs = plt.subplots(1, N + 1, figsize=(10, 3))
        fig.tight_layout()
        for x2 in range(N + 1):
            # within truncation: green CI lines
            if (x2 >= minM_x2_OB) and (x2 <= maxM_x2_OB):
                color = "green"
            else:
                color = "red"
            axs[x2].hist(x2_counts[:, x2] / n)
            axs[x2].set_title(f"p({x2})")
            axs[x2].axvline(x2_bounds[0, x2], color=color)
            axs[x2].axvline(x2_bounds[1, x2], color=color)

        plt.suptitle("X2 Confidence Intervals")
        plt.show()

    # printing
    if printing:
        print(f"Box truncation: [{min_x1_OB}, {max_x1_OB}] x [{min_x2_OB}, {max_x2_OB}]")
        print(f"Marginal x1 truncation: [{minM_x1_OB}, {maxM_x1_OB}]")
        print(f"Marginal x2 truncation: [{minM_x2_OB}, {maxM_x2_OB}]")

    # collect results
    result_dict =  {
        'data': data,
        'sample_counts': sample_counts,
        'sample_counts_x1': x1_sample_counts,
        'sample_counts_x2': x2_sample_counts,
        'joint': bounds,
        'x1': x1_bounds,
        'x2': x2_bounds,
        'min_x1_OB': min_x1_OB,
        'max_x1_OB': max_x1_OB,
        'min_x2_OB': min_x2_OB,
        'max_x2_OB': max_x2_OB,
        'minM_x1_OB': minM_x1_OB,
        'maxM_x1_OB': maxM_x1_OB,
        'minM_x2_OB': minM_x2_OB,
        'maxM_x2_OB': maxM_x2_OB,
        'thresh_flag': thresh_flag,
        'thresh_flag_x1': thresh_flag_x1,
        'thresh_flag_x2': thresh_flag_x2
    }

    return result_dict

### Moments

In [215]:
def bootstrap_moments(data, resamples=None):
    '''
    Compute confidence intervals on the moments of a sample of count pairs.

    Compute confidence intervals for the moments: mean, variance, cross moments,
    etc of the sample using the percentile bootstrap.

    Args:
        sample: list of tuples (x1, x2) of integer counts per cell
        resamples: integer number of bootstrap resamples to use

    Returns:
        A dictionary containing results

        'E_x1': CI bounds on E[X1]
        'E_x2': CI bounds on E[X2]
        'E_x1_x2': CI ounds on E[X1X2]
    '''

    # get sample size
    n = len(data['OB'])

    # get bootstrap size: default to sample size
    if resamples is None:
        resamples = n

    # initialize random generator
    rng = np.random.default_rng()

    # convert sample to n x 2 array
    sample = np.array([data['x1_OB'], data['x2_OB']]).T

    # bootstrap to resamples x n x 2 array
    boot = rng.choice(sample, size=(resamples, n))

    # mean over axis 1 to get E[X1], E[X2] for each resample
    means = np.mean(boot, axis=1)

    # product over axis 2 to get x1x2 counts
    prods = np.prod(boot, axis=2)

    # mean over axis 1 to get E[X1X2] for each resample
    prod_means = np.mean(prods, axis=1)
    
    # quantiles over resamples
    mean_bounds = np.quantile(means, [0.025, 0.975], axis=0)
    prod_mean_bounds = np.quantile(prod_means, [0.025, 0.975], axis=0)

    # collect information
    result_dict = {
        'E_x1': mean_bounds[:, 0],
        'E_x2': mean_bounds[:, 1],
        'E_x1_x2': prod_mean_bounds
    }

    return result_dict

## Testing

In [230]:
# settings
params = {
    'k_tx_1': 5,
    'k_tx_2': 5,
    'k_deg_1': 1,
    'k_deg_2': 1,
    'k_reg': 2
}
n = 1000
beta = 0.1

# simulate data
data = gillespie(params, n, beta)

In [231]:
# bootstrap probabilities
probabilities = bootstrap_probabilities(data)

In [232]:
# bootstrap moments
moments = bootstrap_moments(data)

In [233]:
# moment bounds from probabilities
cut_x1 = probabilities['maxM_x1_OB'] + 1
cut_x2 = probabilities['maxM_x2_OB'] + 1
prob_E_x1 = np.sum(probabilities['x1'][:, :cut_x1] * np.arange(cut_x1), axis=1)
prob_E_x2 = np.sum(probabilities['x2'][:, :cut_x2] * np.arange(cut_x2), axis=1)
prob_E_x1_x2 = np.sum((np.arange(cut_x1)[:, None] * np.arange(cut_x2)[None, :]) * probabilities['joint'][:, :cut_x1, :cut_x2], axis=(1, 2))

In [235]:
# compare
print("Moment bounds:\n")
print(f"E[X1] = ({moments['E_x1'][0]}, {moments['E_x1'][1]})")
print(f"E[X2] = ({moments['E_x2'][0]}, {moments['E_x2'][1]})")
print(f"E[X1X2] = ({moments['E_x1_x2'][0]}, {moments['E_x1_x2'][1]})")
print(f"E[X1]E[X2] = ({moments['E_x1'][0] * moments['E_x2'][0]}, {moments['E_x1'][1] * moments['E_x2'][1]})")

print("\nProbability moment bounds:\n")
print(f"E[X1] = ({prob_E_x1[0]}, {prob_E_x1[1]})")
print(f"E[X2] = ({prob_E_x2[0]}, {prob_E_x2[1]})")
print(f"E[X1X2] = ({prob_E_x1_x2[0]}, {prob_E_x1_x2[1]})")
print(f"E[X1]E[X2] = ({prob_E_x1[0] * prob_E_x2[0]}, {prob_E_x1[1] * prob_E_x2[1]})")

Moment bounds:

E[X1] = (0.137, 0.188)
E[X2] = (0.122, 0.17302499999999996)
E[X1X2] = (0.007, 0.025)
E[X1]E[X2] = (0.016714, 0.032528699999999994)

Probability moment bounds:

E[X1] = (0.12795, 0.195)
E[X2] = (0.113, 0.17804999999999996)
E[X1X2] = (0.0, 9.0)
E[X1]E[X2] = (0.014458350000000002, 0.034719749999999994)
