Skip to content

Commit

Permalink
Merge pull request #8 from rademacher-p/py_mods
Browse files Browse the repository at this point in the history
Py mods
  • Loading branch information
rademacher-p committed Apr 18, 2020
2 parents d81189b + c63dbd0 commit d30abd4
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 947 deletions.
154 changes: 0 additions & 154 deletions Code/PGR_thesis/_old.py

This file was deleted.

64 changes: 42 additions & 22 deletions Code/PGR_thesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@


from util.util import simplex_grid
from rv_obj import deterministic_multi, dirichlet_multi, discrete_multi
# from rv_obj import deterministic_multi, dirichlet_multi, discrete_multi
from rv_obj import DeterministicRE, DirichletRE, FiniteRE

plt.style.use('seaborn')
# plt.style.use('seaborn')

rng = random.default_rng()

Expand Down Expand Up @@ -82,22 +83,29 @@ def theta_c(x): return stats.multivariate_normal(mean=x)
#%% Discrete sets

Y_set = np.array(['a', 'b'])
# Y_set = np.arange(2)
# X_set = np.arange(3)
# Y_set = np.arange(3)
# X_set = np.arange(1)
# X_set = np.arange(6).reshape(3, 2)
X_set = np.stack(np.meshgrid(np.arange(3), np.arange(2)), axis=-1)
# X_set = np.stack(np.meshgrid(np.arange(2), np.arange(2)), axis=-1)
X_set = np.random.random((2,2,2))

i_split_y, i_split_x = Y_set.ndim, X_set.ndim-1
i_split_y, i_split_x = Y_set.ndim, X_set.ndim-0


# YX_set = np.array(list(itertools.product(Y_set.flatten(), X_set.flatten())),
# dtype=[('y', Y_set.dtype), ('x', X_set.dtype)]).reshape(Y_set.shape + X_set.shape)

Y_set_shape, Y_data_shape = Y_set.shape[:i_split_y], Y_set.shape[i_split_y:]
X_set_shape, X_data_shape = X_set.shape[:i_split_x], X_set.shape[i_split_x:]

# tt = list(map(tuple, X_set.reshape((-1,) + X_data_shape)))
tt = [(x,) for x in X_set.reshape((-1,) + X_data_shape)]
# tt = list(X_set.reshape((-1,) + X_data_shape))
xx = np.array(tt, dtype=[('x', X_set.dtype, X_data_shape)]).reshape(X_set_shape) ###

_temp = list(itertools.product(Y_set.reshape((-1,) + Y_data_shape), X_set.reshape((-1,) + X_data_shape)))
YX_set = np.array(_temp, dtype=[('y', Y_set.dtype, Y_data_shape),
('x', X_set.dtype, X_data_shape)]).reshape(Y_set_shape + X_set_shape)
('x', X_set.dtype, X_data_shape)]).reshape(Y_set_shape + X_set_shape)

# YX_set = np.array(list(itertools.product(Y_set, X_set)),
# dtype=[('y', Y_set.dtype, Y_data_shape), ('x', X_set.dtype, X_data_shape)]).reshape(Y_set_shape + X_set_shape)
Expand All @@ -106,36 +114,48 @@ def theta_c(x): return stats.multivariate_normal(mean=x)

n_plt = 10

# val = dirichlet_multi.rvs(YX_set.size, np.ones(YX_set.shape)/YX_set.size)
# prior = deterministic_multi(val)
# # val = dirichlet_multi.rvs(YX_set.size, np.ones(YX_set.shape)/YX_set.size)
# # prior = deterministic_multi(val)
# val = DirichletRE(YX_set.size, np.ones(YX_set.shape)/YX_set.size).rvs()
# prior = DeterministicRE(val)
# t_plt = simplex_grid(n_plt, YX_set.shape)

alpha_0 = YX_set.size
mean = dirichlet_multi.rvs(YX_set.size, np.ones(YX_set.shape) / YX_set.size)
prior = dirichlet_multi(alpha_0, mean, rng)
t_plt = simplex_grid(n_plt, YX_set.shape, hull_mask=(mean < 1 / alpha_0))
alpha_0 = 10*YX_set.size
# mean = dirichlet_multi.rvs(YX_set.size, np.ones(YX_set.shape) / YX_set.size)
mean = DirichletRE(YX_set.size, np.ones(YX_set.shape) / YX_set.size).rvs()
# prior = dirichlet_multi(alpha_0, mean, rng)
prior = DirichletRE(alpha_0, mean, rng)
# t_plt = simplex_grid(n_plt, YX_set.shape, hull_mask=(mean < 1 / alpha_0))


p_theta_plt = prior.pdf(t_plt)
# p_theta_plt = prior.pdf(t_plt)
theta_pmf = prior.rvs()

# prior_plt.sum() / (n_plt**(mean.size-1))


# TODO: add plot methods to RV classes
if YX_set.shape == (3, 1):
_, ax_prior = plt.subplots(num='prior', clear=True, subplot_kw={'projection': '3d'})
sc = ax_prior.scatter(t_plt[:, 0], t_plt[:, 1], t_plt[:, 2], s=15, c=p_theta_plt)
ax_prior.view_init(35, 45)
plt.colorbar(sc)
ax_prior.set(xlabel='$x_1$', ylabel='$x_2$', zlabel='$x_3$')
# if YX_set.shape == (3, 1):
# _, ax_prior = plt.subplots(num='prior', clear=True, subplot_kw={'projection': '3d'})
# sc = ax_prior.scatter(t_plt[:, 0], t_plt[:, 1], t_plt[:, 2], s=15, c=p_theta_plt)
# ax_prior.view_init(35, 45)
# plt.colorbar(sc)
# ax_prior.set(xlabel='$x_1$', ylabel='$x_2$', zlabel='$x_3$')

# TODO: marginal/conditional models to alleviate structured array issues?

theta = discrete_multi(YX_set, theta_pmf, rng)
# theta = discrete_multi(YX_set, theta_pmf, rng)
theta = FiniteRE(YX_set, theta_pmf, rng)

theta.rvs(6)

theta_m = discrete_multi(X_set, theta_pmf.sum(axis=0))
###
theta_m_pmf = theta_pmf.reshape((-1,) + X_set_shape).sum(axis=0)
# theta_m = discrete_multi(X_set, theta_m_pmf)
theta_m = FiniteRE(X_set, theta_m_pmf)
theta_m.mean # TODO: broken, tuple product
theta_m.rvs()
theta_m.pmf(theta_m.rvs(2))

D = theta.rvs(10)

Expand Down
Loading

0 comments on commit d30abd4

Please sign in to comment.