In [None]:
# new likelihood found by iteratively reducing teh likelihood to remove B, using the identities B(a + 1, b) = B(a, b) * a / (a + b), B(a, b + 1) = B(a, b) * b / (a + b)

import torch
from torch.distributions import bernoulli, beta


In [None]:
def new_log_marginal(a, b, x, n):
    j = torch.linspace(1, (n - x), (n - x))
    i = torch.linspace(1, x, x)
    A = torch.sum(torch.log((b + n - x) - j) - torch.log((a + b + n) - j))
    B = torch.sum(torch.log((a + x) - i) - torch.log((a + b + n) - i))
    return A + B


In [None]:
# plot beta distribution given a and b

def plot_beta(a, b):
    beta_dist = beta.Beta(a, b)
    x = torch.linspace(0, 1, 100)
    y = beta_dist.log_prob(x)
    plt.plot(x, torch.exp(y))
    plt.show()

In [None]:
plot_beta(3.1, 2.9)

In [None]:
def old_log_marginal(a, b, x, n):
    A = (a + b - 1/2) * torch.log(a + b) + (x + a - 1/2) * torch.log(x + a) + (n - x + b - 1/2) * torch.log(n - x + b)
    
    B = (a - 1/2) * torch.log(a) + (b - 1/2) * torch.log(b) + (a + n + b - 1/2) * torch.log(a + n + b)
    
    return A - B

In [None]:
a = torch.tensor([100.0])
b = torch.tensor([200.0])

beta_dist = beta.Beta(a, b)
n = 100000


rho = beta_dist.sample([n])
bernoulli_dist = bernoulli.Bernoulli(rho)
X = bernoulli_dist.sample()

w_a = torch.tensor([1.0], requires_grad=True)
w_b = torch.tensor([4.0], requires_grad=True)

optimizer = torch.optim.SGD([w_a, w_b], lr=0.01)

for i in range(100000):

    
    # mini batch X:
    for j in range(n // 100):
        optimizer.zero_grad()
        s = torch.sum(X[j * 100:(j + 1) * 100])
        loss = -new_log_marginal(w_a, w_b, int(s.item()), 100) 
        loss.backward()
        optimizer.step()

    # print(f'sum: {torch.sum(X)}')
    print(w_a.item(), w_b.item())


In [None]:
from tqdm import tqdm
# plot values for log_marginal over different values of a and b, using x = 50, n = 100
import matplotlib.pyplot as plt

a_vals = torch.linspace(0, 2000, 100)
b_vals = torch.linspace(0, 2000, 100)

log_marginal_vals = torch.zeros((100, 100))

for i, a_val in tqdm(enumerate(a_vals)):
    for j, b_val in enumerate(b_vals):
        log_marginal_vals[i, j] = new_log_marginal(a_val, b_val, 90, 100) / 100
        
plt.imshow(log_marginal_vals, extent=[0, 10, 0, 10], origin='lower', aspect='auto')
plt.colorbar()

# label axes 
plt.xlabel('b')
plt.ylabel('a')

plt.show()


In [None]:
# plot values for new_log_marginal over different values of n with a = 10, b = 10, x = 5

n_vals = torch.linspace(10, 1000, 10)

log_marginal_vals = torch.zeros((10, ))

for i, n_val in tqdm(enumerate(n_vals)):
    log_marginal_vals[i] = new_log_marginal(10, 10, 1, int(n_val.item()))
    
plt.plot(log_marginal_vals)
# label axes
plt.xlabel('n')
plt.ylabel('loss')
    
plt.show()

In [None]:
# 3d plot values of new log marginal over different values of a, b and x with n=100

from mpl_toolkits.mplot3d import Axes3D
import numpy as np

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

a_vals = torch.linspace(0, 100, 10)
b_vals = torch.linspace(0, 100, 10)

x_vals = torch.linspace(0, 100, 10)

log_marginal_vals = torch.zeros((10, 10, 10))

for i, a_val in tqdm(enumerate(a_vals)):
    for j, b_val in enumerate(b_vals):
        for k, x_val in enumerate(x_vals):
            log_marginal_vals[i, j, k] = new_log_marginal(a_val, b_val, int(x_val), 100)
            
            
            
# plot with x, a, b on axes with colouration for log marginal

X, Y, Z = np.meshgrid(a_vals.numpy(), b_vals.numpy(), x_vals.numpy())
ax.scatter(X, Y, Z, c=log_marginal_vals.numpy())

# add labels
ax.set_xlabel('a')
ax.set_ylabel('b')
ax.set_zlabel('x')

# add colour bar
cbar = plt.colorbar(ax.scatter(X, Y, Z, c=log_marginal_vals.numpy()))

            
plt.show()




In [None]:
from tqdm import tqdm
# plot values for log_marginal over different values of a, with b fixed at 1, using x = 50, n = 100
import matplotlib.pyplot as plt

a_vals = torch.linspace(0, 10, 100)

log_marginal_vals = torch.zeros((100, ))

for i, a_val in tqdm(enumerate(a_vals)):
    log_marginal_vals[i] = new_log_marginal(a_val, 1, 1, 2)

plt.plot(log_marginal_vals)
# label axes 
plt.xlabel('a')
plt.ylabel('loss')

plt.show()



In [None]:

_log_marginal_vals = torch.zeros((1000, 1000))

for i, a_val in tqdm(enumerate(a_vals)):
    for j, b_val in enumerate(b_vals):
        _log_marginal_vals[i, j] = old_log_marginal(a_val, b_val, 150, 200) 

plt.imshow(_log_marginal_vals, extent=[0, 10, 0, 10], origin='lower', aspect='auto')
plt.colorbar()
plt.show()


In [None]:
def torch_binom(n, k):
    mask = n.detach() >= k.detach()
    n = mask * n
    k = mask * k
    a = torch.lgamma(n + 1) - torch.lgamma((n - k) + 1) - torch.lgamma(k + 1)
    return torch.exp(a) * mask