In [None]:
import os
os.chdir('/Users/rupertmenneer/Documents/git/bayesian_flow/')
import torch
from datasets.utils import plot_tensor_images
from discretised.trainer import DiscretisedBFNTrainer
# 150e3a3656bc3e6c76366ee98da5b0fd9f7c16ea
trainer = DiscretisedBFNTrainer(wandb_project_name=None)
bfn_model = trainer.bfn_model
test_dls = trainer.test_dls


In [None]:
import math
import torch
from bfn_discretised import safe_log
from datasets.utils import float_to_idx, idx_to_float, quantize

class DiscretizedCtsDistribution():
    def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, min_prob=1e-5):
        self.num_bins = num_bins
        self.bin_width = 2.0 / num_bins
        self.half_bin_width = self.bin_width / 2.0
        self.cts_dist = cts_dist
        self.log_bin_width = math.log(self.bin_width)
        self.batch_dims = batch_dims
        self.clip = clip
        self.min_prob = min_prob

    def prob(self, x):
            class_idx = float_to_idx(x, self.num_bins)
            centre = idx_to_float(class_idx, self.num_bins)
            cdf_lo = self.cts_dist.cdf(centre - self.half_bin_width)
            cdf_hi = self.cts_dist.cdf(centre + self.half_bin_width)
            if self.clip:
                cdf_lo = torch.where(class_idx <= 0, torch.zeros_like(centre), cdf_lo)
                cdf_hi = torch.where(class_idx >= (self.num_bins - 1), torch.ones_like(centre), cdf_hi)
                return cdf_hi - cdf_lo
            else:
                cdf_min = self.cts_dist.cdf(torch.zeros_like(centre) - 1)
                cdf_max = self.cts_dist.cdf(torch.ones_like(centre))
                cdf_range = cdf_max - cdf_min
                cdf_mask = cdf_range < self.min_prob
                cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)
                prob = (cdf_hi - cdf_lo) / cdf_range
                return torch.where(cdf_mask, (prob * 0) + (1 / self.num_bins), prob)

    def log_prob(self, x):
        prob = self.prob(x)
        return torch.where(
            prob < self.min_prob,
            self.cts_dist.log_prob(quantize(x, self.num_bins)) + self.log_bin_width,
            safe_log(prob),
        )


In [None]:
import torch

from discretised.bfn_discretised import right_pad_dims_to

def reconstruction_loss(data, mean):
    flat_data = data.flatten(start_dim=1)
    t = torch.ones_like(data).flatten(start_dim=1).float()
    noise_dev =  bfn_model.sigma_one
    num_bins = bfn_model.k
    final_dist = torch.distributions.Normal(mean, noise_dev)
    final_dist = DiscretizedCtsDistribution(final_dist, num_bins, device=t.device, batch_dims=mean.ndim - 1)
    reconstruction_loss = -final_dist.log_prob(flat_data)
    return reconstruction_loss

recon_losses = []
for test_batch in test_dls:
    bs = test_batch.shape[0]
    test_batch = test_batch.view(bs, -1)
    t = right_pad_dims_to(bfn_model.get_time_at_t(1., bs=bs), test_batch)
    gamma = right_pad_dims_to(bfn_model.get_gamma_t(t), test_batch)
    # Shape-> Tensor[B, D] from the discretised data, create noisy sender sample from a normal centered around data and known variance
    std = torch.sqrt(gamma*(1-gamma))
    mean = test_batch*gamma
    sender_mu_sample = bfn_model.get_normal_sample(mean, std)
    output_distribution = bfn_model.discretised_output_distribution(sender_mu_sample, t, gamma=gamma)
    mean = torch.sum(output_distribution*bfn_model.k_centers, dim=-1)
    rec_loss = reconstruction_loss(test_batch, mean).flatten(start_dim=1).mean()
    recon_losses.append(torch.tensor(rec_loss).mean())

recon_losses = torch.stack(recon_losses)
print(recon_losses.mean())
