In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
import bayescfm as bcfm

# 1) Dummy data
class WhiteNoise(Dataset):
    def __init__(self, n=512, shape=(3,32,32), seed=0):
        g = torch.Generator().manual_seed(seed)
        C,H,W = shape
        self.x = (torch.rand(n, C, H, W, generator=g)*2-1).float()
        self.y = torch.randint(0, 10, (n,), generator=g)
    def __len__(self): return self.x.size(0)
    def __getitem__(self, i): return self.x[i], self.y[i]

loader = DataLoader(WhiteNoise(), batch_size=64, shuffle=True)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# 2) Plain CFM training (OT path)
model = bcfm.UNetCFM(
    in_channels=3, out_channels=3, model_channels=64,
    channel_mult=(1,2,2), num_res_blocks=1,
    attn_resolutions=(16,), num_heads=4,
    num_classes=10, class_dropout_prob=0.1,
)

ema_model = bcfm.train_cfm(model, loader, epochs=5, lr=1e-3, device=device, log_every=10)

[epoch 2] step      10  loss=1.1652
[epoch 3] step      20  loss=1.1688
[epoch 4] step      30  loss=1.1222
[epoch 5] step      40  loss=0.9456


In [6]:
# 3) Regularized training (autodiff gradient-field penalties)
model = bcfm.UNetCFM(
    in_channels=3, out_channels=3, model_channels=64,
    channel_mult=(1,2,2), num_res_blocks=1,
    attn_resolutions=(16,), num_heads=4,
    num_classes=10, class_dropout_prob=0.1,
)

ema_reg = bcfm.train_cgm(
    model, loader, epochs=5, lr=1e-3, device=device,
    lambda_curl=1e-4, 
    lambda_mono=1e-4,
    probes=1,  
    pool_factor=None,
    probe_dist="rademacher", 
    orthogonalize=True,
    penalty_train_flag=False,
    normalize_curl=False,
    log_every = 10
)

[epoch 2] step      10 total=1.1469  cfm=1.1468  curl=0.0000  mono=1.0478
[epoch 3] step      20 total=0.9570  cfm=0.9569  curl=0.0000  mono=1.1282
[epoch 4] step      30 total=0.9619  cfm=0.9618  curl=0.0000  mono=1.1074
[epoch 5] step      40 total=0.9354  cfm=0.9353  curl=0.0000  mono=1.0598


In [12]:
model = bcfm.UNetCFM(
    in_channels=3, out_channels=3, model_channels=64,
    channel_mult=(1,2,2), num_res_blocks=1,
    attn_resolutions=(16,), num_heads=4,
    num_classes=10, class_dropout_prob=0.1,
)

ema_bayes, posterior = bcfm.train_cgm_bayes(
    model, loader, epochs=10, lr=2e-3, device=device,
    lambda_curl=1e-4, lambda_mono=1e-4,
    probes=1, 
    sgmcmc_enable=True, sgmcmc_alg="sghmc",
    sgmcmc_eta=2e-6, sgmcmc_temperature=0.1, sgmcmc_friction=0.2,
    sgmcmc_collect=True, sgmcmc_burnin_steps=100, sgmcmc_thin=5, sgmcmc_max_samples=10,
    return_posterior=True,
    log_every = 10
)

[epoch 2] step      10 total=1.4534  cfm=1.4533  curl=0.0002  mono=0.6938  (SG-MCMC sghmc, eta=2.00e-06, T=0.1)
[epoch 3] step      20 total=1.4558  cfm=1.4557  curl=0.0003  mono=0.6935  (SG-MCMC sghmc, eta=2.00e-06, T=0.1)
[epoch 4] step      30 total=1.4667  cfm=1.4666  curl=0.0003  mono=0.6971  (SG-MCMC sghmc, eta=2.00e-06, T=0.1)
[epoch 5] step      40 total=1.4561  cfm=1.4560  curl=0.0002  mono=0.6961  (SG-MCMC sghmc, eta=2.00e-06, T=0.1)
[epoch 7] step      50 total=1.4595  cfm=1.4594  curl=0.0003  mono=0.6938  (SG-MCMC sghmc, eta=2.00e-06, T=0.1)
[epoch 8] step      60 total=1.4695  cfm=1.4695  curl=0.0002  mono=0.6963  (SG-MCMC sghmc, eta=2.00e-06, T=0.1)
[epoch 9] step      70 total=1.4621  cfm=1.4620  curl=0.0002  mono=0.7010  (SG-MCMC sghmc, eta=2.00e-06, T=0.1)
[epoch 10] step      80 total=1.4795  cfm=1.4795  curl=0.0003  mono=0.7039  (SG-MCMC sghmc, eta=2.00e-06, T=0.1)


In [6]:
model = bcfm.UNetCFM(
    in_channels=3, out_channels=3, model_channels=64,
    channel_mult=(1,2,2), num_res_blocks=1,
    attn_resolutions=(16,), num_heads=4,
    num_classes=10, class_dropout_prob=0.1,
)

# ema_lap = bcfm.train_cgm(
#     model, loader, epochs=5, lr=1e-3, device=device,
#     lambda_curl=1e-4, 
#     lambda_mono=1e-4,
#     probes=1,  
#     pool_factor=None,
#     probe_dist="rademacher", 
#     orthogonalize=True,
#     penalty_train_flag=False,
#     normalize_curl=False,
#     log_every = 10
# )

lap = bcfm.LaplaceCFM(model, mode='last_layer')  # or approx="last_layer"
lap.fit(loader, device=device)
lap_models = lap.posterior_sample(5, device=device, seed=0)

[epoch 2] step      10 total=1.1435  cfm=1.1434  curl=0.0000  mono=0.9787
[epoch 3] step      20 total=0.8632  cfm=0.8631  curl=0.0000  mono=1.1182
[epoch 4] step      30 total=1.0151  cfm=1.0150  curl=0.0000  mono=1.0077
[epoch 5] step      40 total=0.8965  cfm=0.8964  curl=0.0000  mono=1.0472


AssertionError: call fit() before sampling models

In [10]:
v_samp[2].shape

torch.Size([4, 3, 32, 32])