In [3]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/Users/rupertmenneer/Documents/git/bayesian_flow/')
from discretised.bfn_discretised_data import DiscretisedBimodalData
from discretised.bfn_discretised import BayesianFlowNetworkDiscretised
from bfn.models import SimpleNeuralNetworkDiscretised
from torch.utils.data import DataLoader
from torch.optim import AdamW
import matplotlib.pyplot as plt
import torch

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [67]:
import torch
# original 
def discretised_cdf(mu, sigma, x):
    if x < -1:
        return 0
    if x > 1:
        return 1
    else:
        return 0.5 * (1 + torch.erf((x-mu) / ( sigma*torch.sqrt( torch.tensor(2.0) ) ) ) )
    

def vectorised_cdf(mu, sigma, x):

    # ensure shapes align for correct broadcasting
    mu = mu.unsqueeze(-1)  # Shape: [B, D, 1]
    sigma = sigma.unsqueeze(-1)  # Shape: [B, D, 1]
    x = x.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, K]
    assert mu.dim() == sigma.dim() == x.dim()

    cdf_func = 0.5 * (1 + torch.erf((x - mu) / (sigma * torch.sqrt(torch.tensor(2.0)))))

    # Apply conditions directly without squeezing, using broadcasting
    lower_mask = x < -1
    upper_mask = x > 1

    # Apply masks
    cdf_func = torch.where(lower_mask, torch.zeros_like(cdf_func), cdf_func)
    cdf_func = torch.where(upper_mask, torch.ones_like(cdf_func), cdf_func)

    return cdf_func


batch = 8
k = 4
d = 1

upper_bounds = torch.linspace(-1.1, 1.1, k)
mu = torch.randn((batch, d))
sigma = torch.randn((batch, d))

# Correct the loop variable shadowing issue
result_discretised = torch.zeros((batch, d, k))
for i in range(batch):
    for j in range(d):
        for l in range(k):  # Use a different variable here to avoid shadowing
            result_discretised[i, j, l] = discretised_cdf(mu[i, j], sigma[i, j], upper_bounds[l])

result_vectorised = vectorised_cdf(mu, sigma, upper_bounds)

print(result_vectorised.shape, result_discretised.shape)
torch.allclose(result_discretised, result_vectorised)



torch.Size([8, 1, 4]) torch.Size([8, 1, 4])


True

In [63]:
# original 
def discretised_cdf(mu, sigma, x):
    if x < -1:
        return 0
    if x > 1:
        return 1
    else:
        return 0.5 * (1 + torch.erf((x-mu) / ( sigma*torch.sqrt( torch.tensor(2.0) ) ) ) )

def vectorised_cdf_fixed(mu, sigma, x):
    # ensure shapes align for correct broadcasting
    mu = mu.unsqueeze(-1)  # Shape: [B, D, 1]
    sigma = sigma.unsqueeze(-1)  # Shape: [B, D, 1]
    x_expanded = x.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, K]

    cdf_func = 0.5 * (1 + torch.erf((x_expanded - mu) / (sigma * torch.sqrt(torch.tensor(2.0)))))

    # Apply conditions directly without squeezing, using broadcasting
    lower_mask = x_expanded < -1
    upper_mask = x_expanded > 1

    # Apply masks
    cdf_func = torch.where(lower_mask, torch.zeros_like(cdf_func), cdf_func)
    cdf_func = torch.where(upper_mask, torch.ones_like(cdf_func), cdf_func)

    return cdf_func

batch = 8
k = 4
d = 1

upper_bounds = torch.linspace(-1.1, 1.1, k)
mu = torch.randn((batch, d))
sigma = torch.randn((batch, d))

# Correct the loop variable shadowing issue
result_discretised_fixed = torch.zeros((batch, d, k))
for i in range(batch):
    for j in range(d):
        for l in range(k):  # Use a different variable here to avoid shadowing
            result_discretised_fixed[i, j, l] = discretised_cdf(mu[i, j], sigma[i, j], upper_bounds[l])

result_vectorised_fixed = vectorised_cdf_fixed(mu, sigma, upper_bounds)

print(torch.allclose(result_discretised_fixed, result_vectorised_fixed, atol=1e-6))


True


In [4]:
dataset = DiscretisedBimodalData(k=16)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

bfn_model = BayesianFlowNetworkDiscretised(SimpleNeuralNetworkDiscretised())
optim = AdamW(bfn_model.parameters(), lr=3e-3, betas=(0.9, 0.98), weight_decay=0.01)

epochs = 1000
losses = []
for i in range(epochs):
    for _, batch in enumerate(dataloader):
        optim.zero_grad()
        loss = bfn_model.continuous_time_loss_for_discretised_data(batch)
        loss.backward()
        optim.step()
        losses.append(loss.item())
    print(f'Epoch {i+1}/{epochs}, Loss: {loss.item()}')

plt.plot(losses)

Epoch 1/1000, Loss: 2.725398302078247
Epoch 2/1000, Loss: 2.975722074508667
Epoch 3/1000, Loss: 2.170651912689209
Epoch 4/1000, Loss: 1.492960810661316
Epoch 5/1000, Loss: 1.0062901973724365
Epoch 6/1000, Loss: 0.8573547005653381
Epoch 7/1000, Loss: 1.0479390621185303
Epoch 8/1000, Loss: 0.6261180639266968
Epoch 9/1000, Loss: 0.6320828199386597
Epoch 10/1000, Loss: 0.9604189395904541


KeyboardInterrupt: 

This example is 1 DIMENSIONAL (D), with 16 BINS (K)

In [None]:
dataset = DiscretisedBimodalData(k=16)
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True)
batch = next(iter(dataloader))
print(batch.shape)
# plt.hist(batch.numpy(), bins=16)

# Sample generation

In [None]:
samples = bfn_model.sample_generation_for_discretised_data(n_steps=20, bs=1024).to(torch.float32)
plt.hist(samples.detach().numpy(), bins=16)
plt.hist(batch.numpy(), bins=16, alpha=0.1)

In [None]:
 ls ../bfn_github/

In [None]:

from bfn_github.bfn import BayesianFlowNetwork
import torch
# network should learn:
# when x0 = 0, x1 = 1
# when x0 = 1, x1 = 0
def get_datapoint(batch=128, device='cpu'):
    x0 = torch.randint(0, 2, size=(batch,), dtype=torch.bool, device=device)
    x1 = ~x0

    X = torch.stack([x0, x1], dim=0)
    return X.long().transpose(0, 1)

X = get_datapoint()  # (B, D=2) with K=2 classes 
print(X.shape)
# plt.title("Dataset")
# plt.scatter(X[:, 0], X[:, 1]);
# plt.grid()

from torch.optim import AdamW
from tqdm.auto import tqdm

bfn = BayesianFlowNetwork()

optim = AdamW(bfn.parameters(), lr=1e-2)


n = 1000
losses = []
for i in tqdm(range(n)):
    optim.zero_grad()
    X = get_datapoint(device='cpu')
    loss = bfn.process(X)
    loss.backward()
    optim.step()
    losses.append(loss.item())

bfn.sample(16)

In [None]:
torch.nn.functional.one_hot(torch.tensor([2, 1]).unsqueeze(0), num_classes=3)
def vectorised_discretised_cdf(mu, sigma, bounds):
    # input is mu, sigma -> B x D, and bounds -> K
    lower_mask = bounds < -1
    upper_mask = bounds > 1

    # output is B x D x K
    result = torch.zeros(mu.shape[0], mu.shape[1], bounds.shape[0])
    result = 0.5 * (1 + torch.erf((bounds - mu.unsqueeze(-1)) / (sigma.unsqueeze(-1) * torch.sqrt(torch.tensor(2.0)))))
    # clip result depending on boundsa
    result[:, :, lower_mask] = 0
    result[:, :, upper_mask] = 1

    return result

mu = torch.randn(2, 2)
sigma = torch.ones(2, 2)
bounds = torch.linspace(-1.1, 1.1, 4)

cdf = vectorised_discretised_cdf(mu, sigma, bounds)

print(cdf)