In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/Users/rupertmenneer/Documents/git/bayesian_flow/')
from datasets.bfn_discretised_toy_data import DiscretisedBimodalData
from discretised.bfn_discretised import BayesianFlowNetworkDiscretised
from models.simple_models import SimpleNeuralNetworkDiscretised
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD, AdamW
import matplotlib.pyplot as plt
import torch
torch.set_printoptions(precision=5, sci_mode=False)
k = 5
dataset = DiscretisedBimodalData(n=5000, k=k)
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, drop_last=True)
batch = next(iter(dataloader))

In [None]:
from torch_ema import ExponentialMovingAverage

device = 'cpu'
bfn_model = BayesianFlowNetworkDiscretised(SimpleNeuralNetworkDiscretised(), device=device, k=k).to(device)
optim = AdamW(bfn_model.parameters(), lr=3e-4, betas=(0.9, 0.98), weight_decay=0.01)

ema = ExponentialMovingAverage(bfn_model.parameters(), decay=0.995)
torch.autograd.set_detect_anomaly(True)
epochs = 1000
losses = []
n_batches_track = 100
for i in range(epochs):
    # print(i)
    epoch_losses = []
    for _, batch in enumerate(dataloader):
        optim.zero_grad()
        loss = bfn_model.discrete_time_loss_for_discretised_data(batch.to(device))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(bfn_model.parameters(), max_norm=2.0)
        optim.step()
        # Update the moving average with the new parameters from the last optimizer step
        ema.update()
        epoch_losses.append(loss.item())
    if i%n_batches_track == 0:
        print(f'Epoch {i+1}/{epochs}, Loss: {torch.mean(torch.tensor(epoch_losses))}')
    losses.append(torch.mean(torch.tensor(epoch_losses)))



In [None]:
print(batch.shape)
plt.hist(batch.squeeze(), bins=k)

In [None]:
plt.plot(losses)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Training loss')

# Sample generation

In [None]:
torch.randint(1, 25, (32, 1))

In [None]:
k=5
bs=1000
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
batch = next(iter(dataloader))
print(len(torch.unique(batch)))
samples, priors = bfn_model.sample_generation_for_discretised_data(sample_shape=(bs, 1), n_steps=100)
samples = samples.to(torch.float32)
print(len(torch.unique(samples)))

In [None]:
k=5
bs=256
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
batch = next(iter(dataloader))
batch.shape
with ema.average_parameters():
    samples, priors = bfn_model.sample_generation_for_discretised_data(sample_shape=(bs, 1), n_steps=100)
    samples = samples.to(torch.float32)

plt.hist(samples.detach().numpy(), alpha=0.8, bins=k, label='BFN Samples', color='orange')
plt.hist(batch.numpy(), bins=k, alpha=0.5, label='True Samples')
plt.title('Samples from Discretised BFN model')
plt.legend()

In [None]:
plt.plot(priors.detach().numpy()[:, :, 0, :].squeeze().T, label='Prior mu', alpha=0.03);
plt.xlabel('Generation step')
plt.ylabel('Prior mu')
plt.title('Mean samples over time from Discretised BFN model')