In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/Users/rupertmenneer/Documents/git/bayesian_flow/')
from discretised.bfn_discretised_data import DiscretisedBimodalData
from discretised.bfn_discretised import BayesianFlowNetworkDiscretised
from bfn.models import SimpleNeuralNetworkDiscretised
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD
import matplotlib.pyplot as plt
import torch
torch.set_printoptions(precision=5, sci_mode=False)
dataset = DiscretisedBimodalData(n=1000, k=32)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
batch = next(iter(dataloader))
bfn_model = BayesianFlowNetworkDiscretised(SimpleNeuralNetworkDiscretised())
loss = bfn_model.continuous_time_loss_for_discretised_data(batch)


In [None]:
batch = next(iter(dataloader))
print(batch.shape)

plt.hist(batch.numpy(), bins=32, density=True)


In [None]:

bfn_model = BayesianFlowNetworkDiscretised(SimpleNeuralNetworkDiscretised())
optim = Adam(bfn_model.parameters(), lr=3e-6, betas=(0.9, 0.98), weight_decay=0.01)
# optim = SGD(bfn_model.parameters(), lr=0.001)

epochs = 5000
losses = []
n_batches_track = 100
for i in range(epochs):
    for _, batch in enumerate(dataloader):
        optim.zero_grad()
        loss = bfn_model.continuous_time_loss_for_discretised_data(batch)
        loss.backward()
        optim.step()
        losses.append(loss.item())
    if i%n_batches_track == 0:
        print(f'Epoch {i+1}/{epochs}, Loss: {loss.item()}')



In [None]:
plt.plot(losses)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Training loss')

# Sample generation

In [None]:
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
batch = next(iter(dataloader))
batch.shape

In [None]:
samples, priors = bfn_model.sample_generation_for_discretised_data(n_steps=50, bs=64)
samples = samples.to(torch.float32)
plt.hist(samples.detach().numpy(), alpha=0.8, bins=64, label='BFN Samples', color='orange')
plt.hist(batch.numpy(), bins=64, alpha=0.5, label='True Samples')
plt.title('Samples from Discretised BFN model')
plt.legend()

In [None]:
plt.plot(priors.detach().numpy()[:10, :, 0, :].squeeze().T, label='Prior mu');
# plt.plot(priors.detach().numpy()[0, :, 1, :].squeeze(), label='Prior precision')

In [None]:
 ls ../bfn_github/

In [None]:

from bfn_github.bfn import BayesianFlowNetwork
import torch
# network should learn:
# when x0 = 0, x1 = 1
# when x0 = 1, x1 = 0
def get_datapoint(batch=128, device='cpu'):
    x0 = torch.randint(0, 2, size=(batch,), dtype=torch.bool, device=device)
    x1 = ~x0

    X = torch.stack([x0, x1], dim=0)
    return X.long().transpose(0, 1)

X = get_datapoint()  # (B, D=2) with K=2 classes 
print(X.shape)
# plt.title("Dataset")
# plt.scatter(X[:, 0], X[:, 1]);
# plt.grid()

from torch.optim import AdamW
from tqdm.auto import tqdm

bfn = BayesianFlowNetwork()

optim = AdamW(bfn.parameters(), lr=1e-2)


n = 1000
losses = []
for i in tqdm(range(n)):
    optim.zero_grad()
    X = get_datapoint(device='cpu')
    loss = bfn.process(X)
    loss.backward()
    optim.step()
    losses.append(loss.item())

bfn.sample(16)

In [None]:
torch.nn.functional.one_hot(torch.tensor([2, 1]).unsqueeze(0), num_classes=3)
def vectorised_discretised_cdf(mu, sigma, bounds):
    # input is mu, sigma -> B x D, and bounds -> K
    lower_mask = bounds < -1
    upper_mask = bounds > 1

    # output is B x D x K
    result = torch.zeros(mu.shape[0], mu.shape[1], bounds.shape[0])
    result = 0.5 * (1 + torch.erf((bounds - mu.unsqueeze(-1)) / (sigma.unsqueeze(-1) * torch.sqrt(torch.tensor(2.0)))))
    # clip result depending on boundsa
    result[:, :, lower_mask] = 0
    result[:, :, upper_mask] = 1

    return result

mu = torch.randn(2, 2)
sigma = torch.ones(2, 2)
bounds = torch.linspace(-1.1, 1.1, 4)

cdf = vectorised_discretised_cdf(mu, sigma, bounds)

print(cdf)