In [1]:
from sklearn.datasets import make_swiss_roll, make_moons
from matplotlib import pyplot as plt
from genexp.models import DiffusionModel

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

from genexp.sampling import VPSDE, sample_trajectories_ddpm, sample_trajectories_memoryless, EMDiffusionSampler, DDIMSampler, EulerMaruyamaSampler, MemorylessSampler
from genexp.trainers.adjoint_matching import AMTrainerFlow
from genexp.trainers.genexp import FDCTrainerFlow

from matplotlib.widgets import Button, Slider

2025-12-02 19:16:00.205007: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-02 19:16:00.235337: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-12-02 19:16:00.235380: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-02 19:16:00.236481: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-02 19:16:00.242791: I tensorflow/core/platform/cpu_feature_guar

In [2]:
class LightningDiffusion(LightningModule):
    def __init__(self, model: DiffusionModel):
        super().__init__()
        self.model = model

    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    

    def training_step(self, batch, batch_idx):
        x0, = batch
        t = torch.rand(x0.shape[0]).to(x0.device)
        alpha, sig = self.model.sde.get_alpha_sigma(t[:, None])
        eps = torch.randn(x0.shape).to(x0.device)

        xt = torch.sqrt(alpha) * x0 + sig * eps

        eps_pred = self(xt, t[:, None])

        loss = torch.mean((eps - eps_pred)**2) / 2.
        self.log('loss', loss, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [3]:
sworl, r = make_swiss_roll(n_samples=100000, noise=0.1)

dataset = torch.tensor(sworl, dtype=torch.float32)
dataset = torch.hstack((dataset[:, 0, None], dataset[:, 2, None]))

network = nn.Sequential(
    nn.Linear(3, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 2)
)

sde = VPSDE(0.1, 12)

model = DiffusionModel(network, sde)
pl_model = LightningDiffusion(model)

model.load_state_dict(torch.load('swirl_model.pth'))

<All keys matched successfully>

In [None]:
from omegaconf import OmegaConf
import copy

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
config = OmegaConf.load('../configs/example_fdc.yaml')
am_config = config.adjoint_matching
sampler = EulerMaruyamaSampler(model, data_shape=(2,), device=device)
model = model.to(device)
fdc_trainer = AMTrainerFlow(config, copy.deepcopy(model), copy.deepcopy(model),
                            grad_rewardsdevice=device, sampler=sampler)

In [None]:
for k in range(config.num_md_iterations):
    for i in range(config.adjoint_matching.num_iterations):
        dataset = fdc_trainer.generate_dataset()
        fdc_trainer.finetune(dataset, steps=config.adjoint_matching.finetune_steps)

    fdc_trainer.update_base_model()