In [46]:
import math
import scipy as sp
import matplotlib.pyplot as plt
from scipy.stats import bernoulli, uniform, chi2
import numpy as np
from scipy.stats.mstats import gmean
from numpy.testing import assert_allclose
from utils import sprt_mart, eb_selector, psi_E, v_i
np.random.seed(123456789)

In [47]:
stratum_1 = np.random.normal(loc = 0.5, scale = 0.05, size = 200)
stratum_2 = np.random.normal(loc = 0.6, scale = 0.05, size = 100)
strata = [stratum_1, stratum_2]
lam = np.array([0.5])

In [55]:
N = np.array([len(x) for x in strata])
K = len(strata)
w = N/np.sum(N)
u = 1
gamma = 1
marts = [np.ones(x) for x in N]
a = [(gamma/(np.arange(N[k]) + 1)) * np.cumsum(lam*strata[k] - psi_E(lam)*v_i(strata[k])) + (1-gamma)*w[k] for k in np.arange(K)]
running_n = np.zeros(K)
running_a = np.ones(K)
#record which strata are pulled from
selected_strata = np.zeros(np.sum(N))
log_mart = np.zeros(np.sum(N))
i = 0
running_b = np.zeros(K)
while any(running_n < (N-1)):
    next_stratum = eb_selector(running_a = running_a, running_n = running_n, lam = lam, N = N, gamma = gamma)
    selected_strata[i] = next_stratum
    running_n[next_stratum] += 1
    running_a[next_stratum] = a[next_stratum][int(running_n[next_stratum])]
    running_b[next_stratum] += -lam
    eta_star = np.zeros(K)
    active = np.ones(K)
    #greedy algorithm to optimize over eta
    while((np.dot(eta_star, w) < 1/2) and all(eta_star <= u)):
        weight = -running_b / w
        max_index = np.argmax(weight * active)
        active[max_index] = 0
        eta_star[max_index] = np.minimum(u, (1/2 - np.dot(eta_star, w)) / w[max_index])
    i += 1
    print(eta_star)
    print(running_b)
    log_mart[i] = np.sum(running_a) + np.dot(running_b, eta_star)
mart = np.exp(log_mart)
p_value = 1/np.maximum(1,mart)


[0.25 1.  ]
[ 0.  -0.5]
[0.25 1.  ]
[-0.5 -0.5]
[0.75 0.  ]
[-1.  -0.5]
[0.25 1.  ]
[-1. -1.]
[0.25 1.  ]
[-1.5 -1. ]
[0.25 1.  ]
[-1.5 -1.5]
[0.25 1.  ]
[-2.  -1.5]
[0.25 1.  ]
[-2. -2.]
[0.25 1.  ]
[-2.  -2.5]
[0.25 1.  ]
[-2. -3.]
[0.25 1.  ]
[-2.  -3.5]
[0.25 1.  ]
[-2.5 -3.5]
[0.25 1.  ]
[-2.5 -4. ]
[0.25 1.  ]
[-3. -4.]
[0.25 1.  ]
[-3.5 -4. ]
[0.25 1.  ]
[-3.5 -4.5]
[0.25 1.  ]
[-4.  -4.5]
[0.25 1.  ]
[-4.5 -4.5]
[0.25 1.  ]
[-4.5 -5. ]
[0.25 1.  ]
[-5. -5.]
[0.25 1.  ]
[-5.  -5.5]
[0.25 1.  ]
[-5. -6.]
[0.25 1.  ]
[-5.5 -6. ]
[0.25 1.  ]
[-5.5 -6.5]
[0.25 1.  ]
[-6.  -6.5]
[0.25 1.  ]
[-6.5 -6.5]
[0.25 1.  ]
[-6.5 -7. ]
[0.25 1.  ]
[-6.5 -7.5]
[0.25 1.  ]
[-7.  -7.5]
[0.25 1.  ]
[-7. -8.]
[0.25 1.  ]
[-7.5 -8. ]
[0.25 1.  ]
[-8. -8.]
[0.25 1.  ]
[-8.  -8.5]
[0.25 1.  ]
[-8.5 -8.5]
[0.25 1.  ]
[-8.5 -9. ]
[0.25 1.  ]
[-8.5 -9.5]
[0.25 1.  ]
[ -8.5 -10. ]
[0.25 1.  ]
[ -8.5 -10.5]
[0.25 1.  ]
[ -8.5 -11. ]
[0.25 1.  ]
[ -8.5 -11.5]
[0.25 1.  ]
[ -9.  -11.5]
[0.25 1.  ]
[ -9. -12.

In [54]:
log_mart

array([ 0.00000000e+00,  7.80185069e-01, -4.28246850e-02, -1.70353119e-01,
       -6.59046266e-01, -7.95897648e-01, -1.30088622e+00, -1.79148207e+00,
       -2.28786740e+00, -2.78542228e+00, -3.29094759e+00, -3.41924611e+00,
       -3.55346141e+00, -4.05729423e+00, -4.18183873e+00, -4.30584728e+00,
       -4.80304857e+00, -5.30323641e+00, -5.80055805e+00, -6.30275251e+00,
       -6.80346103e+00, -6.92883850e+00, -7.05437015e+00, -7.55221765e+00,
       -7.67835652e+00, -7.80659713e+00, -7.93165820e+00, -8.43364139e+00,
       -8.93443468e+00, -9.06053327e+00, -9.56065053e+00, -1.00594325e+01,
       -1.05578808e+01, -1.10588365e+01, -1.15590209e+01, -1.16867965e+01,
       -1.21886359e+01, -1.23174230e+01, -1.28178773e+01, -1.29436244e+01,
       -1.34428391e+01, -1.39419202e+01, -1.44412502e+01, -1.49413821e+01,
       -1.50665679e+01, -1.55672338e+01, -1.60675892e+01, -1.61943251e+01,
       -1.63209290e+01, -1.68213887e+01, -1.69481283e+01, -1.74481870e+01,
       -1.79479311e+01, -