In [20]:
from interaction_inference import simulation
import numpy as np
import matplotlib.pyplot as plt
import scipy
import tqdm
from ast import literal_eval
import json
import gurobipy as gp
from gurobipy import GRB

In [19]:
status_codes = {
    1: 'LOADED',
    2: 'OPTIMAL',
    3: 'INFEASIBLE',
    4: 'INF_OR_UNBD',
    5: 'UNBOUNDED',
    6: 'CUTOFF',
    7: 'ITERATION_LIMIT',
    8: 'NODE_LIMIT',
    9: 'TIME_LIMIT',
    10: 'SOLUTION_LIMIT',
    11: 'INTERRUPTED',
    12: 'NUMERIC',
    13: 'SUBOPTIMAL',
    14: 'INPROGRESS',
    15: 'USER_OBJ_LIMIT'
}

In [4]:
rng = np.random.default_rng(7)

## Observed Probability Bootstrap (Existing work)

In [125]:
def bootstrap_probabilities(sample, resamples=None, splits=1, thresh_OB=10, threshM_OB=10, plot=False, printing=False):

    # get sample size
    n = len(sample)

    # get bootstrap size: default to sample size
    if resamples is None:
        resamples = n

    # initialize random generator
    rng = np.random.default_rng()

    # convert string to tuple if neccessary (pandas reading csv to string)
    if type(sample[0]) == str:
        sample = [literal_eval(count_pair) for count_pair in sample]

    # compute maximum x1 and x2 values
    M, N = np.max(sample, axis=0)
    M, N = int(M), int(N)

    # map (x1, x2) pairs to integers: x2 + (N + 1) * x1
    integer_sample = np.array([x[1] + (N + 1)*x[0] for x in sample], dtype='uint32')

    # maxiumum of integer sample
    D = (M + 1)*(N + 1) - 1

    # number of bootstrap samples per split (split to reduce memory usage)
    resamples_split = resamples // splits

    # setup count array
    counts = np.empty((resamples, M + 1, N + 1), dtype='uint32')

    # BS bootstrap samples: split into 'splits' number of BS_split x n arrays
    for split in range(splits):

        # BS_split bootstrap samples as BS_split x n array
        bootstrap_split = rng.choice(integer_sample, size=(resamples_split, n))

        # offset row i by (D + 1)i
        bootstrap_split += np.arange(resamples_split, dtype='uint32')[:, None]*(D + 1)

        # flatten, count occurances of each state and reshape, reversing map to give counts of each (x1, x2) pair
        counts_split = np.bincount(bootstrap_split.ravel(), minlength=resamples_split*(D + 1)).reshape(-1, M + 1, N + 1)

        # add to counts
        counts[(split * resamples_split):((split + 1) * resamples_split), :, :] = counts_split

    # sum over columns / rows to give counts (/n) of each x1 / x2 state
    x1_counts = counts.sum(axis=2)
    x2_counts = counts.sum(axis=1)

    # compute 2.5% and 97.5% quantiles for each p(x1, x2), p(x1) and p(x2)
    bounds = np.quantile(counts, [0.025, 0.975], axis=0)
    x1_bounds = np.quantile(x1_counts, [0.025, 0.975], axis=0)
    x2_bounds = np.quantile(x2_counts, [0.025, 0.975], axis=0)

    # scale to probability
    bounds = bounds / n
    x1_bounds = x1_bounds / n
    x2_bounds = x2_bounds / n

    # count occurances per (x1, x2) in the in original sample
    sample_counts = np.bincount(integer_sample, minlength=D + 1).reshape(M + 1, N + 1)

    # sum over columns / rows to give counts per x1 / x2 state
    x1_sample_counts = sample_counts.sum(axis=1)
    x2_sample_counts = sample_counts.sum(axis=0)

    # set truncation bounds
    min_x1_OB, max_x1_OB, min_x2_OB, max_x2_OB = M, 0, N, 0
    minM_x1_OB, maxM_x1_OB = M, 0
    minM_x2_OB, maxM_x2_OB = N, 0

    # set flag for changes
    thresh_flag = False
    thresh_flag_x1 = False
    thresh_flag_x2 = False

    # replace CI's for states below threshold occurances by [0, 1] bounds
    for x1 in range(M + 1):
        for x2 in range(N + 1):
            # below: replace
            if sample_counts[x1, x2] < thresh_OB:
                bounds[:, x1, x2] = [0.0, 1.0]
            # above: update truncation
            else:
                # check if smaller than current min
                if x1 < min_x1_OB:
                    min_x1_OB = x1
                    thresh_flag = True
                if x2 < min_x2_OB:
                    min_x2_OB = x2
                    thresh_flag = True
                # check if larger than current max
                if x1 > max_x1_OB:
                    max_x1_OB = x1
                    thresh_flag = True
                if x2 > max_x2_OB:
                    max_x2_OB = x2
                    thresh_flag = True

    for x1 in range(M + 1):
        # below: replace
        if x1_sample_counts[x1] < threshM_OB:
            x1_bounds[:, x1] = [0.0, 1.0]
        # above: update truncation
        else:
            # check if smaller than current min
            if x1 < minM_x1_OB:
                minM_x1_OB = x1
                thresh_flag_x1 = True
            # check if larger than current max
            if x1 > maxM_x1_OB:
                maxM_x1_OB = x1
                thresh_flag_x1 = True

    for x2 in range(N + 1):
        # below: replace
        if x2_sample_counts[x2] < threshM_OB:
            x2_bounds[:, x2] = [0.0, 1.0]
        # above: update truncation
        else:
            # check if smaller than current min
            if x2 < minM_x2_OB:
                minM_x2_OB = x2
                thresh_flag_x2 = True
            # check if larger than current max
            if x2 > maxM_x2_OB:
                maxM_x2_OB = x2
                thresh_flag_x2 = True

    # if no states were above threshold: default to max range, report
    if not thresh_flag:
        min_x1_OB, max_x1_OB, min_x2_OB, max_x2_OB = 0, M, 0, N
    if not thresh_flag_x1:
        minM_x1_OB, maxM_x1_OB = 0, M
    if not thresh_flag_x2:
        minM_x2_OB, maxM_x2_OB = 0, N

    # plotting
    if plot:
        fig, axs = plt.subplots(M + 1, N + 1, figsize=(10, 10))
        fig.tight_layout()
        for x1 in range(M + 1):
            for x2 in range(N + 1):
                # within truncation: green CI lines
                if (x1 >= min_x1_OB) and (x2 >= min_x2_OB) and (x1 <= max_x1_OB) and (x2 <= max_x2_OB):
                    color = "green"
                else:
                    color = "red"
                axs[x1, x2].hist(counts[:, x1, x2] / n)
                axs[x1, x2].set_title(f"p({x1}, {x2})")
                axs[x1, x2].axvline(bounds[0, x1, x2], color=color)
                axs[x1, x2].axvline(bounds[1, x1, x2], color=color)

        plt.suptitle("X1 X2 Confidence Intervals")
        plt.show()

        fig, axs = plt.subplots(1, M + 1, figsize=(10, 3))
        fig.tight_layout()
        for x1 in range(M + 1):
            # within truncation: green CI lines
            if (x1 >= minM_x1_OB) and (x1 <= maxM_x1_OB):
                color = "green"
            else:
                color = "red"
            axs[x1].hist(x1_counts[:, x1] / n)
            axs[x1].set_title(f"p({x1})")
            axs[x1].axvline(x1_bounds[0, x1], color=color)
            axs[x1].axvline(x1_bounds[1, x1], color=color)

        plt.suptitle("X1 Confidence Intervals")
        plt.show()

        fig, axs = plt.subplots(1, N + 1, figsize=(10, 3))
        fig.tight_layout()
        for x2 in range(N + 1):
            # within truncation: green CI lines
            if (x2 >= minM_x2_OB) and (x2 <= maxM_x2_OB):
                color = "green"
            else:
                color = "red"
            axs[x2].hist(x2_counts[:, x2] / n)
            axs[x2].set_title(f"p({x2})")
            axs[x2].axvline(x2_bounds[0, x2], color=color)
            axs[x2].axvline(x2_bounds[1, x2], color=color)

        plt.suptitle("X2 Confidence Intervals")
        plt.show()

    # printing
    if printing:
        print(f"Box truncation: [{min_x1_OB}, {max_x1_OB}] x [{min_x2_OB}, {max_x2_OB}]")
        print(f"Marginal x1 truncation: [{minM_x1_OB}, {maxM_x1_OB}]")
        print(f"Marginal x2 truncation: [{minM_x2_OB}, {maxM_x2_OB}]")

    # collect results
    truncation_OB = {
        'min_x1_OB': min_x1_OB,
        'max_x1_OB': max_x1_OB,
        'min_x2_OB': min_x2_OB,
        'max_x2_OB': max_x2_OB
    }
    truncationM_OB = {
        'minM_x1_OB': minM_x1_OB,
        'maxM_x1_OB': maxM_x1_OB,
        'minM_x2_OB': minM_x2_OB,
        'maxM_x2_OB': maxM_x2_OB
    }

    result_dict = {
        'bounds': bounds,
        'x1_bounds': x1_bounds,
        'x2_bounds': x2_bounds,
        'truncation_OB': truncation_OB,
        'truncationM_OB': truncationM_OB
    }

    return result_dict

## $f_{m}$ Bootstrap

In [441]:
def bootstrap_f(sample, beta, thresh_OB=10, threshM_OB=10, resamples=None, printing=False):

    # get sample size
    n = len(sample)

    # get bootstrap size: default to sample size
    if resamples is None:
        resamples = n

    # initialize random generator
    rng = np.random.default_rng()

    # convert string to tuple if neccessary (pandas reading csv to string)
    if type(sample[0]) == str:
        sample = [literal_eval(count_pair) for count_pair in sample]

    # compute maximum x1 and x2 values
    M, N = np.max(sample, axis=0)
    M, N = int(M), int(N)

    # map (x1, x2) pairs to integers: x2 + (N + 1) * x1
    integer_sample = np.array([x[1] + (N + 1)*x[0] for x in sample], dtype='uint32')

    # maxiumum of integer sample
    D = (M + 1)*(N + 1) - 1

    # setup f arrays
    fm1m2 = np.zeros((2, M + 1, N + 1))
    fm1 = np.zeros((2, M + 1))
    fm2 = np.zeros((2, N + 1))

    # loop over states
    for m1 in range(M + 1):
        for m2 in range(N + 1):

            # capture for cells with counts (m1, m2)
            beta_m = beta[(sample == np.array([m1, m2])).sum(axis=1) == 2]

            # if empty
            if beta_m.size == 0:
                
                # store [0, 1] bounds
                fm1m2[:, m1, m2] = [0.0, 1.0]

                # move to next pair
                continue

            # bootstrap resample
            boot = rng.choice(beta_m, size=(resamples, len(beta_m)))

            # estimate E[beta|(m1, m2)]
            estimates = boot.mean(axis=1)

            # quantile for confidence intervals
            interval = np.quantile(estimates, [0.025, 0.975], axis=0)

            # store
            fm1m2[:, m1, m2] = interval

    # marginals
    x1_sample = np.array([x[0] for x in sample])
    x2_sample = np.array([x[1] for x in sample])

    for m1 in range(M + 1):

        # capture for cells with count m1
        beta_m = beta[x1_sample == m1]

        # if empty
        if beta_m.size == 0:
            
            # store [0, 1] bounds
            fm1[:, m1] = [0.0, 1.0]

            # move to next pair
            continue

        # bootstrap resample
        boot = rng.choice(beta_m, size=(resamples, len(beta_m)))

        # estimate E[beta|m]
        estimates = boot.mean(axis=1)

        # quantile for confidence intervals
        interval = np.quantile(estimates, [0.025, 0.975], axis=0)

        # store
        fm1[:, m1] = interval

    for m2 in range(N + 1):

        # capture for cells with count m2
        beta_m = beta[x2_sample == m2]

        # if empty
        if beta_m.size == 0:
            
            # store [0, 1] bounds
            fm2[:, m2] = [0.0, 1.0]

            # move to next pair
            continue

        # bootstrap resample
        boot = rng.choice(beta_m, size=(resamples, len(beta_m)))

        # estimate E[beta|m]
        estimates = boot.mean(axis=1)

        # quantile for confidence intervals
        interval = np.quantile(estimates, [0.025, 0.975], axis=0)

        # store
        fm2[:, m2] = interval

    # count occurances per (x1, x2) in the in original sample
    sample_counts = np.bincount(integer_sample, minlength=D + 1).reshape(M + 1, N + 1)

    # sum over columns / rows to give counts per x1 / x2 state
    x1_sample_counts = sample_counts.sum(axis=1)
    x2_sample_counts = sample_counts.sum(axis=0)
    
    # set truncation bounds
    min_x1_OB, max_x1_OB, min_x2_OB, max_x2_OB = M, 0, N, 0
    minM_x1_OB, maxM_x1_OB = M, 0
    minM_x2_OB, maxM_x2_OB = N, 0

    # set flag for changes
    thresh_flag = False
    thresh_flag_x1 = False
    thresh_flag_x2 = False

    # replace CI's for states below threshold occurances by [0, 1] bounds
    for x1 in range(M + 1):
        for x2 in range(N + 1):
            # below: replace
            if sample_counts[x1, x2] < thresh_OB:
                fm1m2[:, x1, x2] = [0.0, 1.0]
            # above: update truncation
            else:
                # check if smaller than current min
                if x1 < min_x1_OB:
                    min_x1_OB = x1
                    thresh_flag = True
                if x2 < min_x2_OB:
                    min_x2_OB = x2
                    thresh_flag = True
                # check if larger than current max
                if x1 > max_x1_OB:
                    max_x1_OB = x1
                    thresh_flag = True
                if x2 > max_x2_OB:
                    max_x2_OB = x2
                    thresh_flag = True

    for x1 in range(M + 1):
        # below: replace
        if x1_sample_counts[x1] < threshM_OB:
            fm1[:, x1] = [0.0, 1.0]
        # above: update truncation
        else:
            # check if smaller than current min
            if x1 < minM_x1_OB:
                minM_x1_OB = x1
                thresh_flag_x1 = True
            # check if larger than current max
            if x1 > maxM_x1_OB:
                maxM_x1_OB = x1
                thresh_flag_x1 = True

    for x2 in range(N + 1):
        # below: replace
        if x2_sample_counts[x2] < threshM_OB:
            fm2[:, x2] = [0.0, 1.0]
        # above: update truncation
        else:
            # check if smaller than current min
            if x2 < minM_x2_OB:
                minM_x2_OB = x2
                thresh_flag_x2 = True
            # check if larger than current max
            if x2 > maxM_x2_OB:
                maxM_x2_OB = x2
                thresh_flag_x2 = True

    # if no states were above threshold: default to max range, report
    if not thresh_flag:
        min_x1_OB, max_x1_OB, min_x2_OB, max_x2_OB = 0, M, 0, N
    if not thresh_flag_x1:
        minM_x1_OB, maxM_x1_OB = 0, M
    if not thresh_flag_x2:
        minM_x2_OB, maxM_x2_OB = 0, N

    # printing
    if printing:
        print(f"Box truncation: [{min_x1_OB}, {max_x1_OB}] x [{min_x2_OB}, {max_x2_OB}]")
        print(f"Marginal x1 truncation: [{minM_x1_OB}, {maxM_x1_OB}]")
        print(f"Marginal x2 truncation: [{minM_x2_OB}, {maxM_x2_OB}]")

    # collect results
    truncation_OB = {
        'min_x1_OB': min_x1_OB,
        'max_x1_OB': max_x1_OB,
        'min_x2_OB': min_x2_OB,
        'max_x2_OB': max_x2_OB
    }
    truncationM_OB = {
        'minM_x1_OB': minM_x1_OB,
        'maxM_x1_OB': maxM_x1_OB,
        'minM_x2_OB': minM_x2_OB,
        'maxM_x2_OB': maxM_x2_OB
    }

    result_dict = {
        'fm1m2': fm1m2,
        'fm1': fm1,
        'fm2': fm2,
        'truncation_OB': truncation_OB,
        'truncationM_OB': truncationM_OB
    }

    return result_dict

## Downsampled Scale Optimization (Marginals)

In [442]:
def optimize(probs, fs, silent=True, time_limit=300, print_solution=True):

    # WLS license
    options = json.load(open("C:/WLS_credentials.json"))

    # silent
    if silent:
        options['OutputFlag'] = 0

    # environment context
    with gp.Env(params=options) as env:

        # model context
        with gp.Model('test-construction', env=env) as model:

            # model settings
            model.Params.TimeLimit = time_limit
            K = 100

            # size
            M1 = probs['truncationM_OB']['maxM_x1_OB']
            M2 = probs['truncationM_OB']['maxM_x2_OB']

            # variables
            pd1 = model.addMVar(shape=M1 + 1, vtype=GRB.CONTINUOUS, name="pd1", lb=0, ub=1)
            pd2 = model.addMVar(shape=M2 + 1, vtype=GRB.CONTINUOUS, name="pd2", lb=0, ub=1)
            k_tx_1 = model.addVar(vtype=GRB.CONTINUOUS, name="k_tx_1", lb=0, ub=K)
            k_tx_2 = model.addVar(vtype=GRB.CONTINUOUS, name="k_tx_2", lb=0, ub=K)
            fm1 = model.addMVar(shape=M1 + 1, vtype=GRB.CONTINUOUS, name="fm1", lb=0, ub=1)
            fm2 = model.addMVar(shape=M2 + 1, vtype=GRB.CONTINUOUS, name="fm2", lb=0, ub=1)

            # constraints

            # base
            model.addConstr(pd1.sum() <= 1, name="Dist_pd1")
            model.addConstr(pd2.sum() <= 1, name="Dist_pd2")

            # probabilities
            model.addConstr(pd1 <= probs['x1_bounds'][1, :M1 + 1], name="pd1_UB")
            model.addConstr(pd1 >= probs['x1_bounds'][0, :M1 + 1], name="pd1_LB")
            model.addConstr(pd2 <= probs['x2_bounds'][1, :M2 + 1], name="pd2_UB")
            model.addConstr(pd2 >= probs['x2_bounds'][0, :M2 + 1], name="pd2_LB")

            # CME rates
            model.addConstr(fm1 <= fs['fm1'][1, :M1 + 1], name="fm1_UB")
            model.addConstr(fm1 >= fs['fm1'][0, :M1 + 1], name="fm1_LB")
            model.addConstr(fm2 <= fs['fm2'][1, :M2 + 1], name="fm2_UB")
            model.addConstr(fm2 >= fs['fm2'][0, :M2 + 1], name="fm2_LB")

            # CME
            z = model.addVar()
            model.addConstr(z == 0)

            for m1 in range(M1 - 1):

                model.addConstr(
                    z == (m1 != 0)*fm1[m1 - 1]*k_tx_1*pd1[m1 - 1] - fm1[m1]*pd1[m1] + (m1 + 1)*pd1[m1 + 1] - m1*pd1[m1],
                    name=f"CME_x1_{m1}"
                )

            for m2 in range(M2 - 1):

                model.addConstr(
                    z == (m2 != 0)*fm2[m2 - 1]*k_tx_2*pd2[m2 - 1] - fm2[m2]*pd2[m2] + (m2 + 1)*pd2[m2 + 1] - m2*pd2[m2],
                    name=f"CME_x2_{m2}"
                )

            # optimize: test feasibility
            model.setObjective(0, GRB.MINIMIZE)
            try:
                model.optimize()
            except gp.GurobiError:
                print("GurobiError")

            # collect solution information
            solution = {
                'status': status_codes[model.status],
                'time': model.Runtime
            }

            if solution['status'] == "INFEASIBLE":
                model.computeIIS()
                model.write('iis-dsp.ilp')

    # print
    if print_solution:
        print(f"Optimization status: {solution['status']}")
        print(f"Runtime: {solution['time']}")

    return solution

In [443]:
# model parameters
params = {
    'k_tx_1': 1,
    'k_tx_2': 1,
    'k_deg_1': 1,
    'k_deg_2': 1,
    'k_reg': 0.25
}

# simulate sample
sample = simulation.gillespie_birth_death(params, 1000)

# capture efficiencies
beta = rng.beta(1, 2, size=1000)

# downsample
x1_sample = [x[0] for x in sample]
x2_sample = [x[1] for x in sample]
x1_downsampled = rng.binomial(x1_sample, beta)
x2_downsampled = rng.binomial(x2_sample, beta)
downsampled = list(zip(x1_downsampled, x2_downsampled))

In [444]:
probs = bootstrap_probabilities(downsampled, thresh_OB=10, threshM_OB=10, printing=True)

Box truncation: [0, 2] x [0, 2]
Marginal x1 truncation: [0, 2]
Marginal x2 truncation: [0, 2]


In [445]:
fs = bootstrap_f(downsampled, beta, thresh_OB=10, threshM_OB=10, printing=True)

Box truncation: [0, 2] x [0, 2]
Marginal x1 truncation: [0, 2]
Marginal x2 truncation: [0, 2]


In [446]:
sol = optimize(probs, fs)

Optimization status: INFEASIBLE
Runtime: 0.004999876022338867


Incredible performance, using only marginals able to detect interaction strength as low as ~0.25 at Beta(1, 2) capture, which rivals even moment detection performance.

In [450]:
for k_reg in [0.5, 0.4, 0.3, 0.2, 0.15, 0.1, 0.05, 0.0]:

    print(f"k_reg: {k_reg}")

    # model parameters
    params = {
        'k_tx_1': 1,
        'k_tx_2': 1,
        'k_deg_1': 1,
        'k_deg_2': 1,
        'k_reg': k_reg
    }

    # simulate sample
    sample = simulation.gillespie_birth_death(params, 1000)

    # capture efficiencies
    beta = rng.beta(1, 2, size=1000)

    # downsample
    x1_sample = [x[0] for x in sample]
    x2_sample = [x[1] for x in sample]
    x1_downsampled = rng.binomial(x1_sample, beta)
    x2_downsampled = rng.binomial(x2_sample, beta)
    downsampled = list(zip(x1_downsampled, x2_downsampled))

    # bootstrap probabilities
    probs = bootstrap_probabilities(downsampled, thresh_OB=10, threshM_OB=10, printing=False)

    # bootstrap f's
    fs = bootstrap_f(downsampled, beta, thresh_OB=10, threshM_OB=10, printing=False)

    # optimize
    sol = optimize(probs, fs)

k_reg: 0.5
Optimization status: INFEASIBLE
Runtime: 0.0010001659393310547
k_reg: 0.4
Optimization status: INFEASIBLE
Runtime: 0.0009999275207519531
k_reg: 0.3
Optimization status: INFEASIBLE
Runtime: 0.0010001659393310547
k_reg: 0.2
Optimization status: INFEASIBLE
Runtime: 0.0009999275207519531
k_reg: 0.15
Optimization status: OPTIMAL
Runtime: 0.003000020980834961
k_reg: 0.1
Optimization status: OPTIMAL
Runtime: 0.009000062942504883
k_reg: 0.05
Optimization status: OPTIMAL
Runtime: 0.003000020980834961
k_reg: 0.0
Optimization status: OPTIMAL
Runtime: 0.00800013542175293


## Downsampled Scale Optimization (Joint)

Does not appear to perform better than marginals

In [451]:
def optimize_joint(probs, fs, silent=True, time_limit=300, print_solution=True, marginal=True, joint=True):

    # WLS license
    options = json.load(open("C:/WLS_credentials.json"))

    # silent
    if silent:
        options['OutputFlag'] = 0

    # environment context
    with gp.Env(params=options) as env:

        # model context
        with gp.Model('test-construction', env=env) as model:

            # model settings
            model.Params.TimeLimit = time_limit
            K = 100

            # size
            M1 = probs['truncationM_OB']['maxM_x1_OB']
            M2 = probs['truncationM_OB']['maxM_x2_OB']

            # variables
            pd = model.addMVar(shape=(M1 + 1, M2 + 1), vtype=GRB.CONTINUOUS, name="pd", lb=0, ub=1)
            pd1 = model.addMVar(shape=M1 + 1, vtype=GRB.CONTINUOUS, name="pd1", lb=0, ub=1)
            pd2 = model.addMVar(shape=M2 + 1, vtype=GRB.CONTINUOUS, name="pd2", lb=0, ub=1)

            k_tx_1 = model.addVar(vtype=GRB.CONTINUOUS, name="k_tx_1", lb=0, ub=K)
            k_tx_2 = model.addVar(vtype=GRB.CONTINUOUS, name="k_tx_2", lb=0, ub=K)

            fm = model.addMVar(shape=(M1 + 1, M2 + 1), vtype=GRB.CONTINUOUS, name="fm", lb=0, ub=1)
            fm1 = model.addMVar(shape=M1 + 1, vtype=GRB.CONTINUOUS, name="fm1", lb=0, ub=1)
            fm2 = model.addMVar(shape=M2 + 1, vtype=GRB.CONTINUOUS, name="fm2", lb=0, ub=1)

            # constraints

            # base
            model.addConstr(pd1.sum() <= 1, name="Dist_pd1")
            model.addConstr(pd2.sum() <= 1, name="Dist_pd2")

            model.addConstr(pd.sum() <= 1, name="Dist_pd")

            # probabilities
            if marginal:
                model.addConstr(pd1 <= probs['x1_bounds'][1, :M1 + 1], name="pd1_UB")
                model.addConstr(pd1 >= probs['x1_bounds'][0, :M1 + 1], name="pd1_LB")
                model.addConstr(pd2 <= probs['x2_bounds'][1, :M2 + 1], name="pd2_UB")
                model.addConstr(pd2 >= probs['x2_bounds'][0, :M2 + 1], name="pd2_LB")

            # joint
            if joint:
                model.addConstr(pd <= probs['bounds'][1, :M1 + 1, :M2 + 1], name="pd_UB")
                model.addConstr(pd >= probs['bounds'][0, :M1 + 1, :M2 + 1], name="pd_LB")

            # CME rates
            model.addConstr(fm1 <= fs['fm1'][1, :M1 + 1], name="fm1_UB")
            model.addConstr(fm1 >= fs['fm1'][0, :M1 + 1], name="fm1_LB")
            model.addConstr(fm2 <= fs['fm2'][1, :M2 + 1], name="fm2_UB")
            model.addConstr(fm2 >= fs['fm2'][0, :M2 + 1], name="fm2_LB")

            # joint
            model.addConstr(fm <= fs['fm1m2'][1, :M1 + 1, :M2 + 1], name="fm_UB")
            model.addConstr(fm >= fs['fm1m2'][0, :M1 + 1, :M2 + 1], name="fm_LB")

            # CME
            z = model.addVar()
            model.addConstr(z == 0)

            if marginal:
                for m1 in range(M1 - 1):

                    model.addConstr(
                        z == (m1 != 0)*fm1[m1 - 1]*k_tx_1*pd1[m1 - 1] - fm1[m1]*pd1[m1] + (m1 + 1)*pd1[m1 + 1] - m1*pd1[m1],
                        name=f"CME_x1_{m1}"
                    )

                for m2 in range(M2 - 1):

                    model.addConstr(
                        z == (m2 != 0)*fm2[m2 - 1]*k_tx_2*pd2[m2 - 1] - fm2[m2]*pd2[m2] + (m2 + 1)*pd2[m2 + 1] - m2*pd2[m2],
                        name=f"CME_x2_{m2}"
                    )

            if joint:
                for m1 in range(M1 - 1):
                    for m2 in range(M2 - 1):

                        model.addConstr(
                            z == (m1 != 0)*fm[m1 - 1, m2]*k_tx_1*pd[m1 - 1, m2] + \
                                 (m2 != 0)*fm[m1, m2 - 1]*k_tx_2*pd[m1, m2 - 1] + \
                                 (m1 + 1)*pd[m1 + 1, m2] + \
                                 (m2 + 1)*pd[m1, m2 + 1] - \
                                 (fm[m1, m2]*k_tx_1 + fm[m1, m2]*k_tx_2 + m1 + m2)*pd[m1, m2],
                            name=f"CME_{m1}_{m2}"
                        )


            # optimize: test feasibility
            model.setObjective(0, GRB.MINIMIZE)
            try:
                model.optimize()
            except gp.GurobiError:
                print("GurobiError")

            # collect solution information
            solution = {
                'status': status_codes[model.status],
                'time': model.Runtime
            }

            if solution['status'] == "INFEASIBLE":
                model.computeIIS()
                model.write('iis-dsp.ilp')

    # print
    if print_solution:
        print(f"Optimization status: {solution['status']}")
        print(f"Runtime: {solution['time']}")

    return solution

In [453]:
for k_reg in [0.5, 0.4, 0.3, 0.2, 0.15, 0.1, 0.05, 0.0]:

    print(f"k_reg: {k_reg}")

    # model parameters
    params = {
        'k_tx_1': 1,
        'k_tx_2': 1,
        'k_deg_1': 1,
        'k_deg_2': 1,
        'k_reg': k_reg
    }

    # simulate sample
    sample = simulation.gillespie_birth_death(params, 1000)

    # capture efficiencies
    beta = rng.beta(1, 2, size=1000)

    # downsample
    x1_sample = [x[0] for x in sample]
    x2_sample = [x[1] for x in sample]
    x1_downsampled = rng.binomial(x1_sample, beta)
    x2_downsampled = rng.binomial(x2_sample, beta)
    downsampled = list(zip(x1_downsampled, x2_downsampled))

    # bootstrap probabilities
    probs = bootstrap_probabilities(downsampled, thresh_OB=20, threshM_OB=20, printing=False)

    # bootstrap f's
    fs = bootstrap_f(downsampled, beta, thresh_OB=20, threshM_OB=20, printing=False)

    # optimize
    print("Marginal:")
    sol_marginal = optimize_joint(probs, fs, silent=True, marginal=True, joint=False)
    print("Joint:")
    sol_joint = optimize_joint(probs, fs, silent=True, marginal=True, joint=True)

k_reg: 0.5
Marginal:
Optimization status: INFEASIBLE
Runtime: 0.0009999275207519531
Joint:
Optimization status: INFEASIBLE
Runtime: 0.0009999275207519531
k_reg: 0.4
Marginal:
Optimization status: INFEASIBLE
Runtime: 0.0019998550415039062
Joint:
Optimization status: INFEASIBLE
Runtime: 0.0010001659393310547
k_reg: 0.3
Marginal:
Optimization status: INFEASIBLE
Runtime: 0.0009999275207519531
Joint:
Optimization status: INFEASIBLE
Runtime: 0.0010001659393310547
k_reg: 0.2
Marginal:
Optimization status: INFEASIBLE
Runtime: 0.0009999275207519531
Joint:
Optimization status: INFEASIBLE
Runtime: 0.0009999275207519531
k_reg: 0.15
Marginal:
Optimization status: OPTIMAL
Runtime: 0.009000062942504883
Joint:
Optimization status: OPTIMAL
Runtime: 0.00599980354309082
k_reg: 0.1
Marginal:
Optimization status: OPTIMAL
Runtime: 0.003000020980834961
Joint:
Optimization status: OPTIMAL
Runtime: 0.005000114440917969
k_reg: 0.05
Marginal:
Optimization status: OPTIMAL
Runtime: 0.005000114440917969
Joint:
Opti

# NOTE: OB counts not independent (as per cell capture introduces correlation)

Means we cannot factorise downsampled probability variables in optimization: pd != pd1 @ pd2

In [399]:
M1, M2 = probs['bounds'].shape[1:]
p1 = np.zeros(M1)
p2 = np.zeros(M2)
p = np.zeros((M1, M2))
for b in beta:
    p += scipy.stats.poisson.pmf(range(M1), b * params['k_tx_1'] / params['k_deg_1'])[:, None] * scipy.stats.poisson.pmf(range(M2), b * params['k_tx_2'] / params['k_deg_2'])[None, :] / len(beta)
    p1 += scipy.stats.poisson.pmf(range(M1), b * params['k_tx_1'] / params['k_deg_1']) / len(beta)
    p2 += scipy.stats.poisson.pmf(range(M2), b * params['k_tx_2'] / params['k_deg_2']) / len(beta)
#probs_truth = p1[:, None] * p2[None, :]

In [400]:
probs['bounds'][0, :, :] < p, p < probs['bounds'][1, :, :]

(array([[ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True]]),
 array([[ True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True]]))

In [401]:
probs['x1_bounds'][0, :] < p1, p1 < probs['x1_bounds'][1, :]

(array([ True,  True,  True,  True,  True]),
 array([ True,  True,  True,  True,  True]))

In [402]:
probs['x2_bounds'][0, :] < p2, p2 < probs['x2_bounds'][1, :]

(array([ True,  True,  True,  True,  True,  True]),
 array([ True,  True,  True,  True,  True,  True]))

In [404]:
def fm1m2(m1, m2, p):
    fm = 0
    for b in beta:
        fm += b * scipy.stats.poisson.pmf(m1, b*params['k_tx_1']/params['k_deg_1']) * scipy.stats.poisson.pmf(m2, b*params['k_tx_2']/params['k_deg_2']) * (1 / len(beta)) / p[m1, m2]
    return fm

def fmi(mi, i, p):
    fm = 0
    for b in beta:
        fm += b * scipy.stats.poisson.pmf(mi, b*params[f'k_tx_{i}']/params[f'k_deg_{i}']) * (1 / len(beta)) / p[mi]
    return fm

fms = np.zeros((M1, M2))
for m1 in range(M1):
    for m2 in range(M2):
        fms[m1, m2] = fm1m2(m1, m2, p)
f1s = np.zeros(M1)
f2s = np.zeros(M2)
for m1 in range(M1):
    f1s[m1] = fmi(m1, 1, p1)
for m2 in range(M2):
    f2s[m2] = fmi(m2, 2, p2)

In [405]:
fs['fm1m2'][0, :, :] <= fms, fms <= fs['fm1m2'][1, :, :]

(array([[False,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True]]),
 array([[ True, False, False,  True,  True,  True],
        [ True, False,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True]]))

In [406]:
fs['fm1'][0, :] <= f1s, f1s <= fs['fm1'][1, :]

(array([ True,  True,  True,  True,  True]),
 array([ True,  True,  True,  True,  True]))

In [407]:
fs['fm2'][0, :] <= f2s, f2s <= fs['fm2'][1, :]

(array([ True,  True,  True,  True,  True,  True]),
 array([ True,  True, False,  True,  True,  True]))