In [9]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from sim import simulation

In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Simulate some data for demo

In [11]:
def p_Y(A, X):
    # conditional distribution of Y|A,X
    return np.random.normal(loc=X*A, scale=0.25)

def p_Z(A, X):
    # conditional distribution of Z|A,X
    return np.random.normal(loc=np.sqrt((X-0.1)*A*(X>=0.1))-np.sqrt((-X-0.1)*(X<=-0.1))*A, scale=0.25)


# the optimal policy
def d_anal(x, p_Y):
    return (p_Y(1, x) - p_Y(0, x) >= 0)

N = 1000
X = np.random.uniform(-1, 1, N) # uniform
A = np.random.binomial(1, 0.5+0.1*X) # distribution for A|X - allow for vector representations
Z = p_Z(d_anal(X, p_Z), X)
Y = p_Y(d_anal(X, p_Y), X)

## run the simulation

In [17]:
n_iters = 1000 # number of iterations
alpha = 0.05 # confidence level - no need to change this
method_typ_lst = ["os", "mb", "naive_unique"]
avg_lcbs = {} # average lower confidence bound
avg_ucbs = {} # average upper confidence bound
for method_typ in method_typ_lst:
    sim = simulation(X, A, Y, Z, method_typ=method_typ, n_iters=n_iters)
    sim.run(alpha)
    avg_lcbs[method_typ] = np.mean(sim.lcb_z)
    avg_ucbs[method_typ] = np.mean(sim.ucb_z)

2023-08-08 08:00:21,421	INFO worker.py:1636 -- Started a local Ray instance.


[2m[36m(worker pid=2225278)[0m running one iteration for task 3: 0.04611921310424805
Running method os.

2023-08-08 08:00:28,103	INFO worker.py:1636 -- Started a local Ray instance.


[2m[36m(worker pid=2227281)[0m ray running time:  4.47353720664978


[2m[33m(raylet)[0m [2023-08-08 08:00:38,015 E 2227148 2227161] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-08-08_08-00-26_283047_2202087 is over 95% full, available space: 17250885632; capacity: 502467059712. Object creation will fail if spilling is required.


[2m[36m(worker pid=2227251)[0m ray running time:  11.18195629119873[32m [repeated 4x across cluster][0m
[2m[36m(worker pid=2227274)[0m ray running time:  16.269855976104736[32m [repeated 3x across cluster][0m


[2m[33m(raylet)[0m [2023-08-08 08:00:48,052 E 2227148 2227161] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-08-08_08-00-26_283047_2202087 is over 95% full, available space: 17250672640; capacity: 502467059712. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2023-08-08 08:00:58,063 E 2227148 2227161] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-08-08_08-00-26_283047_2202087 is over 95% full, available space: 17250525184; capacity: 502467059712. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2023-08-08 08:01:08,074 E 2227148 2227161] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-08-08_08-00-26_283047_2202087 is over 95% full, available space: 17250291712; capacity: 502467059712. Object creation will fail if spilling is required.


[2m[36m(worker pid=2227281)[0m running one iteration for task 2: 41.07969355583191
[2m[36m(worker pid=2227287)[0m ray running time:  19.205219507217407[32m [repeated 2x across cluster][0m
[2m[36m(worker pid=2227284)[0m running one iteration for task 4: 47.137622117996216[32m [repeated 6x across cluster][0m


[2m[33m(raylet)[0m [2023-08-08 08:01:18,085 E 2227148 2227161] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-08-08_08-00-26_283047_2202087 is over 95% full, available space: 17250242560; capacity: 502467059712. Object creation will fail if spilling is required.


[2m[36m(worker pid=2227302)[0m running one iteration for task 8: 48.65371632575989[32m [repeated 2x across cluster][0m
Running method mb.

2023-08-08 08:01:24,534	INFO worker.py:1636 -- Started a local Ray instance.


[2m[36m(worker pid=2230934)[0m ray running time:  4.29848051071167
[2m[36m(worker pid=2230934)[0m running one iteration for task 4: 4.9648683071136475


[2m[33m(raylet)[0m [2023-08-08 08:01:34,427 E 2230774 2230786] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-08-08_08-01-22_609713_2202087 is over 95% full, available space: 17248829440; capacity: 502467059712. Object creation will fail if spilling is required.


[2m[36m(worker pid=2230944)[0m ray running time:  8.581119537353516[32m [repeated 3x across cluster][0m
[2m[36m(worker pid=2230944)[0m running one iteration for task 5: 9.283493757247925[32m [repeated 3x across cluster][0m
[2m[36m(worker pid=2230923)[0m ray running time:  13.617282390594482[32m [repeated 3x across cluster][0m
[2m[36m(worker pid=2230923)[0m running one iteration for task 1: 14.271918058395386[32m [repeated 3x across cluster][0m


[2m[33m(raylet)[0m [2023-08-08 08:01:44,475 E 2230774 2230786] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-08-08_08-01-22_609713_2202087 is over 95% full, available space: 17248710656; capacity: 502467059712. Object creation will fail if spilling is required.


[2m[36m(worker pid=2230924)[0m ray running time:  18.061262607574463[32m [repeated 3x across cluster][0m
[2m[36m(worker pid=2230922)[0m running one iteration for task 9: 17.703144788742065[32m [repeated 2x across cluster][0m
Running method naive_unique.

In [18]:
# print results
print("average lcbs: ", avg_lcbs)
print("average ucbs: ", avg_ucbs)

average lcbs:  {'os': 0.24259361522792702, 'mb': 0.18694018823428832, 'naive_unique': 0.23733974446437775}
average ucbs:  {'os': 0.3078756318580821, 'mb': 0.35989756897948355, 'naive_unique': 0.3093893254531828}
