In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_blobs
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import imageio

# Hyperparameters
timesteps = 200
batch_size = 128
epochs = 300
learning_rate = 1e-3

# Generate two Gaussians dataset in corners (4,4) and (4,-4)
X, _ = make_blobs(n_samples=10000, centers=[(4,4), (4,-4)], cluster_std=0.5, random_state=42)
X = torch.tensor(X, dtype=torch.float32)

# Define beta schedule
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

betas = linear_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# Helper function to extract the appropriate t index for a batch of indices
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

# Forward diffusion process
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

# Simple MLP model
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2 + 1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )

    def forward(self, x, t):
        x = torch.cat([x, t.unsqueeze(-1)], dim=-1)
        return self.net(x)

# Training function
def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for x in tqdm(dataloader):
        x = x[0].to(device)
        optimizer.zero_grad()

        t = torch.randint(0, timesteps, (x.shape[0],), device=device).long()
        noise = torch.randn_like(x)
        x_noisy = q_sample(x, t, noise)
        noise_pred = model(x_noisy, t.float())
        loss = nn.MSELoss()(noise_pred, noise)

        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

# Sampling function
@torch.no_grad()
def sample(model, n_samples, device, save_path=None):
    model.eval()
    x = torch.randn(n_samples, 2).to(device)
    images = []
    for t in reversed(range(timesteps)):
        t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
        predicted_noise = model(x, t_batch.float())
        alpha = alphas[t]
        alpha_hat = alphas_cumprod[t]
        beta = betas[t]
        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = torch.zeros_like(x)
        x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        
        if t % 10 == 0 or t == timesteps - 1:  # Save every 10th frame and the first frame
            fig, ax = plt.subplots(figsize=(6, 6))
            ax.scatter(x.cpu().numpy()[:, 0], x.cpu().numpy()[:, 1], s=1, c='red', alpha=0.7)
            ax.set_xlim(-6, 6)
            ax.set_ylim(-6, 6)
            ax.set_title(f"Step {timesteps - t}/{timesteps}")
            plt.close(fig)
            images.append(fig2img(fig))
    
    if save_path:
        imageio.mimsave(save_path, images, fps=5)
    
    return x

# Helper function to convert matplotlib figure to image
def fig2img(fig):
    fig.canvas.draw()
    return np.array(fig.canvas.renderer._renderer)

# Set up device, model, optimizer, and dataloader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleMLP().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
dataloader = DataLoader(TensorDataset(X), batch_size=batch_size, shuffle=True)

# Training loop
losses = []
for epoch in range(epochs):
    loss = train(model, dataloader, optimizer, device)
    losses.append(loss)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch + 1}/{epochs} completed, Loss: {loss:.4f}")

# Print final loss
print(f"Final loss: {losses[-1]:.4f}")

# Generate samples and create gif
samples = sample(model, 1000, device, save_path='diffusion_process.gif')

# Plot loss
plt.figure(figsize=(10, 6))
plt.plot(range(1, epochs + 1), losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.savefig('training_loss.png')
plt.close()

print("Diffusion process gif saved as 'diffusion_process.gif'")
print("Training loss plot saved as 'training_loss.png'")

  0%|          | 0/79 [00:00<?, ?it/s]

100%|██████████| 79/79 [00:00<00:00, 368.18it/s]
100%|██████████| 79/79 [00:00<00:00, 356.26it/s]
100%|██████████| 79/79 [00:00<00:00, 531.27it/s]
100%|██████████| 79/79 [00:00<00:00, 463.94it/s]
100%|██████████| 79/79 [00:00<00:00, 560.44it/s]
100%|██████████| 79/79 [00:00<00:00, 396.19it/s]
100%|██████████| 79/79 [00:00<00:00, 406.95it/s]
100%|██████████| 79/79 [00:00<00:00, 392.47it/s]
100%|██████████| 79/79 [00:00<00:00, 365.22it/s]
100%|██████████| 79/79 [00:00<00:00, 397.71it/s]


Epoch 10/300 completed, Loss: 0.7787


100%|██████████| 79/79 [00:00<00:00, 430.86it/s]
100%|██████████| 79/79 [00:00<00:00, 328.90it/s]
100%|██████████| 79/79 [00:00<00:00, 359.69it/s]
100%|██████████| 79/79 [00:00<00:00, 510.38it/s]
100%|██████████| 79/79 [00:00<00:00, 516.20it/s]
100%|██████████| 79/79 [00:00<00:00, 506.50it/s]
100%|██████████| 79/79 [00:00<00:00, 486.96it/s]
100%|██████████| 79/79 [00:00<00:00, 421.97it/s]
100%|██████████| 79/79 [00:00<00:00, 388.75it/s]
100%|██████████| 79/79 [00:00<00:00, 315.96it/s]


Epoch 20/300 completed, Loss: 0.6114


100%|██████████| 79/79 [00:00<00:00, 372.14it/s]
100%|██████████| 79/79 [00:00<00:00, 369.28it/s]
100%|██████████| 79/79 [00:00<00:00, 404.54it/s]
100%|██████████| 79/79 [00:00<00:00, 371.95it/s]
100%|██████████| 79/79 [00:00<00:00, 364.28it/s]
100%|██████████| 79/79 [00:00<00:00, 393.53it/s]
100%|██████████| 79/79 [00:00<00:00, 337.47it/s]
100%|██████████| 79/79 [00:00<00:00, 339.45it/s]
100%|██████████| 79/79 [00:00<00:00, 377.08it/s]
100%|██████████| 79/79 [00:00<00:00, 427.70it/s]


Epoch 30/300 completed, Loss: 0.5458


100%|██████████| 79/79 [00:00<00:00, 351.90it/s]
100%|██████████| 79/79 [00:00<00:00, 419.60it/s]
100%|██████████| 79/79 [00:00<00:00, 410.14it/s]
100%|██████████| 79/79 [00:00<00:00, 404.87it/s]
100%|██████████| 79/79 [00:00<00:00, 504.32it/s]
100%|██████████| 79/79 [00:00<00:00, 534.47it/s]
100%|██████████| 79/79 [00:00<00:00, 456.22it/s]
100%|██████████| 79/79 [00:00<00:00, 410.16it/s]
100%|██████████| 79/79 [00:00<00:00, 372.30it/s]
100%|██████████| 79/79 [00:00<00:00, 393.21it/s]


Epoch 40/300 completed, Loss: 0.5121


100%|██████████| 79/79 [00:00<00:00, 346.11it/s]
100%|██████████| 79/79 [00:00<00:00, 396.39it/s]
100%|██████████| 79/79 [00:00<00:00, 408.06it/s]
100%|██████████| 79/79 [00:00<00:00, 396.76it/s]
100%|██████████| 79/79 [00:00<00:00, 429.27it/s]
100%|██████████| 79/79 [00:00<00:00, 479.42it/s]
100%|██████████| 79/79 [00:00<00:00, 448.28it/s]
100%|██████████| 79/79 [00:00<00:00, 396.68it/s]
100%|██████████| 79/79 [00:00<00:00, 380.63it/s]
100%|██████████| 79/79 [00:00<00:00, 419.69it/s]


Epoch 50/300 completed, Loss: 0.4838


100%|██████████| 79/79 [00:00<00:00, 400.23it/s]
100%|██████████| 79/79 [00:00<00:00, 349.24it/s]
100%|██████████| 79/79 [00:00<00:00, 325.24it/s]
100%|██████████| 79/79 [00:00<00:00, 354.69it/s]
100%|██████████| 79/79 [00:00<00:00, 378.90it/s]
100%|██████████| 79/79 [00:00<00:00, 410.20it/s]
100%|██████████| 79/79 [00:00<00:00, 498.19it/s]
100%|██████████| 79/79 [00:00<00:00, 506.44it/s]
100%|██████████| 79/79 [00:00<00:00, 347.46it/s]
100%|██████████| 79/79 [00:00<00:00, 359.07it/s]


Epoch 60/300 completed, Loss: 0.4486


100%|██████████| 79/79 [00:00<00:00, 345.83it/s]
100%|██████████| 79/79 [00:00<00:00, 453.07it/s]
100%|██████████| 79/79 [00:00<00:00, 391.14it/s]
100%|██████████| 79/79 [00:00<00:00, 388.66it/s]
100%|██████████| 79/79 [00:00<00:00, 370.37it/s]
100%|██████████| 79/79 [00:00<00:00, 476.40it/s]
100%|██████████| 79/79 [00:00<00:00, 500.65it/s]
100%|██████████| 79/79 [00:00<00:00, 528.76it/s]
100%|██████████| 79/79 [00:00<00:00, 514.53it/s]
100%|██████████| 79/79 [00:00<00:00, 414.19it/s]


Epoch 70/300 completed, Loss: 0.4527


100%|██████████| 79/79 [00:00<00:00, 406.88it/s]
100%|██████████| 79/79 [00:00<00:00, 523.43it/s]
100%|██████████| 79/79 [00:00<00:00, 440.88it/s]
100%|██████████| 79/79 [00:00<00:00, 421.43it/s]
100%|██████████| 79/79 [00:00<00:00, 417.82it/s]
100%|██████████| 79/79 [00:00<00:00, 427.67it/s]
100%|██████████| 79/79 [00:00<00:00, 449.86it/s]
100%|██████████| 79/79 [00:00<00:00, 458.32it/s]
100%|██████████| 79/79 [00:00<00:00, 599.07it/s]
100%|██████████| 79/79 [00:00<00:00, 536.90it/s]


Epoch 80/300 completed, Loss: 0.4606


100%|██████████| 79/79 [00:00<00:00, 502.12it/s]
100%|██████████| 79/79 [00:00<00:00, 437.08it/s]
100%|██████████| 79/79 [00:00<00:00, 436.52it/s]
100%|██████████| 79/79 [00:00<00:00, 333.56it/s]
100%|██████████| 79/79 [00:00<00:00, 466.09it/s]
100%|██████████| 79/79 [00:00<00:00, 456.62it/s]
100%|██████████| 79/79 [00:00<00:00, 454.58it/s]
100%|██████████| 79/79 [00:00<00:00, 442.89it/s]
100%|██████████| 79/79 [00:00<00:00, 410.75it/s]
100%|██████████| 79/79 [00:00<00:00, 551.45it/s]


Epoch 90/300 completed, Loss: 0.4469


100%|██████████| 79/79 [00:00<00:00, 552.99it/s]
100%|██████████| 79/79 [00:00<00:00, 513.11it/s]
100%|██████████| 79/79 [00:00<00:00, 568.82it/s]
100%|██████████| 79/79 [00:00<00:00, 451.66it/s]
100%|██████████| 79/79 [00:00<00:00, 431.90it/s]
100%|██████████| 79/79 [00:00<00:00, 456.70it/s]
100%|██████████| 79/79 [00:00<00:00, 424.46it/s]
100%|██████████| 79/79 [00:00<00:00, 408.91it/s]
100%|██████████| 79/79 [00:00<00:00, 386.49it/s]
100%|██████████| 79/79 [00:00<00:00, 430.47it/s]


Epoch 100/300 completed, Loss: 0.4589


100%|██████████| 79/79 [00:00<00:00, 421.38it/s]
100%|██████████| 79/79 [00:00<00:00, 554.86it/s]
100%|██████████| 79/79 [00:00<00:00, 573.88it/s]
100%|██████████| 79/79 [00:00<00:00, 514.04it/s]
100%|██████████| 79/79 [00:00<00:00, 524.34it/s]
100%|██████████| 79/79 [00:00<00:00, 441.64it/s]
100%|██████████| 79/79 [00:00<00:00, 447.91it/s]
100%|██████████| 79/79 [00:00<00:00, 354.74it/s]
100%|██████████| 79/79 [00:00<00:00, 447.56it/s]
100%|██████████| 79/79 [00:00<00:00, 422.90it/s]


Epoch 110/300 completed, Loss: 0.4628


100%|██████████| 79/79 [00:00<00:00, 435.40it/s]
100%|██████████| 79/79 [00:00<00:00, 408.61it/s]
100%|██████████| 79/79 [00:00<00:00, 500.96it/s]
100%|██████████| 79/79 [00:00<00:00, 476.24it/s]
100%|██████████| 79/79 [00:00<00:00, 561.08it/s]
100%|██████████| 79/79 [00:00<00:00, 541.39it/s]
100%|██████████| 79/79 [00:00<00:00, 624.39it/s]
100%|██████████| 79/79 [00:00<00:00, 451.56it/s]
100%|██████████| 79/79 [00:00<00:00, 418.43it/s]
100%|██████████| 79/79 [00:00<00:00, 345.45it/s]


Epoch 120/300 completed, Loss: 0.4493


100%|██████████| 79/79 [00:00<00:00, 382.87it/s]
100%|██████████| 79/79 [00:00<00:00, 418.65it/s]
100%|██████████| 79/79 [00:00<00:00, 465.28it/s]
100%|██████████| 79/79 [00:00<00:00, 328.47it/s]
100%|██████████| 79/79 [00:00<00:00, 446.32it/s]
100%|██████████| 79/79 [00:00<00:00, 553.17it/s]
100%|██████████| 79/79 [00:00<00:00, 521.43it/s]
100%|██████████| 79/79 [00:00<00:00, 582.17it/s]
100%|██████████| 79/79 [00:00<00:00, 554.35it/s]
100%|██████████| 79/79 [00:00<00:00, 431.03it/s]


Epoch 130/300 completed, Loss: 0.4419


100%|██████████| 79/79 [00:00<00:00, 416.95it/s]
100%|██████████| 79/79 [00:00<00:00, 436.33it/s]
100%|██████████| 79/79 [00:00<00:00, 406.16it/s]
100%|██████████| 79/79 [00:00<00:00, 388.10it/s]
100%|██████████| 79/79 [00:00<00:00, 353.73it/s]
100%|██████████| 79/79 [00:00<00:00, 415.64it/s]
100%|██████████| 79/79 [00:00<00:00, 480.28it/s]
100%|██████████| 79/79 [00:00<00:00, 517.54it/s]
100%|██████████| 79/79 [00:00<00:00, 577.42it/s]
100%|██████████| 79/79 [00:00<00:00, 591.07it/s]


Epoch 140/300 completed, Loss: 0.4506


100%|██████████| 79/79 [00:00<00:00, 511.03it/s]
100%|██████████| 79/79 [00:00<00:00, 465.82it/s]
100%|██████████| 79/79 [00:00<00:00, 384.65it/s]
100%|██████████| 79/79 [00:00<00:00, 359.65it/s]
100%|██████████| 79/79 [00:00<00:00, 380.14it/s]
100%|██████████| 79/79 [00:00<00:00, 344.05it/s]
100%|██████████| 79/79 [00:00<00:00, 426.22it/s]
100%|██████████| 79/79 [00:00<00:00, 483.88it/s]
100%|██████████| 79/79 [00:00<00:00, 417.26it/s]
100%|██████████| 79/79 [00:00<00:00, 527.11it/s]


Epoch 150/300 completed, Loss: 0.4394


100%|██████████| 79/79 [00:00<00:00, 547.29it/s]
100%|██████████| 79/79 [00:00<00:00, 565.00it/s]
100%|██████████| 79/79 [00:00<00:00, 524.74it/s]
100%|██████████| 79/79 [00:00<00:00, 464.82it/s]
100%|██████████| 79/79 [00:00<00:00, 484.10it/s]
100%|██████████| 79/79 [00:00<00:00, 452.06it/s]
100%|██████████| 79/79 [00:00<00:00, 411.61it/s]
100%|██████████| 79/79 [00:00<00:00, 397.96it/s]
100%|██████████| 79/79 [00:00<00:00, 479.00it/s]
100%|██████████| 79/79 [00:00<00:00, 407.55it/s]


Epoch 160/300 completed, Loss: 0.4493


100%|██████████| 79/79 [00:00<00:00, 382.55it/s]
100%|██████████| 79/79 [00:00<00:00, 477.76it/s]
100%|██████████| 79/79 [00:00<00:00, 446.55it/s]
100%|██████████| 79/79 [00:00<00:00, 576.49it/s]
100%|██████████| 79/79 [00:00<00:00, 435.29it/s]
100%|██████████| 79/79 [00:00<00:00, 415.20it/s]
100%|██████████| 79/79 [00:00<00:00, 410.54it/s]
100%|██████████| 79/79 [00:00<00:00, 425.19it/s]
100%|██████████| 79/79 [00:00<00:00, 362.48it/s]
100%|██████████| 79/79 [00:00<00:00, 423.46it/s]


Epoch 170/300 completed, Loss: 0.4569


100%|██████████| 79/79 [00:00<00:00, 371.58it/s]
100%|██████████| 79/79 [00:00<00:00, 438.73it/s]
100%|██████████| 79/79 [00:00<00:00, 548.71it/s]
100%|██████████| 79/79 [00:00<00:00, 536.28it/s]
100%|██████████| 79/79 [00:00<00:00, 557.21it/s]
100%|██████████| 79/79 [00:00<00:00, 510.67it/s]
100%|██████████| 79/79 [00:00<00:00, 396.01it/s]
100%|██████████| 79/79 [00:00<00:00, 436.94it/s]
100%|██████████| 79/79 [00:00<00:00, 397.71it/s]
100%|██████████| 79/79 [00:00<00:00, 384.42it/s]


Epoch 180/300 completed, Loss: 0.4509


100%|██████████| 79/79 [00:00<00:00, 374.79it/s]
100%|██████████| 79/79 [00:00<00:00, 434.54it/s]
100%|██████████| 79/79 [00:00<00:00, 308.66it/s]
100%|██████████| 79/79 [00:00<00:00, 450.06it/s]
100%|██████████| 79/79 [00:00<00:00, 498.18it/s]
100%|██████████| 79/79 [00:00<00:00, 523.16it/s]
100%|██████████| 79/79 [00:00<00:00, 561.31it/s]
100%|██████████| 79/79 [00:00<00:00, 377.51it/s]
100%|██████████| 79/79 [00:00<00:00, 437.50it/s]
100%|██████████| 79/79 [00:00<00:00, 504.55it/s]


Epoch 190/300 completed, Loss: 0.4325


100%|██████████| 79/79 [00:00<00:00, 452.76it/s]
100%|██████████| 79/79 [00:00<00:00, 426.96it/s]
100%|██████████| 79/79 [00:00<00:00, 438.58it/s]
100%|██████████| 79/79 [00:00<00:00, 458.40it/s]
100%|██████████| 79/79 [00:00<00:00, 394.93it/s]
100%|██████████| 79/79 [00:00<00:00, 415.07it/s]
100%|██████████| 79/79 [00:00<00:00, 648.68it/s]
100%|██████████| 79/79 [00:00<00:00, 630.30it/s]
100%|██████████| 79/79 [00:00<00:00, 558.97it/s]
100%|██████████| 79/79 [00:00<00:00, 412.82it/s]


Epoch 200/300 completed, Loss: 0.4399


100%|██████████| 79/79 [00:00<00:00, 466.50it/s]
100%|██████████| 79/79 [00:00<00:00, 426.67it/s]
100%|██████████| 79/79 [00:00<00:00, 450.37it/s]
100%|██████████| 79/79 [00:00<00:00, 486.36it/s]
100%|██████████| 79/79 [00:00<00:00, 410.87it/s]
100%|██████████| 79/79 [00:00<00:00, 475.89it/s]
100%|██████████| 79/79 [00:00<00:00, 429.78it/s]
100%|██████████| 79/79 [00:00<00:00, 463.35it/s]
100%|██████████| 79/79 [00:00<00:00, 603.31it/s]
100%|██████████| 79/79 [00:00<00:00, 571.31it/s]


Epoch 210/300 completed, Loss: 0.4494


100%|██████████| 79/79 [00:00<00:00, 572.07it/s]
100%|██████████| 79/79 [00:00<00:00, 474.24it/s]
100%|██████████| 79/79 [00:00<00:00, 398.11it/s]
100%|██████████| 79/79 [00:00<00:00, 426.62it/s]
100%|██████████| 79/79 [00:00<00:00, 460.86it/s]
100%|██████████| 79/79 [00:00<00:00, 435.87it/s]
100%|██████████| 79/79 [00:00<00:00, 414.31it/s]
100%|██████████| 79/79 [00:00<00:00, 400.50it/s]
100%|██████████| 79/79 [00:00<00:00, 398.31it/s]
100%|██████████| 79/79 [00:00<00:00, 425.09it/s]


Epoch 220/300 completed, Loss: 0.4374


100%|██████████| 79/79 [00:00<00:00, 481.55it/s]
100%|██████████| 79/79 [00:00<00:00, 350.70it/s]
100%|██████████| 79/79 [00:00<00:00, 484.83it/s]
100%|██████████| 79/79 [00:00<00:00, 498.58it/s]
100%|██████████| 79/79 [00:00<00:00, 481.19it/s]
100%|██████████| 79/79 [00:00<00:00, 407.05it/s]
100%|██████████| 79/79 [00:00<00:00, 437.78it/s]
100%|██████████| 79/79 [00:00<00:00, 427.92it/s]
100%|██████████| 79/79 [00:00<00:00, 398.23it/s]
100%|██████████| 79/79 [00:00<00:00, 382.97it/s]


Epoch 230/300 completed, Loss: 0.4324


100%|██████████| 79/79 [00:00<00:00, 452.11it/s]
100%|██████████| 79/79 [00:00<00:00, 505.50it/s]
100%|██████████| 79/79 [00:00<00:00, 553.86it/s]
100%|██████████| 79/79 [00:00<00:00, 569.04it/s]
100%|██████████| 79/79 [00:00<00:00, 597.85it/s]
100%|██████████| 79/79 [00:00<00:00, 462.18it/s]
100%|██████████| 79/79 [00:00<00:00, 363.63it/s]
100%|██████████| 79/79 [00:00<00:00, 294.89it/s]
100%|██████████| 79/79 [00:00<00:00, 482.22it/s]
100%|██████████| 79/79 [00:00<00:00, 434.58it/s]


Epoch 240/300 completed, Loss: 0.4665


100%|██████████| 79/79 [00:00<00:00, 355.09it/s]
100%|██████████| 79/79 [00:00<00:00, 461.74it/s]
100%|██████████| 79/79 [00:00<00:00, 428.06it/s]
100%|██████████| 79/79 [00:00<00:00, 521.01it/s]
100%|██████████| 79/79 [00:00<00:00, 592.07it/s]
100%|██████████| 79/79 [00:00<00:00, 593.37it/s]
100%|██████████| 79/79 [00:00<00:00, 477.34it/s]
100%|██████████| 79/79 [00:00<00:00, 415.40it/s]
100%|██████████| 79/79 [00:00<00:00, 452.95it/s]
100%|██████████| 79/79 [00:00<00:00, 436.96it/s]


Epoch 250/300 completed, Loss: 0.4366


100%|██████████| 79/79 [00:00<00:00, 447.48it/s]
100%|██████████| 79/79 [00:00<00:00, 416.33it/s]
100%|██████████| 79/79 [00:00<00:00, 403.62it/s]
100%|██████████| 79/79 [00:00<00:00, 421.24it/s]
100%|██████████| 79/79 [00:00<00:00, 447.64it/s]
100%|██████████| 79/79 [00:00<00:00, 546.66it/s]
100%|██████████| 79/79 [00:00<00:00, 576.50it/s]
100%|██████████| 79/79 [00:00<00:00, 562.59it/s]
100%|██████████| 79/79 [00:00<00:00, 530.23it/s]
100%|██████████| 79/79 [00:00<00:00, 458.38it/s]


Epoch 260/300 completed, Loss: 0.4449


100%|██████████| 79/79 [00:00<00:00, 424.61it/s]
100%|██████████| 79/79 [00:00<00:00, 334.98it/s]
100%|██████████| 79/79 [00:00<00:00, 339.09it/s]
100%|██████████| 79/79 [00:00<00:00, 408.41it/s]
100%|██████████| 79/79 [00:00<00:00, 427.02it/s]
100%|██████████| 79/79 [00:00<00:00, 463.95it/s]
100%|██████████| 79/79 [00:00<00:00, 410.98it/s]
100%|██████████| 79/79 [00:00<00:00, 549.34it/s]
100%|██████████| 79/79 [00:00<00:00, 559.54it/s]
100%|██████████| 79/79 [00:00<00:00, 496.62it/s]


Epoch 270/300 completed, Loss: 0.4344


100%|██████████| 79/79 [00:00<00:00, 392.61it/s]
100%|██████████| 79/79 [00:00<00:00, 396.90it/s]
100%|██████████| 79/79 [00:00<00:00, 477.21it/s]
100%|██████████| 79/79 [00:00<00:00, 380.84it/s]
100%|██████████| 79/79 [00:00<00:00, 357.98it/s]
100%|██████████| 79/79 [00:00<00:00, 427.32it/s]
100%|██████████| 79/79 [00:00<00:00, 464.11it/s]
100%|██████████| 79/79 [00:00<00:00, 436.18it/s]
100%|██████████| 79/79 [00:00<00:00, 498.23it/s]
100%|██████████| 79/79 [00:00<00:00, 602.57it/s]


Epoch 280/300 completed, Loss: 0.4400


100%|██████████| 79/79 [00:00<00:00, 527.75it/s]
100%|██████████| 79/79 [00:00<00:00, 640.54it/s]
100%|██████████| 79/79 [00:00<00:00, 534.16it/s]
100%|██████████| 79/79 [00:00<00:00, 474.07it/s]
100%|██████████| 79/79 [00:00<00:00, 406.57it/s]
100%|██████████| 79/79 [00:00<00:00, 429.12it/s]
100%|██████████| 79/79 [00:00<00:00, 352.68it/s]
100%|██████████| 79/79 [00:00<00:00, 431.78it/s]
100%|██████████| 79/79 [00:00<00:00, 418.64it/s]
100%|██████████| 79/79 [00:00<00:00, 449.06it/s]


Epoch 290/300 completed, Loss: 0.4277


100%|██████████| 79/79 [00:00<00:00, 478.48it/s]
100%|██████████| 79/79 [00:00<00:00, 584.21it/s]
100%|██████████| 79/79 [00:00<00:00, 559.05it/s]
100%|██████████| 79/79 [00:00<00:00, 563.89it/s]
100%|██████████| 79/79 [00:00<00:00, 545.59it/s]
100%|██████████| 79/79 [00:00<00:00, 431.12it/s]
100%|██████████| 79/79 [00:00<00:00, 444.00it/s]
100%|██████████| 79/79 [00:00<00:00, 471.89it/s]
100%|██████████| 79/79 [00:00<00:00, 369.74it/s]
100%|██████████| 79/79 [00:00<00:00, 425.77it/s]


Epoch 300/300 completed, Loss: 0.4365
Final loss: 0.4365
Diffusion process gif saved as 'diffusion_process.gif'
Training loss plot saved as 'training_loss.png'


In [27]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch

# Sampling function
@torch.no_grad()
def sample(model, n_samples, device, source_samples, target_samples):
    model.eval()
    x = torch.randn(n_samples, 2).to(device)
    frames = []
    
    fig, ax = plt.subplots(figsize=(10, 10))
    scatter = ax.scatter([], [], s=5, c='red', alpha=0.7, label='Generated')
    source_scatter = ax.scatter(source_samples[:, 0], source_samples[:, 1], s=10, c='blue', alpha=0.3, label='Source')
    target_scatter = ax.scatter(target_samples[:, 0], target_samples[:, 1], s=10, c='green', alpha=0.3, label='Target')
    time_text = ax.text(0.02, 0.02, '', transform=ax.transAxes, fontsize=16, verticalalignment='bottom')
    
    # Set axis limits based on the data
    all_data = np.vstack((source_samples, target_samples))
    x_min, x_max = all_data[:, 0].min(), all_data[:, 0].max()
    y_min, y_max = all_data[:, 1].min(), all_data[:, 1].max()
    margin = 0.1 * max(x_max - x_min, y_max - y_min)
    ax.set_xlim(x_min - margin, x_max + margin)
    ax.set_ylim(y_min - margin, y_max + margin)
    
    ax.legend(loc='lower left')
    
    # Remove axis and ticks
    ax.axis('off')
    
    def update(frame):
        nonlocal x
        t = timesteps - 1 - frame
        t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
        predicted_noise = model(x, t_batch.float())
        alpha = alphas[t]
        alpha_hat = alphas_cumprod[t]
        beta = betas[t]
        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = torch.zeros_like(x)
        x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        
        scatter.set_offsets(x.cpu().numpy())
        time_text.set_text(f't = {frame}')
        return scatter, time_text
    
    anim = animation.FuncAnimation(fig, update, frames=timesteps, interval=50, blit=True)
    
    # Save as GIF
    anim.save('diffusion_process.gif', writer='pillow', fps=20)
    plt.close(fig)
    
    return x

# Generate source and target samples
source_samples = torch.randn(1000, 2).numpy()
target_samples = X.numpy()

# Generate samples and create animation
samples = sample(model, 5000, device, source_samples, target_samples)

# Plot loss
plt.figure(figsize=(10, 6))
plt.plot(range(1, epochs + 1), losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.savefig('training_loss.png')
plt.close()

print("Diffusion process animation saved as 'diffusion_process.gif'")
print("Training loss plot saved as 'training_loss.png'")

Diffusion process animation saved as 'diffusion_process.gif'
Training loss plot saved as 'training_loss.png'


In [16]:
# Generate samples
samples = sample(model, 5000, device).cpu().numpy()

# Define font sizes
base_fontsize = 18
title_fontsize = base_fontsize + 2
label_fontsize = base_fontsize
tick_fontsize = base_fontsize - 2

# Plot results
plt.figure(figsize=(20, 6))
plt.rcParams.update({'font.size': base_fontsize})  # Set base font size

plt.subplot(131)
plt.scatter(X[:, 0], X[:, 1], s=5, c='blue', alpha=0.7)
plt.title("Original Data", fontsize=title_fontsize)
plt.xticks(fontsize=tick_fontsize)
plt.yticks(fontsize=tick_fontsize)

plt.subplot(132)
plt.scatter(samples[:, 0], samples[:, 1], s=5, c='red', alpha=0.7)
plt.title("Generated Samples", fontsize=title_fontsize)
plt.xticks(fontsize=tick_fontsize)
plt.yticks(fontsize=tick_fontsize)

plt.subplot(133)
plt.plot(range(1, epochs + 1), losses, linewidth=4)
plt.title("Training Loss", fontsize=title_fontsize)
plt.xlabel("Epoch", fontsize=label_fontsize)
plt.ylabel("Loss", fontsize=label_fontsize)
plt.xticks(fontsize=tick_fontsize)
plt.yticks(fontsize=tick_fontsize)

# Add dotted line at minimum loss
min_loss = min(losses)
plt.axhline(y=min_loss, color='r', linestyle=':', linewidth=2)

# Add text for minimum objective using LaTeX
plt.text(epochs/2, min_loss, r'$\mathbf{min~~ L_{\rm diffusion} > 0}$', 
         horizontalalignment='right', verticalalignment='bottom', 
         fontsize=label_fontsize, color='r', fontweight='bold')

plt.tight_layout()
plt.savefig('diffusion_results.png', dpi=300, bbox_inches='tight')
plt.close()