In [1]:
from __future__ import annotations
from typing import Callable, Any, Literal
from pathlib import Path

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import datasets, transforms, utils
from tqdm.auto import tqdm, trange

In [2]:
# This UNET-style prediction model was originally included as part of the Score-based generative modelling tutorial
# by Yang Song et al: https://colab.research.google.com/drive/120kYYBOVa1i0TD85RjlEkFjaWDxSFUx3?usp=sharing


class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""

    def __init__(self, embed_dim, scale=30.0):
        super().__init__()
        # Randomly sample weights during initialization. These weights are fixed
        # during optimization and are not trainable.
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * torch.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Dense(nn.Module):
    """A fully connected layer that reshapes outputs to feature maps."""

    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.dense(x)[..., None, None]


class ScoreNet(nn.Module):
    """A time-dependent score-based model built upon U-Net architecture."""

    def __init__(
        self,
        marginal_prob_std: Callable[[torch.Tensor], torch.Tensor] | None,
        channels=(32, 64, 128, 256),
        embed_dim=256,
    ):
        """Initialize a time-dependent score-based network.

        Args:
          marginal_prob_std: A function that takes time t and gives the standard
            deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
          channels: The number of channels for feature maps of each resolution.
          embed_dim: The dimensionality of Gaussian random feature embeddings.
        """
        super().__init__()
        # Gaussian random feature embedding layer for time
        self.embed = nn.Sequential(
            GaussianFourierProjection(embed_dim=embed_dim),
            nn.Linear(embed_dim, embed_dim),
        )
        # Encoding layers where the resolution decreases
        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
        self.dense1 = Dense(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = Dense(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = Dense(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = Dense(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

        # Decoding layers where the resolution increases
        self.tconv4 = nn.ConvTranspose2d(
            channels[3], channels[2], 3, stride=2, bias=False
        )
        self.dense5 = Dense(embed_dim, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
        self.tconv3 = nn.ConvTranspose2d(
            channels[2] + channels[2],
            channels[1],
            3,
            stride=2,
            bias=False,
            output_padding=1,
        )
        self.dense6 = Dense(embed_dim, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
        self.tconv2 = nn.ConvTranspose2d(
            channels[1] + channels[1],
            channels[0],
            3,
            stride=2,
            bias=False,
            output_padding=1,
        )
        self.dense7 = Dense(embed_dim, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
        self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)

        self.act = nn.SiLU()
        self.marginal_prob_std = marginal_prob_std

    def forward(self, x, t):
        # Obtain the Gaussian random feature embedding for t
        embed = self.act(self.embed(t))
        # Encoding path
        h1 = self.conv1(x)
        ## Incorporate information from t
        h1 += self.dense1(embed)
        ## Group normalization
        h1 = self.gnorm1(h1)
        h1 = self.act(h1)
        h2 = self.conv2(h1)
        h2 += self.dense2(embed)
        h2 = self.gnorm2(h2)
        h2 = self.act(h2)
        h3 = self.conv3(h2)
        h3 += self.dense3(embed)
        h3 = self.gnorm3(h3)
        h3 = self.act(h3)
        h4 = self.conv4(h3)
        h4 += self.dense4(embed)
        h4 = self.gnorm4(h4)
        h4 = self.act(h4)

        # Decoding path
        h = self.tconv4(h4)
        ## Skip connection from the encoding path
        h += self.dense5(embed)
        h = self.tgnorm4(h)
        h = self.act(h)
        h = self.tconv3(torch.cat([h, h3], dim=1))
        h += self.dense6(embed)
        h = self.tgnorm3(h)
        h = self.act(h)
        h = self.tconv2(torch.cat([h, h2], dim=1))
        h += self.dense7(embed)
        h = self.tgnorm2(h)
        h = self.act(h)
        h = self.tconv1(torch.cat([h, h1], dim=1))

        # Normalize output (if given a std function)
        if self.marginal_prob_std is not None:
            h /= self.marginal_prob_std(t)[:, None, None, None]
        return h

In [3]:
# ExponentialMovingAverage implementation as used in pytorch vision
# https://github.com/pytorch/vision/blob/main/references/classification/utils.py#L159

# BSD 3-Clause License

# Copyright (c) Soumith Chintala 2016,
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.

# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.

# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
    """Maintains moving averages of model parameters using an exponential decay.
    ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
    `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
    is used to compute the EMA.
    """

    def __init__(self, avg_model, decay, avg_device="cpu"):
        def ema_avg(avg_model_param, model_param, _num_averaged):
            return decay * avg_model_param + (1 - decay) * model_param

        super().__init__(avg_model, avg_device, ema_avg, use_buffers=True)

In [4]:
class DDPM(nn.Module):
    def __init__(
        self,
        network,
        max_t: int = 100,
        beta_1: float = 1e-4,
        beta_max_t: float = 2e-2,
        predict_mean_by: Literal["e", "u", "x0"] = "e",
        reduce_variance_by: Literal["low-discrepency", "importance-sampling"]
        | None = None,
    ):
        """
        Initialize Denoising Diffusion Probabilistic Model

        Parameters
        ----------
        network: nn.Module
            The inner neural network used by the diffusion process. Typically a Unet.
        beta_1: float
            beta_t value at t=1
        beta_max_t: [float]
            beta_t value at t=T (last step)
        max_t: int
            The number of diffusion steps.
        predict_mean_by: str
            Which target (epsilon, x_t-1 mean, x_0 mean) the network should predict.
        reduce_variance_by: str | None
            An optional method to reduce the variance of the loss estimate.
        """

        super().__init__()

        # Normalize time input before evaluating neural network
        # Reshape input into image format and normalize time value before sending it to network model
        self._network = network

        # Total number of time steps
        self.max_t = max_t
        self.predict_mean_by = predict_mean_by
        self.reduce_variance_by = reduce_variance_by

        # Registering as buffers to ensure they get transferred to the GPU automatically
        self.register_buffer("beta", torch.linspace(beta_1, beta_max_t, max_t + 1))
        self.register_buffer("alpha", 1 - self.beta)
        self.register_buffer("alpha_bar", self.alpha.cumprod(dim=0))

        if self.reduce_variance_by == "importance-sampling":
            self.imps_warmup = True
            history_length = 10
            self.register_buffer(
                "imps_idx", torch.zeros((self.max_t + 1,), dtype=torch.int64)
            )
            self.register_buffer(
                "imps_hist", torch.zeros((self.max_t + 1, history_length))
            )
            startup_ts = (
                torch.arange(1, self.max_t + 1)
                .expand((history_length, max_t))
                .flatten()
            )
            self.register_buffer(
                "imps_startup_ts", startup_ts[torch.randperm(startup_ts.numel())]
            )
            self.imps_startup_idx = 0

    def network(self, x, t):
        return self._network(
            x.reshape(-1, 1, 28, 28), t.squeeze() / self.max_t
        ).reshape(-1, 28 * 28)

    def forward_diffusion(self, x0, t, epsilon):
        """
        q(x_t | x_0)
        Forward diffusion from an input datapoint x0 to an xt at timestep t, provided a N(0,1) noise sample epsilon.
        Note that we can do this operation in a single step

        Parameters
        ----------
        x0: torch.tensor
            x value at t=0 (an input image)
        t: int
            step index
        epsilon:
            noise sample

        Returns
        -------
        torch.tensor
            image at timestep t
        """

        mean = torch.sqrt(self.alpha_bar[t]) * x0
        std = torch.sqrt(1 - self.alpha_bar[t])

        return mean + std * epsilon

    def reverse_diffusion(self, xt, t, epsilon):
        """
        p(x_{t-1} | x_t)
        Single step in the reverse direction, from x_t (at timestep t) to x_{t-1}, provided a N(0,1) noise sample epsilon.

        Parameters
        ----------
        xt: torch.tensor
            x value at step t
        t: int
            step index
        epsilon:
            noise sample

        Returns
        -------
        torch.tensor
            image at timestep t-1
        """
        if self.predict_mean_by == "e":
            mean = (
                1.0
                / torch.sqrt(self.alpha[t])
                * (
                    xt
                    - (self.beta[t])
                    / torch.sqrt(1 - self.alpha_bar[t])
                    * self.network(xt, t)
                )
            )
        elif self.predict_mean_by == "u":
            mean = self.network(xt, t)
        elif self.predict_mean_by == "x0":
            full_epsilon = xt - self.network(xt, t)
            mean = (1.0 / torch.sqrt(1 - self.beta[t])) * (
                xt - (self.beta[t] / torch.sqrt(1 - self.alpha_bar[t])) * full_epsilon
            )
        else:
            raise ValueError(f"Invalid value: {self.predict_mean_by=}")

        mask = t > 0
        std = torch.where(
            mask,
            torch.sqrt(
                ((1 - self.alpha_bar[t - 1]) / (1 - self.alpha_bar[t])) * self.beta[t]
            ),
            0,
        )

        return mean + std * epsilon

    @torch.no_grad()
    def sample(self, shape):
        """
        Sample from diffusion model (Algorithm 2 in Ho et al, 2020)

        Parameters
        ----------
        shape: tuple
            Specify shape of sampled output. For MNIST: (nsamples, 28*28)

        Returns
        -------
        torch.tensor
            sampled image
        """

        # Sample xT: Gaussian noise
        xt = torch.randn(shape).to(self.beta.device)
        for t in range(self.max_t, 0, -1):
            noise = torch.randn_like(xt) if t > 1 else 0
            t = torch.tensor(t).expand(xt.shape[0], 1).to(self.beta.device)
            xt = self.reverse_diffusion(xt, t, noise)

        return xt

    def elbo_simple(self, x0: torch.Tensor):
        """
        ELBO training objective (Algorithm 1 in Ho et al, 2020)

        Parameters
        ----------
        x0: torch.tensor
            Input image

        Returns
        -------
        float
            ELBO value
        """
        # Sample time step t
        b_size = x0.shape[0]
        if self.reduce_variance_by is None:
            t = torch.randint(1, self.max_t, (b_size, 1)).to(x0.device)
            t_weights = torch.ones_like(t)
        elif self.reduce_variance_by == "low-discrepency":
            # Remove last, since it's equal under modulus. Add extra one to compensate rounding at start and end
            ts_base = torch.linspace(1, self.max_t + 1, b_size + 1)[:-1]
            # Random addition between 0 and the step size of ts_base
            ts_offset = torch.rand_like(ts_base) * ts_base[1]
            # Taking a t value from a random index, yields a marginal distribution that is uniform over the reals [1, max_t+1).
            # Taking the floor, yields a marginal distribution that is unfirom over the integers [1, max_t]
            # Clamp to ensure floating point errors don't cause out-of-bounds
            # Unlike "Variational Diffusion Models", individual positions (e.g. the first t) are not uniformly distributed,
            # this could be sovled with a random permutation, but that is needless work.
            t = (
                (ts_base + ts_offset)
                .floor()
                .int()
                .clamp(1, self.max_t)
                .view((b_size, 1))
                .to(x0.device)
            )
            t_weights = torch.ones_like(t)
        elif self.reduce_variance_by == "importance-sampling":
            if (
                self.imps_warmup
                and self.imps_startup_idx + b_size < self.imps_startup_ts.numel()
            ):
                t = (
                    self.imps_startup_ts[
                        self.imps_startup_idx : self.imps_startup_idx + b_size
                    ]
                    .view((b_size, 1))
                    .to(x0.device)
                )
                t_weights = torch.ones_like(t)
                self.imps_startup_idx += b_size
            elif self.imps_warmup:
                self.imps_warmup = False
                print(f"Warmup complete, just-before counts\n{self.imps_idx}")
                missing_warmup = (
                    self.imps_startup_idx + b_size - self.imps_startup_ts.numel()
                )
                missing_elems = torch.randint(1, self.max_t, (missing_warmup,)).to(x0.device)
                t = torch.cat(
                    (self.imps_startup_ts[self.imps_startup_idx :], missing_elems),
                    dim=0,
                ).view((b_size, 1))
                t_weights = torch.ones_like(t)
            else:
                t = torch.randint(1, self.max_t, (b_size, 1)).to(x0.device)
                weights = torch.sqrt(self.imps_hist.mean(dim=1)).to(x0.device)
                t_weights = weights[t.view(-1)].view_as(t)
        else:
            raise ValueError(f"Invalid value: {self.reduce_variance_by=}")

        # Sample noise
        epsilon = torch.randn_like(x0)

        xt = self.forward_diffusion(x0, t, epsilon)

        if self.predict_mean_by == "e":
            target = epsilon
        elif self.predict_mean_by == "u":
            target = self.forward_diffusion(x0, t - 1, epsilon)
        elif self.predict_mean_by == "x0":
            target = x0
        else:
            raise ValueError(f"Invalid value: {self.predict_mean_by=}")

        unweighted_se = nn.MSELoss(reduction="none")(target, self.network(xt, t)).mean(
            tuple(range(1, len(target.shape)))
        )
        assert unweighted_se.shape == (
            b_size,
        ), f"Strange shape {unweighted_se.shape=} {x0.shape=}"

        if self.reduce_variance_by == "importance-sampling":
            uniq_t, rev_t, count_t = torch.unique(
                t.view(-1), return_inverse=True, return_counts=True
            )
            assert rev_t.shape == (b_size,), f"Strange shape {rev_t=}"
            # Nth occurence (zero-indexed) of the particular value of t
            same_t = torch.eq(rev_t.view((-1, 1)), rev_t.view((1, -1)))
            nth_occurence = torch.triu(same_t).sum(dim=0) - 1
            assert nth_occurence.shape == (b_size,), f"Strange shape {nth_occurence=}"
            hist_idx = (self.imps_idx[t.view(-1)] + nth_occurence) % self.imps_hist.shape[1]
            assert (
                len(self.imps_hist.shape) == 2
            ), f"Strange shape {self.imps_hist.shape=}"
            assert hist_idx.shape == (b_size,), f"Strange shape {hist_idx.shape=}"
            self.imps_hist[t.view(-1), hist_idx] = unweighted_se.detach()**2
            self.imps_idx[uniq_t] += count_t

        return -(unweighted_se * t_weights / t_weights.sum()).mean()

    def loss(self, x0):
        """
        Loss function. Just the negative of the ELBO.
        """
        return -self.elbo_simple(x0)

In [5]:
def train(
    train_model: nn.Module,
    train_optimizer: torch.optim.Optimizer,
    train_scheduler: torch.optim.lr_scheduler.LRScheduler,
    train_dataloader: torch.utils.data.DataLoader[tuple[torch.Tensor, Any]],
    train_epochs: int,
    train_device: torch.device,
    ema: bool = True,
    per_epoch_callback: Callable = None,
    double_bar: bool = False,
    extra_desc: str = None,
):
    """
    Training loop

    Parameters
    ----------
    train_model: nn.Module
        Pytorch model
    train_optimizer: optim.Optimizer
        Pytorch optimizer to be used for training
    train_scheduler: optim.LRScheduler
        Pytorch learning rate scheduler
    train_dataloader: utils.DataLoader
        Pytorch dataloader
    train_epochs: int
        Number of epochs to train
    train_device: torch.device
        Pytorch device specification
    ema: Boolean
        Whether to activate Exponential Model Averaging
    per_epoch_callback: function
        Called at the end of every epoch
    double_bar: bool
        Whether to add a progress bar within an epoch
    extra_desc: str
        Extra progress bar description
    """

    # Setup progress bar

    if ema:
        ema_global_step_counter = 0
        ema_steps = 10
        ema_adjust = train_dataloader.batch_size * ema_steps / train_epochs
        ema_decay = 1.0 - 0.995
        ema_alpha = min(1.0, (1.0 - ema_decay) * ema_adjust)
        ema_model = ExponentialMovingAverage(
            train_model, avg_device=train_device, decay=1.0 - ema_alpha
        )

    full_bar = trange(
        train_epochs,
        desc="Training" if not extra_desc else f"Training ({extra_desc})",
        smoothing=0.05,
        unit="epoch",
        position=1,
    )
    for epoch in full_bar:
        # Switch to train mode
        train_model.train()

        epoch_bar = tqdm(
            train_dataloader,
            leave=False,
            position=2,
            desc=f"Epoch {epoch}",
            disable=not double_bar,
        )
        for x, _label in epoch_bar:
            x = x.to(train_device)
            train_optimizer.zero_grad()
            loss = train_model.loss(x)
            loss.backward()
            train_optimizer.step()
            train_scheduler.step()

            # Update progress bar
            epoch_bar.set_postfix(
                loss=f"⠀{loss.item():12.4f}",
                lr=f"{train_scheduler.get_last_lr()[0]:.2E}",
                refresh=False,
            )

            if ema:
                ema_global_step_counter += 1
                if ema_global_step_counter % ema_steps == 0:
                    ema_model.update_parameters(train_model)

        if per_epoch_callback:
            per_epoch_callback(ema_model.module if ema else train_model)

In [6]:
# Parameters
T = 1000
learning_rate = 1e-3
epochs = 100
batch_size = 256

In [7]:
# Rather than treating MNIST images as discrete objects, as done in Ho et al 2020,
# we here treat them as continuous input data, by dequantizing the pixel values (adding noise to the input data)
# Also note that we map the 0..255 pixel values to [-1, 1], and that we process the 28x28 pixel values as a flattened 784 tensor.
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Lambda(
            lambda x: x + torch.rand(x.shape) / 255
        ),  # Dequantize pixel values
        transforms.Lambda(lambda x: (x - 0.5) * 2.0),  # Map from [0,1] -> [-1, -1]
        transforms.Lambda(lambda x: x.flatten()),
    ]
)

# Download and transform train dataset
dataloader_train = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
    batch_size=batch_size,
    shuffle=True,
)

In [8]:
# Select device
if torch.cuda.is_available():
    device = "cuda:0"
elif torch.mps.is_available():
    device = "mps:0"
else:
    device = "cpu"

print(f"Running on {device}")

Running on mps:0


In [9]:
def reporter(train_model: nn.Module):
    """Callback function used for plotting images during training"""

    # Switch to eval mode
    train_model.eval()

    with torch.no_grad():
        nsamples = 10
        samples = train_model.sample((nsamples, 28 * 28)).cpu()

        # Map pixel values back from [-1,1] to [0,1]
        samples = (samples + 1) / 2
        samples = samples.clamp(0.0, 1.0)

        # Plot in grid
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=nsamples)
        plt.gca().set_axis_off()
        plt.imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        plt.show()

In [10]:
# Base models, three different ways to predict the mean
for tgt in ("e", "u", "x0"):
    if Path(f"ddpm_models/base_{tgt}.pt").exists():
        print(f"Skipping target {tgt}, already exists")
        continue

    torch.manual_seed(0)

    # Construct Unet
    # The original ScoreNet expects a function with std for all the
    # different noise levels, such that the output can be rescaled.
    # Since we are predicting the noise (rather than the score), we
    # ignore this rescaling and just set std=1 for all t.
    mnist_unet = ScoreNet(None)

    # Construct model
    model = DDPM(mnist_unet, max_t=T, predict_mean_by=tgt).to(device)
    model.compile(fullgraph=True)

    # Construct optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Setup simple scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

    # Call training loop
    train(
        model,
        optimizer,
        scheduler,
        dataloader_train,
        train_epochs=epochs,
        train_device=device,
        ema=True,
        per_epoch_callback=reporter,
        extra_desc=f"target={tgt}",
    )

    torch.save(model, f"ddpm_models/base_{tgt}.pt")

Skipping target e, already exists
Skipping target u, already exists
Skipping target x0, already exists


In [11]:
for mthd in ("low-discrepency", "importance-sampling"):
    if Path(f"ddpm_models/variance_{mthd}.pt").exists():
        print(f"Skipping method {mthd}, already exists")
        continue

    torch.manual_seed(0)

    # Construct Unet
    # The original ScoreNet expects a function with std for all the
    # different noise levels, such that the output can be rescaled.
    # Since we are predicting the noise (rather than the score), we
    # ignore this rescaling and just set std=1 for all t.
    mnist_unet = ScoreNet(None)

    # Construct model
    model = DDPM(mnist_unet, max_t=T, reduce_variance_by=mthd).to(device)
    model.compile(fullgraph=True)

    # Construct optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Setup simple scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

    # Call training loop
    train(
        model,
        optimizer,
        scheduler,
        dataloader_train,
        train_epochs=epochs,
        train_device=device,
        ema=True,
        per_epoch_callback=reporter,
        extra_desc=f"method={mthd}",
    )

    torch.save(model, f"ddpm_models/variance_{mthd}.pt")

Skipping method low-discrepency, already exists
Skipping method importance-sampling, already exists
