In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import gurobipy as gp
from gurobipy import GRB
import scipy
from ast import literal_eval
import json
import tqdm

In [2]:
rng = np.random.default_rng(2343)

## Code: truncation functions

In [21]:
def B(x1_OB, x2_OB, x1_OG, x2_OG, beta):
    '''Compute (1 / n) sum j = 1 to n of P(X1_OB, X2_OB | X1_OG, X2_OG, Beta_j): product of binomials.'''
    # currently just work with float beta

    return scipy.stats.binom.pmf(x1_OB, x1_OG, beta) * scipy.stats.binom.pmf(x2_OB, x2_OG, beta)

In [22]:
def BM(x_OB, x_OG, beta):
    '''Compute (1 / n) sum j = 1 to n of P(X_OB | X_OG, Beta_j): binomial prob.'''
    # currently just work with float beta

    return scipy.stats.binom.pmf(x_OB, x_OG, beta)

In [23]:
def findTrunc(x1_OB, x2_OB, beta, thresh_OG):
    '''
    Compute box truncation around states (x1_OG, x2_OG) which have
    B(x1_OB, x2_OB, x1_OG, x2_OG, beta) >= thresh_OG

    returns: min_x1_OG, max_x1_OG, min_x2_OG, max_x2_OG
    '''

    trunc_start = False
    trunc_end = False
    min_x1_OG, max_x1_OG, min_x2_OG, max_x2_OG = np.inf, 0, np.inf, 0
    diag = 0
    while (not trunc_start) or (not trunc_end):

        # start at top of grid
        x1_OG = x1_OB
        x2_OG = x2_OB + diag

        # flag if at least one coeff > thresh in diagonal
        trunc_diag = False

        # compute coeffs along diagonal
        while x2_OG >= x2_OB:

            # compute coeff
            coeff = B(x1_OB, x2_OB, x1_OG, x2_OG, beta)

            # above thresh
            if coeff >= thresh_OG:

                # update truncations
                if x1_OG < min_x1_OG:
                    min_x1_OG = x1_OG
                if x2_OG < min_x2_OG:
                    min_x2_OG = x2_OG
                if x1_OG > max_x1_OG:
                    max_x1_OG = x1_OG
                if x2_OG > max_x2_OG:
                    max_x2_OG = x2_OG

                # at least one coeff > thresh (overall)
                trunc_start = True

                # at least one coeff > thresh (in diag)
                trunc_diag = True

            # move down diagonal
            x2_OG -= 1
            x1_OG += 1

        # if NO coeff > thresh (in diag) AND at least one coeff > thresh (overall)
        if (not trunc_diag) and trunc_start:

            # end
            trunc_end = True

        # increment diagonal
        diag += 1

    return min_x1_OG, max_x1_OG, min_x2_OG, max_x2_OG

In [24]:
def findTruncM(x_OB, beta, threshM_OG):
    '''
    Compute interval truncation of states x_OG which have
    B(x_OB, x_OG, beta) >= threshM_OG
    
    returns: minM_OG, maxM_OG
    '''

    # start at first non-zero coefficient
    x_OG = x_OB
    coeff = BM(x_OB, x_OG, beta)

    # if not above threshold: increment until above
    while coeff < threshM_OG:

        # increment
        x_OG += 1

        # compute coeff
        coeff = BM(x_OB, x_OG, beta)

    # store first state coeff >= thresh
    minM_OG = x_OG

    # increment until below threshold
    while coeff >= threshM_OG:

        # increment
        x_OG += 1

        # compute coeff
        coeff = BM(x_OB, x_OG, beta)

    # store last state with coeff >= thresh (INCLUSIVE BOUND)
    maxM_OG = x_OG - 1

    return minM_OG, maxM_OG

In [25]:
def preComputeTruncation(M, beta, thresh_OG):
    '''
    Compute dict of original truncations

    M: max state of observed pairs that truncations are computed for
    beta: capture efficiency vector
    thresh_OG: threshold for trunction
    '''
    # store in dictionary (lookup table)
    truncations = {}

    # for each pair of observed counts
    for x1_OB in tqdm.tqdm(range(M + 1)):
        for x2_OB in range(x1_OB + 1):

            # compute truncation bounds
            min_x1_OG, max_x1_OG, min_x2_OG, max_x2_OG = findTrunc(x1_OB, x2_OB, beta, thresh_OG)

            # store
            truncations[f'({x1_OB}, {x2_OB})'] = (min_x1_OG, max_x1_OG, min_x2_OG, max_x2_OG)

            # store symmetric version
            truncations[f'({x2_OB}, {x1_OB})'] = (min_x2_OG, max_x2_OG, min_x1_OG, max_x1_OG)

    return truncations

In [26]:
def preComputeTruncationM(M, beta, threshM_OG):
    '''
    Compute dict of original truncations

    M: max observed state that truncations are computed for
    beta: capture efficiency vector
    threshM_OG: threshold for trunction
    '''
    # store in dictionary (lookup table)
    truncations = {}

    # for each pair of observed counts
    for x_OB in tqdm.tqdm(range(max)):

            # compute truncation bounds
            minM_OG, maxM_OG = findTruncM(x_OB, beta, threshM_OG)

            # store
            truncations[f'{x_OB}'] = (minM_OG, maxM_OG)

    return truncations

# Perfect information leads to perfect output?

Many of the challenges of our inference arise when dealing with the uncertainity and measurement error present in the observed data, and how to account for that when estimating values of distributions for use in optimization.

However, an equally important part is the optimization step, and the question: If we have the exact distribution values, does the optimzation produce exact bounds on the true parameter values. In other words: given perfect information can we obtain perfect output?

Seeing how close the results in this case are to perfect can tell us how important the bootstrap, and optimization parts of inference are.

## Computing perfect information

To test this we need to compute the exact stationary distribution values

To simplify calculations we focus on the case of no interaction $k_{reg} = 0$ and no observation error $\beta = 100\%$

In this case the stationary distribution of each count is a poisson distribution with parameter equal to transcription rate / degradation rate, and the joint stationary distribution factorises into a product of marginals

In [30]:
params = {
    'k_tx_1': 1,
    'k_tx_2': 1,
    'k_deg_1': 1,
    'k_deg_2': 1,
    'k_reg': 0
}

min_x1_OB, max_x1_OB = 0, 3
min_x2_OB, max_x2_OB = 0, 3

bounds = {}
bounds['x1'] = scipy.stats.poisson.pmf(np.array([x1 for x1 in range(min_x1_OB, max_x1_OB + 1)]), params['k_tx_1'] / params['k_deg_1'])
bounds['x2'] = scipy.stats.poisson.pmf(np.array([x2 for x2 in range(min_x2_OB, max_x2_OB + 1)]), params['k_tx_2'] / params['k_deg_2'])
bounds['joint'] = bounds['x1'].reshape(-1, 1) @ bounds['x2'].reshape(1, -1)
bounds['min_x1_OB'] = min_x1_OB
bounds['max_x1_OB'] = max_x1_OB
bounds['min_x2_OB'] = min_x2_OB
bounds['max_x2_OB'] = max_x2_OB

## Code: distribution computation

In [79]:
def computeDist(params, min_x1_OB, max_x1_OB, min_x2_OB, max_x2_OB):
    '''Compute exact values of marginal and joint stationary distributions'''

    bounds = {}
    bounds['x1'] = scipy.stats.poisson.pmf(np.array([x1 for x1 in range(min_x1_OB, max_x1_OB + 1)]), params['k_tx_1'] / params['k_deg_1'])
    bounds['x2'] = scipy.stats.poisson.pmf(np.array([x2 for x2 in range(min_x2_OB, max_x2_OB + 1)]), params['k_tx_2'] / params['k_deg_2'])
    bounds['joint'] = bounds['x1'].reshape(-1, 1) @ bounds['x2'].reshape(1, -1)
    bounds['min_x1_OB'] = min_x1_OB
    bounds['max_x1_OB'] = max_x1_OB
    bounds['min_x2_OB'] = min_x2_OB
    bounds['max_x2_OB'] = max_x2_OB

    return bounds

## Code: marginal optimization

In [45]:
def optimization_hyp_single(bounds, beta, truncationsM, gene, K=100, silent=True,
                     print_solution=True, print_truncation=True, threshM_OG=10**-6,
                     time_limit=300):

    # create model
    md = gp.Model('birth-death-regulation-capture-efficiency-hyp')

    # set options
    if silent:
        md.Params.LogToConsole = 0

    # set time limit: 5 minute default
    md.Params.TimeLimit = time_limit

    # State space truncations

    # marginal observed truncation
    min_OB = bounds[f'min_x{gene}_OB']
    max_OB = bounds[f'max_x{gene}_OB']

    # original truncations: find largest original states needed (to define variables)
    overall_min_OG, overall_max_OG = np.inf, 0

    # for each marginal state used
    for x_OB in range(min_OB, max_OB + 1):

        try:
            # lookup original truncation
            min_OG, max_OG = truncationsM[f'{x_OB}']

        except KeyError:
            # compute if not available
            min_OG, max_OG = findTruncM(x_OB, beta, threshM_OG)

            # store
            truncationsM[f'{x_OB}'] = (min_OG, max_OG)

        # update overall min and max
        if max_OG > overall_max_OG:
            overall_max_OG = max_OG
        if min_OG < overall_min_OG:
            overall_min_OG = min_OG
    
    if print_truncation:
        print(f"Observed counts: [{min_OB}, {max_OB}]")
        print(f"Original counts: [{overall_min_OG}, {overall_max_OG}]")

    # variables

    # marginal stationary distributions: original counts (size = largest original state + 1)
    p = md.addMVar(shape=(overall_max_OG + 1), vtype=GRB.CONTINUOUS, name="p", lb=0, ub=1)

    '''aggressive presolve to hopefully ensure this'''
    md.Params.Presolve = 2

    # reaction rate constants
    rate_names = ['k_tx', 'k_deg']
    rates = md.addVars(rate_names, vtype=GRB.CONTINUOUS, lb=0, ub=K, name=rate_names)

    # constraints

    # fix k_deg_1 = 1 for identifiability
    md.addConstr(rates['k_deg'] == 1)

    # distributional constraints
    md.addConstr(p.sum() <= 1, name="Distribution")
    
    # marginal stationary distribution bounds: for each observed count
    for x_OB in range(min_OB, max_OB + 1):

        # original truncation: lookup from pre-computed dict
        min_OG, max_OG = truncationsM[f'{x_OB}']

        # sum over truncation range (INCLUSIVE)
        sum_expr = gp.quicksum([BM(x_OB, x_OG, beta) * p[x_OG] for x_OG in range(min_OG, max_OG + 1)])

        md.addConstr(sum_expr >= bounds[f'x{gene}'][x_OB], name=f"B marginal lb {x_OB}")
        md.addConstr(sum_expr <= bounds[f'x{gene}'][x_OB], name=f"B marginal ub {x_OB}")

    # CME
    for x_OG in range(overall_max_OG):
        if x_OG == 0:
            x_zero = 0
        else:
            x_zero = 1

        md.addConstr(
            rates['k_tx'] * x_zero * p[x_OG - 1] + \
            rates['k_deg'] * (x_OG + 1) * p[x_OG + 1] - \
            (rates['k_tx'] + rates['k_deg'] * x_OG) * p[x_OG] == 0,
            name=f"Marginal CME {x_OG}"
        )

    # status of optimization
    status_codes = {1: 'LOADED',
                    2: 'OPTIMAL',
                    3: 'INFEASIBLE',
                    4: 'INF_OR_UNBD',
                    5: 'UNBOUNDED',
                    6: 'CUTOFF',
                    7: 'ITERATION_LIMIT',
                    8: 'NODE_LIMIT',
                    9: 'TIME_LIMIT',
                    10: 'SOLUTION_LIMIT',
                    11: 'INTERRUPTED',
                    12: 'NUMERIC',
                    13: 'SUBOPTIMAL',
                    14: 'INPROGRESS',
                    15: 'USER_OBJ_LIMIT'}

    # solution dict
    solution = {
        'status': None,
        'k_tx': None,
        'k_deg': 1,
        'min_time': None,
        'max_time': None
    }

    # optimize

    # minimize
    md.setObjective(rates['k_tx'], GRB.MINIMIZE)
    try:
        md.optimize()
        min_val = md.ObjVal
        min_status = status_codes[md.status]
        min_time = md.Runtime
    except:
        min_val = None
        min_status = status_codes[md.status]
        min_time = md.Runtime

    # maximize
    md.setObjective(rates['k_tx'], GRB.MAXIMIZE)
    try:
        md.optimize()
        max_val = md.ObjVal
        max_status = status_codes[md.status]
        max_time = md.Runtime
    except:
        max_val = None
        max_status = status_codes[md.status]
        max_time = md.Runtime

    # store
    solution['k_tx'] = [min_val, min_status, max_val, max_status]
    solution['min_time'] = min_time
    solution['max_time'] = max_time

    if print_solution:
        print(f"k_tx lower bound: {solution['k_tx'][0]}, status: {solution['k_tx'][1]}, time: {solution['min_time']}")
        print(f"k_tx upper bound: {solution['k_tx'][2]}, status: {solution['k_tx'][3]}, time: {solution['max_time']}")
        print(f"k_deg: 1")

    # save runtime
    solution['time'] = md.Runtime

    return solution

## Test: marginal optimization, perfect information

In [139]:
# reaction rates: no interaction
params = {
    'k_tx_1': 10,
    'k_tx_2': 2,
    'k_deg_1': 1,
    'k_deg_2': 1,
    'k_reg': 0
}

# 100% capture efficiency
beta = 1.0

# truncations (will be trivial i.e. x_OB: (x_OB, x_OB) as 100% capture)
truncationsM = {}

# compute exact values ('bounds')
bounds = computeDist(params, 0, 1, 0, 1)

In [None]:
solution_x1 = optimization_hyp_single(bounds, beta, truncationsM, gene=1, K=100, silent=True,
                     print_solution=True, print_truncation=False, threshM_OG=10**-6,
                     time_limit=300)

k_tx lower bound: 10.000000000000002, status: OPTIMAL, time: 0.0
k_tx upper bound: 10.000000000000002, status: OPTIMAL, time: 0.0
k_deg: 1


In [138]:
solution_x2 = optimization_hyp_single(bounds, beta, truncationsM, gene=2, K=100, silent=True,
                     print_solution=True, print_truncation=False, threshM_OG=10**-6,
                     time_limit=300)

k_tx lower bound: 2.0, status: OPTIMAL, time: 0.0
k_tx upper bound: 2.0, status: OPTIMAL, time: 0.0
k_deg: 1
