In [1]:
import math
import scipy as sp
import matplotlib.pyplot as plt
from scipy.stats import bernoulli, uniform, chi2
import numpy as np
from scipy.stats.mstats import gmean
from numpy.testing import assert_allclose
from utils import sprt_mart, get_eb_p_value, eb_selector, psi_E, v_i, pm_lambda
np.random.seed(123456789)

In [2]:
stratum_1 = np.random.normal(loc = 0.5, scale = 0.05, size = 200)
stratum_2 = np.random.normal(loc = 0.6, scale = 0.05, size = 100)
strata = [stratum_1, stratum_2]

In [4]:
get_eb_p_value(strata = strata, gamma = 1)

array([1.        , 0.51235982, 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.     

In [5]:
#Next steps: check code, allow lambda to take arbitrary values, 
#or, if null, automatically define using predictable mixture.
#lam = np.array([0.5])
lam = [pm_lambda(x) for x in strata]
gamma = 1
N = np.array([len(x) for x in strata])
K = len(strata)
w = N/np.sum(N)
u = 1
marts = [np.ones(x) for x in N]
a = [(gamma/(np.arange(N[k]) + 1)) * np.cumsum(lam[k]*strata[k] - psi_E(lam[k])*v_i(strata[k])) + (1-gamma)*w[k] for k in np.arange(K)]
running_n = np.zeros(K)
running_a = np.ones(K)
running_b = np.zeros(K)
running_lam = np.array([x[0] for x in lam])
#record which strata are pulled from
selected_strata = np.zeros(np.sum(N))
log_mart = np.zeros(np.sum(N))
i = 0
while any(running_n < (N-1)):
    next_stratum = eb_selector(running_a = running_a, running_n = running_n, lam = running_lam, N = N, gamma = gamma)
    selected_strata[i] = next_stratum
    running_n[next_stratum] += 1
    running_lam[next_stratum] = lam[next_stratum][int(running_n[next_stratum])]
    running_a[next_stratum] = a[next_stratum][int(running_n[next_stratum])]
    running_b[next_stratum] -= running_lam[next_stratum]
    eta_star = np.zeros(K)
    active = np.ones(K)
    #greedy algorithm to optimize over eta
    while((np.dot(eta_star, w) < 1/2) and all(eta_star <= u)):
        weight = -running_b / w
        max_index = np.argmax(weight * active)
        active[max_index] = 0
        eta_star[max_index] = np.minimum(u, (1/2 - np.dot(eta_star, w)) / w[max_index])
    i += 1
    log_mart[i] = np.sum(running_a) + np.dot(running_b, eta_star)
mart = np.exp(log_mart)
p_value = 1/np.maximum(1, mart)

In [6]:
p_value

array([1.        , 0.51235982, 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.     