In [1]:
import numpy as np
import pandas as pd
import torch

import torch.nn.functional as F

In [2]:
class BetaScheduler:
    def __init__(self,
                 schedule_type: str="linear",
                 start: float=0.0001,
                 stop: float=0.02,
                ):
        self._schedule_type = schedule_type
        self._start = start
        self._stop = stop
    
    def generate(self, n_steps: int):
        if self._schedule_type == "linear":
            return torch.linspace(
                start=self._start,
                end=self._stop,
                steps=n_steps,
                dtype=torch.float32,
            )
        else:
            raise ValueError(f"Unknown schedule type: {self._schedule_type}")


class DiffusionModel:
    
    def __init__(self,
                 n_steps: int,
                 beta_scheduler: BetaScheduler,
                ):
        self._n_steps = n_steps
        self._betas = beta_scheduler.generate(n_steps=n_steps)
    
    def forward_step(self, batch: torch.Tensor, timestep: int):
        mean =  torch.sqrt(self._gather(input_arr=1-self._betas, timestep=timestep)) * batch
        
    def _gather(self, input_arr: torch.Tensor, timestep: torch.Tensor):
        """Returns an element at a specific timestep in the input array.
        """
        return input_arr.gather(-1, timestep).reshape(-1, 1)

In [3]:
means = torch.tensor([
    [ 0,  1],
    [ 1,  0],
    [ 0, -1],
    [-1,  0],
], dtype=torch.float32) * 4
cov = torch.eye(2, dtype=torch.float32) * 0.1

dist = torch.distributions.MultivariateNormal(loc=means, covariance_matrix=cov)
raw_data = dist.sample(sample_shape=(10,)).reshape(-1, cov.shape[0])

In [4]:
betas = BetaScheduler().generate(250)

dm = DiffusionModel(n_steps=250, beta_scheduler=BetaScheduler())

In [5]:
dm.forward_step(batch=raw_data, timestep=torch.tensor(100))