##### Continuous Diffusion Language Model

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
# Sinusoidal timestep embedding
def get_timestep_embedding(timesteps, embedding_dim):
    """
    Create sinusoidal timestep embeddings.
    Args:
        timesteps: [batch, 1] or [batch], int or float tensor
        embedding_dim: int, embedding dimension
    Returns:
        [batch, embedding_dim] tensor
    """
    if timesteps.dim() == 1:
        timesteps = timesteps.unsqueeze(-1)  # [batch, 1]
    half_dim = embedding_dim // 2
    exponent = -np.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
    emb = timesteps.float() * torch.exp(exponent)  # [batch, half_dim]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # [batch, embedding_dim]
    if embedding_dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb
    
class SimpleMLP(nn.Module):  # Simple MLP as denoising model for noise prediction
    def __init__(self, dim, t_emb_dim=32):
        super().__init__()
        self.t_emb_dim = t_emb_dim
        self.net = nn.Sequential(
            nn.Linear(dim + t_emb_dim, 128),
            nn.ReLU(),
            nn.Linear(128, dim)
        )

    def forward(self, x, t):
        # x: [batch, dim], t: [batch] or [batch, 1]
        # t: timestep, will be encoded
        t_emb = get_timestep_embedding(t, self.t_emb_dim)  # [batch, t_emb_dim]
        xt = torch.cat([x, t_emb], dim=-1)
        return self.net(xt)

In [5]:
class DiffusionLMContinuous:
    def __init__(self, dim, timesteps=1000, device='cpu', t_emb_dim=32):
        """
        Args:
            dim: dimension of data (e.g. word embedding dim)
            timesteps: number of diffusion steps (T)
            device: torch device
            t_emb_dim: dimension of timestep embedding
        """
        self.dim = dim
        self.timesteps = timesteps
        self.device = device
        self.beta = np.linspace(1e-4, 0.02, timesteps)  # Linear noise schedule
        self.alpha = 1. - self.beta
        self.alpha_bar = np.cumprod(self.alpha)
        self.model = SimpleMLP(dim, t_emb_dim).to(device)
        self.t_emb_dim = t_emb_dim
        ## TODO: EMA.

    def q_sample(self, x0, t):
        """
        Forward process (add noise):
            q(x_t | x_0) = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
        Args:
            x0: [batch, dim], original data
            t: [batch], timestep
        Returns:
            xt: [batch, dim], noisy data
            noise: [batch, dim], added noise
        """
        batch = x0.shape[0]
        alpha_bar_t = torch.from_numpy(self.alpha_bar[t]).float().to(self.device)  # [batch]
        alpha_bar_t = alpha_bar_t.view(-1, 1)
        noise = torch.randn_like(x0)  # Gaussian noise
        xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noise  # Add noise
        return xt, noise

    def p_sample(self, xt, t):
        """
        Reverse process (denoise):
            p(x_{t-1} | x_t) = N(mean, beta_t * I)
            mean = (1 / sqrt(alpha_t)) * (x_t - (1 - alpha_t) / sqrt(1 - alpha_bar_t) * pred_noise)
            x0_pred = (x_t - sqrt(1 - alpha_bar_t) * pred_noise) / sqrt(alpha_bar_t)
        Args:
            xt: [batch, dim], current noisy data
            t: [batch], timestep (int or float)
        Returns:
            next_xt: [batch, dim], denoised data for previous step
        """
        # t: [batch] or [batch, 1]
        if isinstance(t, np.ndarray):
            t = torch.from_numpy(t).to(self.device)
        if t.dim() == 1:
            t = t.view(-1)
        pred_noise = self.model(xt, t)  # using the model.
        alpha_t = torch.from_numpy(self.alpha[t.cpu().numpy()]).float().to(self.device).view(-1, 1)
        alpha_bar_t = torch.from_numpy(self.alpha_bar[t.cpu().numpy()]).float().to(self.device).view(-1, 1)
        beta_t = torch.from_numpy(self.beta[t.cpu().numpy()]).float().to(self.device).view(-1, 1)
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - alpha_bar_t)
        mean = (1 / torch.sqrt(alpha_t)) * (xt - (1 - alpha_t) / sqrt_one_minus_alpha_bar_t * pred_noise)
        if t[0] > 0:
            noise = torch.randn_like(xt)
            mean = mean + torch.sqrt(beta_t) * noise
        return mean

    def train_step(self, x0, optimizer):
        batch = x0.shape[0]
        t = np.random.randint(0, self.timesteps, size=(batch,))  # Random timestep
        t_tensor = torch.tensor(t, dtype=torch.long, device=self.device)
        xt, noise = self.q_sample(x0, t)  # Add noise
        pred_noise = self.model(xt, t_tensor)  # Predict noise (t_tensor will be encoded)
        loss = F.mse_loss(pred_noise, noise)  # MSE loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # TODO: EMA.
        
        return loss.item()

    def sample(self, batch_size):
        xt = torch.randn(batch_size, self.dim, device=self.device)  # Start from pure noise.
        # TODO: Classifier-free guidance.
        for t in reversed(range(self.timesteps)):
            t_arr = np.full((batch_size,), t)
            xt = self.p_sample(xt, t_arr)  # Using the model.
        return xt

In [None]:
'''
Suppose we have a simple word vector space, each word is a vector.
This demo is only for illustrating the basic principles of continuous diffusion models. 
Actual diffusion language modeling requires more complex encoding and decoding strategies.
'''
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dim = 8
lm = DiffusionLMContinuous(dim, timesteps=100, device=device)
optimizer = torch.optim.Adam(lm.model.parameters(), lr=1e-3)

# Build a simple training set.
x0 = torch.randn((10,dim), device=device)
x0 = x0.repeat(100, 1)  # batch=100

# Training
for ep in range(500):
    loss = lm.train_step(x0, optimizer)
    if ep % 100 == 0:
        print(f'Epoch {ep}, Loss: {loss:.4f}')

# Sampling
samples = lm.sample(5).detach().cpu().numpy()
print('sampling shape:', samples.shape)

Epoch 0, Loss: 1.0513
Epoch 100, Loss: 0.5817
Epoch 200, Loss: 0.5004
Epoch 300, Loss: 0.4290
Epoch 400, Loss: 0.3646
sampling shape: (5, 8)
