In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import bayescfm as bcfm

# 1) Dummy data
class WhiteNoise(Dataset):
    def __init__(self, n=512, shape=(3,32,32), seed=0):
        g = torch.Generator().manual_seed(seed)
        C,H,W = shape
        self.x = (torch.rand(n, C, H, W, generator=g)*2-1).float()
        self.y = torch.randint(0, 10, (n,), generator=g)
    def __len__(self): return self.x.size(0)
    def __getitem__(self, i): return self.x[i], self.y[i]

loader = DataLoader(WhiteNoise(), batch_size=64, shuffle=True)

# 2) Model
model = bcfm.UNetCFM(
    in_channels=3, out_channels=3, model_channels=64,
    channel_mult=(1,2,2), num_res_blocks=1,
    attn_resolutions=(16,), num_heads=4,
    num_classes=10, class_dropout_prob=0.1,
)

In [2]:
# 3) Plain CFM training (OT path)
device = "cuda" if torch.cuda.is_available() else "cpu"
ema_model = bcfm.train_cfm(model, loader, epochs=5, lr=1e-3, device=device, log_every=10)

[epoch 2] step      10  loss=1.1652
[epoch 3] step      20  loss=1.1688
[epoch 4] step      30  loss=1.1222
[epoch 5] step      40  loss=0.9456


In [5]:
ema_reg = bcfm.train_cgm(
    model, loader, epochs=5, lr=1e-3, device=device,
    lambda_curl=1e-4, 
    lambda_mono=1e-4,
    probes=1,  
    pool_factor=None,
    probe_dist="rademacher", 
    orthogonalize=True,
    penalty_train_flag=False,
    normalize_curl=False,
    log_every = 10
)

[epoch 2] step      10 total=0.9888  cfm=0.9887  curl=0.0001  mono=1.1447
[epoch 3] step      20 total=0.9517  cfm=0.9516  curl=0.0001  mono=1.0154
[epoch 4] step      30 total=0.9348  cfm=0.9347  curl=0.0002  mono=0.9523
[epoch 5] step      40 total=0.8763  cfm=0.8762  curl=0.0001  mono=0.8572
