In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_blobs
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import imageio

# Hyperparameters
timesteps = 200
batch_size = 512
epochs = 200
learning_rate = 1e-3

# Generate two Gaussians dataset in corners (4,4) and (4,-4)
X, _ = make_blobs(n_samples=10000, centers=[(4,4), (4,-4)], cluster_std=0.5, random_state=42)
X = torch.tensor(X, dtype=torch.float32)

# Define beta schedule
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

betas = linear_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# Helper function to extract the appropriate t index for a batch of indices
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

# Forward diffusion process
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

# Simple MLP model
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2 + 1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )

    def forward(self, x, t):
        x = torch.cat([x, t.unsqueeze(-1)], dim=-1)
        return self.net(x)

# Training function
def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for x in tqdm(dataloader):
        x = x[0].to(device)
        optimizer.zero_grad()

        t = torch.randint(0, timesteps, (x.shape[0],), device=device).long()
        noise = torch.randn_like(x)
        x_noisy = q_sample(x, t, noise)
        noise_pred = model(x_noisy, t.float())
        loss = nn.MSELoss()(noise_pred, noise)

        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

# Initialize model, optimizer, and dataloader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleMLP().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
dataloader = DataLoader(TensorDataset(X), batch_size=batch_size, shuffle=True)

# Training loop
losses = []
for epoch in range(epochs):
    loss = train(model, dataloader, optimizer, device)
    losses.append(loss)
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}")

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch

# Sampling function
@torch.no_grad()
def sample(model, n_samples, device, source_samples, target_samples):
    model.eval()
    x = torch.randn(n_samples, 2).to(device)
    frames = []
    
    fig, ax = plt.subplots(figsize=(10, 10))
    scatter = ax.scatter([], [], s=25, c='red', alpha=0.7, label='Generated')
    source_scatter = ax.scatter(source_samples[:, 0], source_samples[:, 1], s=30, c='blue', alpha=0.3, label='Source')
    target_scatter = ax.scatter(target_samples[:, 0], target_samples[:, 1], s=30, c='green', alpha=0.3, label='Target')
    time_text = ax.text(0.05, 0.95, '', transform=ax.transAxes, fontsize=20, horizontalalignment='left', verticalalignment='top')
    
    # Set axis limits based on the data
    all_data = np.vstack((source_samples, target_samples))
    x_min, x_max = all_data[:, 0].min(), all_data[:, 0].max()
    y_min, y_max = all_data[:, 1].min(), all_data[:, 1].max()
    margin = 0.05 * max(x_max - x_min, y_max - y_min)  # Reduced margin
    ax.set_xlim(x_min - margin, x_max + margin)
    ax.set_ylim(y_min - margin, y_max + margin)
    
    ax.legend(loc='lower left', fontsize=20)
    
    # Remove axis and ticks
    ax.axis('off')
    
    def update(frame):
        nonlocal x
        t = timesteps - 1 - frame
        t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
        predicted_noise = model(x, t_batch.float())
        alpha = alphas[t]
        alpha_hat = alphas_cumprod[t]
        beta = betas[t]
        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = torch.zeros_like(x)
        x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        
        scatter.set_offsets(x.cpu().numpy())
        time_text.set_text(f't = {frame}/{timesteps}')
        return scatter, time_text
    
    anim = animation.FuncAnimation(fig, update, frames=timesteps, interval=50, blit=True)
    
    # Save as GIF with tight layout
    plt.tight_layout()
    anim.save('diffusion_process.gif', writer='pillow', fps=20)
    plt.close(fig)
    
    return x

# Generate source and target samples
source_samples = torch.randn(2000, 2).numpy()
target_samples = X.numpy()[:2000]

# Generate samples and create animation
samples = sample(model, 3000, device, source_samples, target_samples)

# Plot loss with increased font and tick size
plt.figure(figsize=(10, 6))
plt.plot(range(1, epochs + 1), losses, linewidth=3)
plt.title("Training Loss", fontsize=20)
plt.xlabel("Epoch", fontsize=18)
plt.ylabel("Loss", fontsize=18)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

# Add red dotted line at minimum value
min_loss = min(losses)
plt.axhline(y=min_loss, color='red', linestyle=':', linewidth=2)

# Add text 'L_diffusion > 0'
plt.text(epochs/2, min_loss*1.1, r'$L_\mathrm{diffusion} > 0$', fontsize=16, 
         horizontalalignment='center', verticalalignment='bottom')

plt.tight_layout()
plt.savefig('training_loss.png', dpi=300, bbox_inches='tight')
plt.close()

print("Diffusion process animation saved as 'diffusion_process.gif'")
print("Training loss plot saved as 'training_loss.png'")

  0%|          | 0/79 [00:00<?, ?it/s]

100%|██████████| 79/79 [00:00<00:00, 395.73it/s]
100%|██████████| 79/79 [00:00<00:00, 476.12it/s]
100%|██████████| 79/79 [00:00<00:00, 498.84it/s]
100%|██████████| 79/79 [00:00<00:00, 504.34it/s]
100%|██████████| 79/79 [00:00<00:00, 479.81it/s]
100%|██████████| 79/79 [00:00<00:00, 428.74it/s]
100%|██████████| 79/79 [00:00<00:00, 442.41it/s]
100%|██████████| 79/79 [00:00<00:00, 479.27it/s]
100%|██████████| 79/79 [00:00<00:00, 511.85it/s]
100%|██████████| 79/79 [00:00<00:00, 399.03it/s]
100%|██████████| 79/79 [00:00<00:00, 288.71it/s]
100%|██████████| 79/79 [00:00<00:00, 516.99it/s]
100%|██████████| 79/79 [00:00<00:00, 533.38it/s]
100%|██████████| 79/79 [00:00<00:00, 501.83it/s]
100%|██████████| 79/79 [00:00<00:00, 543.09it/s]
100%|██████████| 79/79 [00:00<00:00, 466.62it/s]
100%|██████████| 79/79 [00:00<00:00, 443.96it/s]
100%|██████████| 79/79 [00:00<00:00, 462.58it/s]
100%|██████████| 79/79 [00:00<00:00, 464.91it/s]
100%|██████████| 79/79 [00:00<00:00, 407.34it/s]
100%|██████████| 79/

Epoch 50/400, Loss: 0.4462


100%|██████████| 79/79 [00:00<00:00, 461.24it/s]
100%|██████████| 79/79 [00:00<00:00, 440.31it/s]
100%|██████████| 79/79 [00:00<00:00, 446.97it/s]
100%|██████████| 79/79 [00:00<00:00, 461.54it/s]
100%|██████████| 79/79 [00:00<00:00, 413.36it/s]
100%|██████████| 79/79 [00:00<00:00, 483.83it/s]
100%|██████████| 79/79 [00:00<00:00, 463.08it/s]
100%|██████████| 79/79 [00:00<00:00, 352.97it/s]
100%|██████████| 79/79 [00:00<00:00, 458.03it/s]
100%|██████████| 79/79 [00:00<00:00, 420.43it/s]
100%|██████████| 79/79 [00:00<00:00, 505.46it/s]
100%|██████████| 79/79 [00:00<00:00, 492.87it/s]
100%|██████████| 79/79 [00:00<00:00, 487.93it/s]
100%|██████████| 79/79 [00:00<00:00, 445.97it/s]
100%|██████████| 79/79 [00:00<00:00, 392.94it/s]
100%|██████████| 79/79 [00:00<00:00, 445.47it/s]
100%|██████████| 79/79 [00:00<00:00, 529.57it/s]
100%|██████████| 79/79 [00:00<00:00, 513.06it/s]
100%|██████████| 79/79 [00:00<00:00, 415.48it/s]
100%|██████████| 79/79 [00:00<00:00, 450.08it/s]
100%|██████████| 79/

Epoch 100/400, Loss: 0.4499


100%|██████████| 79/79 [00:00<00:00, 439.97it/s]
100%|██████████| 79/79 [00:00<00:00, 329.58it/s]
100%|██████████| 79/79 [00:00<00:00, 443.46it/s]
100%|██████████| 79/79 [00:00<00:00, 442.95it/s]
100%|██████████| 79/79 [00:00<00:00, 411.57it/s]
100%|██████████| 79/79 [00:00<00:00, 379.69it/s]
100%|██████████| 79/79 [00:00<00:00, 413.58it/s]
100%|██████████| 79/79 [00:00<00:00, 469.58it/s]
100%|██████████| 79/79 [00:00<00:00, 460.68it/s]
100%|██████████| 79/79 [00:00<00:00, 430.52it/s]
100%|██████████| 79/79 [00:00<00:00, 354.98it/s]
100%|██████████| 79/79 [00:00<00:00, 459.01it/s]
100%|██████████| 79/79 [00:00<00:00, 507.61it/s]
100%|██████████| 79/79 [00:00<00:00, 456.55it/s]
100%|██████████| 79/79 [00:00<00:00, 529.21it/s]
100%|██████████| 79/79 [00:00<00:00, 351.23it/s]
100%|██████████| 79/79 [00:00<00:00, 416.92it/s]
100%|██████████| 79/79 [00:00<00:00, 385.22it/s]
100%|██████████| 79/79 [00:00<00:00, 396.23it/s]
100%|██████████| 79/79 [00:00<00:00, 540.12it/s]
100%|██████████| 79/

Epoch 150/400, Loss: 0.4447


100%|██████████| 79/79 [00:00<00:00, 335.99it/s]
100%|██████████| 79/79 [00:00<00:00, 487.94it/s]
100%|██████████| 79/79 [00:00<00:00, 520.50it/s]
100%|██████████| 79/79 [00:00<00:00, 487.59it/s]
100%|██████████| 79/79 [00:00<00:00, 493.39it/s]
100%|██████████| 79/79 [00:00<00:00, 495.93it/s]
100%|██████████| 79/79 [00:00<00:00, 455.31it/s]
100%|██████████| 79/79 [00:00<00:00, 362.83it/s]
100%|██████████| 79/79 [00:00<00:00, 322.01it/s]
100%|██████████| 79/79 [00:00<00:00, 497.65it/s]
100%|██████████| 79/79 [00:00<00:00, 481.90it/s]
100%|██████████| 79/79 [00:00<00:00, 455.17it/s]
100%|██████████| 79/79 [00:00<00:00, 458.09it/s]
100%|██████████| 79/79 [00:00<00:00, 452.37it/s]
100%|██████████| 79/79 [00:00<00:00, 518.05it/s]
100%|██████████| 79/79 [00:00<00:00, 496.54it/s]
100%|██████████| 79/79 [00:00<00:00, 430.99it/s]
100%|██████████| 79/79 [00:00<00:00, 459.04it/s]
100%|██████████| 79/79 [00:00<00:00, 438.51it/s]
100%|██████████| 79/79 [00:00<00:00, 464.03it/s]
100%|██████████| 79/

Epoch 200/400, Loss: 0.4516


100%|██████████| 79/79 [00:00<00:00, 349.31it/s]
100%|██████████| 79/79 [00:00<00:00, 469.38it/s]
100%|██████████| 79/79 [00:00<00:00, 513.96it/s]
100%|██████████| 79/79 [00:00<00:00, 504.13it/s]
100%|██████████| 79/79 [00:00<00:00, 543.61it/s]
100%|██████████| 79/79 [00:00<00:00, 442.39it/s]
100%|██████████| 79/79 [00:00<00:00, 512.47it/s]
100%|██████████| 79/79 [00:00<00:00, 498.12it/s]
100%|██████████| 79/79 [00:00<00:00, 534.61it/s]
100%|██████████| 79/79 [00:00<00:00, 512.31it/s]
100%|██████████| 79/79 [00:00<00:00, 425.93it/s]
100%|██████████| 79/79 [00:00<00:00, 487.08it/s]
100%|██████████| 79/79 [00:00<00:00, 518.91it/s]
100%|██████████| 79/79 [00:00<00:00, 518.21it/s]
100%|██████████| 79/79 [00:00<00:00, 523.44it/s]
100%|██████████| 79/79 [00:00<00:00, 510.15it/s]
100%|██████████| 79/79 [00:00<00:00, 484.62it/s]
100%|██████████| 79/79 [00:00<00:00, 527.00it/s]
100%|██████████| 79/79 [00:00<00:00, 523.12it/s]
100%|██████████| 79/79 [00:00<00:00, 533.32it/s]
100%|██████████| 79/

Epoch 250/400, Loss: 0.4442


100%|██████████| 79/79 [00:00<00:00, 414.79it/s]
100%|██████████| 79/79 [00:00<00:00, 458.65it/s]
100%|██████████| 79/79 [00:00<00:00, 324.48it/s]
100%|██████████| 79/79 [00:00<00:00, 487.97it/s]
100%|██████████| 79/79 [00:00<00:00, 429.72it/s]
100%|██████████| 79/79 [00:00<00:00, 517.47it/s]
100%|██████████| 79/79 [00:00<00:00, 493.80it/s]
100%|██████████| 79/79 [00:00<00:00, 525.98it/s]
100%|██████████| 79/79 [00:00<00:00, 541.39it/s]
100%|██████████| 79/79 [00:00<00:00, 529.04it/s]
100%|██████████| 79/79 [00:00<00:00, 565.37it/s]
100%|██████████| 79/79 [00:00<00:00, 503.11it/s]
100%|██████████| 79/79 [00:00<00:00, 506.30it/s]
100%|██████████| 79/79 [00:00<00:00, 466.69it/s]
100%|██████████| 79/79 [00:00<00:00, 423.82it/s]
100%|██████████| 79/79 [00:00<00:00, 506.78it/s]
100%|██████████| 79/79 [00:00<00:00, 471.69it/s]
100%|██████████| 79/79 [00:00<00:00, 510.52it/s]
100%|██████████| 79/79 [00:00<00:00, 534.20it/s]
100%|██████████| 79/79 [00:00<00:00, 502.83it/s]
100%|██████████| 79/

Epoch 300/400, Loss: 0.4317


100%|██████████| 79/79 [00:00<00:00, 505.41it/s]
100%|██████████| 79/79 [00:00<00:00, 382.46it/s]
100%|██████████| 79/79 [00:00<00:00, 422.41it/s]
100%|██████████| 79/79 [00:00<00:00, 462.98it/s]
100%|██████████| 79/79 [00:00<00:00, 479.59it/s]
100%|██████████| 79/79 [00:00<00:00, 549.28it/s]
100%|██████████| 79/79 [00:00<00:00, 501.82it/s]
100%|██████████| 79/79 [00:00<00:00, 546.36it/s]
100%|██████████| 79/79 [00:00<00:00, 480.23it/s]
100%|██████████| 79/79 [00:00<00:00, 491.76it/s]
100%|██████████| 79/79 [00:00<00:00, 504.23it/s]
100%|██████████| 79/79 [00:00<00:00, 502.75it/s]
100%|██████████| 79/79 [00:00<00:00, 466.11it/s]
100%|██████████| 79/79 [00:00<00:00, 339.58it/s]
100%|██████████| 79/79 [00:00<00:00, 451.81it/s]
100%|██████████| 79/79 [00:00<00:00, 452.39it/s]
100%|██████████| 79/79 [00:00<00:00, 460.16it/s]
100%|██████████| 79/79 [00:00<00:00, 482.20it/s]
100%|██████████| 79/79 [00:00<00:00, 374.41it/s]
100%|██████████| 79/79 [00:00<00:00, 400.53it/s]
100%|██████████| 79/

Epoch 350/400, Loss: 0.4437


100%|██████████| 79/79 [00:00<00:00, 419.04it/s]
100%|██████████| 79/79 [00:00<00:00, 488.18it/s]
100%|██████████| 79/79 [00:00<00:00, 526.61it/s]
100%|██████████| 79/79 [00:00<00:00, 496.55it/s]
100%|██████████| 79/79 [00:00<00:00, 503.76it/s]
100%|██████████| 79/79 [00:00<00:00, 517.55it/s]
100%|██████████| 79/79 [00:00<00:00, 462.90it/s]
100%|██████████| 79/79 [00:00<00:00, 438.55it/s]
100%|██████████| 79/79 [00:00<00:00, 409.41it/s]
100%|██████████| 79/79 [00:00<00:00, 481.34it/s]
100%|██████████| 79/79 [00:00<00:00, 347.50it/s]
100%|██████████| 79/79 [00:00<00:00, 411.10it/s]
100%|██████████| 79/79 [00:00<00:00, 475.49it/s]
100%|██████████| 79/79 [00:00<00:00, 375.84it/s]
100%|██████████| 79/79 [00:00<00:00, 432.23it/s]
100%|██████████| 79/79 [00:00<00:00, 492.27it/s]
100%|██████████| 79/79 [00:00<00:00, 493.73it/s]
100%|██████████| 79/79 [00:00<00:00, 476.97it/s]
100%|██████████| 79/79 [00:00<00:00, 489.59it/s]
100%|██████████| 79/79 [00:00<00:00, 565.63it/s]
100%|██████████| 79/

Epoch 400/400, Loss: 0.4406
Diffusion process animation saved as 'diffusion_process.gif'
Training loss plot saved as 'training_loss.png'


In [49]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch

# Sampling function
@torch.no_grad()
def sample(model, n_samples, device, source_samples, target_samples):
    model.eval()
    x = torch.tensor(source_samples).to(device)  # Start with source samples
    
    fig, ax = plt.subplots(figsize=(10, 10))
    scatter = ax.scatter([], [], s=25, c='red', alpha=0.7, label='Generated')
    source_scatter = ax.scatter(source_samples[:, 0], source_samples[:, 1], s=30, c='blue', alpha=0.3, label='Source')
    target_scatter = ax.scatter(target_samples[:, 0], target_samples[:, 1], s=30, c='green', alpha=0.3, label='Target')
    time_text = ax.text(0.05, 0.95, '', transform=ax.transAxes, fontsize=20, horizontalalignment='left', verticalalignment='top')
    
    # Set axis limits based on the data
    all_data = np.vstack((source_samples, target_samples))
    x_min, x_max = all_data[:, 0].min(), all_data[:, 0].max()
    y_min, y_max = all_data[:, 1].min(), all_data[:, 1].max()
    margin = 0.05 * max(x_max - x_min, y_max - y_min)  # Reduced margin
    ax.set_xlim(x_min - margin, x_max + margin)
    ax.set_ylim(y_min - margin, y_max + margin)
    
    ax.legend(loc='lower left', fontsize=20)
    
    # Remove axis and ticks
    ax.axis('off')
    
    # Draw lines between initial and final states
    lines = [ax.plot([], [], color='red', alpha=0.1, linewidth=0.5)[0] for _ in range(n_samples)]
    
    def update(frame):
        nonlocal x
        if frame <= timesteps:
            if frame == 0:
                # Initial source samples
                scatter.set_offsets(x.cpu().numpy())
                time_text.set_text('Initial source samples')
            else:
                t = timesteps - frame
                t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
                predicted_noise = model(x, t_batch.float())
                alpha = alphas[t]
                alpha_hat = alphas_cumprod[t]
                beta = betas[t]
                if t > 0:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
                
                scatter.set_offsets(x.cpu().numpy())
                time_text.set_text(f't = {frame}/{timesteps}')
        
        # Update lines
        for i, line in enumerate(lines):
            line.set_data([source_samples[i, 0], x[i, 0].cpu()], [source_samples[i, 1], x[i, 1].cpu()])
        
        return scatter, time_text, *lines
    
    anim = animation.FuncAnimation(fig, update, frames=timesteps+20, interval=50, blit=True)
    
    # Save as GIF with tight layout
    plt.tight_layout()
    anim.save('diffusion_process.gif', writer='pillow', fps=30)
    plt.close(fig)
    
    return x

# Generate source and target samples
source_samples = torch.randn(2000, 2).numpy()
target_samples = X.numpy()[:2000]

# Generate samples and create animation
samples = sample(model, 2000, device, source_samples, target_samples)

# Plot loss with increased font and tick size
plt.figure(figsize=(10, 6))
plt.plot(range(1, epochs + 1), losses, linewidth=3)
plt.title("Training Loss", fontsize=20)
plt.xlabel("Epoch", fontsize=18)
plt.ylabel("Loss", fontsize=18)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

# Add red dotted line at minimum value
min_loss = min(losses)
plt.axhline(y=min_loss, color='red', linestyle=':', linewidth=2)

# Add text 'L_diffusion > 0'
plt.text(epochs/2, min_loss*1.1, r'$L_\mathrm{diffusion} > 0$', fontsize=16, 
         horizontalalignment='center', verticalalignment='bottom')

plt.tight_layout()
plt.savefig('training_loss.png', dpi=300, bbox_inches='tight')
plt.close()

print("Diffusion process animation saved as 'diffusion_process.gif'")
print("Training loss plot saved as 'training_loss.png'")

Diffusion process animation saved as 'diffusion_process.gif'
Training loss plot saved as 'training_loss.png'


In [46]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch

# Sampling function
@torch.no_grad()
def sample(model, n_samples, device, source_samples, target_samples):
    model.eval()
    x = torch.tensor(source_samples).float().to(device)
    
    fig, ax = plt.subplots(figsize=(10, 10))
    scatter = ax.scatter([], [], s=25, c='red', alpha=0.7, label='Generated')
    source_scatter = ax.scatter(source_samples[:, 0], source_samples[:, 1], s=30, c='blue', alpha=0.3, label='Source')
    target_scatter = ax.scatter(target_samples[:, 0], target_samples[:, 1], s=30, c='green', alpha=0.3, label='Target')
    time_text = ax.text(0.05, 0.95, '', transform=ax.transAxes, fontsize=20)
    
    # Set axis limits based on the data
    all_data = np.vstack((source_samples, target_samples))
    x_min, x_max = all_data[:, 0].min(), all_data[:, 0].max()
    y_min, y_max = all_data[:, 1].min(), all_data[:, 1].max()
    margin = 0.1 * max(x_max - x_min, y_max - y_min)
    ax.set_xlim(x_min - margin, x_max + margin)
    ax.set_ylim(y_min - margin, y_max + margin)
    
    ax.legend(loc='lower left', fontsize=20)
    
    # Remove axis and ticks
    ax.axis('off')
    
    # Draw lines between initial and final states
    lines = [ax.plot([], [], color='red', alpha=0.1, linewidth=0.5)[0] for _ in range(n_samples)]
    
    # Number of intermediate frames
    n_frames = 30
    
    # Compute destination samples
    t = torch.ones(n_samples, device=device) * (timesteps - 1)  # Start from t=T-1
    predicted_noise = model(x, t)
    # Use the formula: x_0 = 1/sqrt(alpha_bar_t) * x_t - (1-alpha_bar_t)/sqrt(1-alpha_bar_t) * epsilon_theta
    alpha_bar_t = alphas_cumprod[timesteps - 1]
    destination_samples = ((1 / torch.sqrt(alpha_bar_t)) * ( x - 
                               torch.sqrt(1 - alpha_bar_t)  * predicted_noise)).cpu().numpy()
    
    print(destination_samples.shape, type(destination_samples))
    print(source_samples.shape, type(source_samples))

    def update(frame):
        # Linearly interpolate between source and destination samples
        if frame > n_frames:
            frame = n_frames
        t = frame / (n_frames - 1)
        current_samples = (1 - t) * source_samples + t * destination_samples
        
        scatter.set_offsets(current_samples)
        if frame < n_frames:
            time_text.set_text(f'Initial source samples')
        else:
            time_text.set_text(f'One step prediction')
        
        # Update lines
        for i, line in enumerate(lines):
            line.set_data([source_samples[i, 0], current_samples[i, 0]], 
                          [source_samples[i, 1], current_samples[i, 1]])
        
        return scatter, time_text, *lines
    
    anim = animation.FuncAnimation(fig, update, frames=n_frames+15, interval=50, blit=True)
    
    # Save as GIF with tight layout and pause at the end
    plt.tight_layout()
    anim.save('one_step_prediction.gif', writer='pillow', fps=30)  # Removed save_count argument
    plt.close(fig)
    
    return x

# Generate source and target samples
source_samples = torch.randn(2000, 2).numpy()
target_samples = X.numpy()[:2000]

# Generate samples and create animation
samples = sample(model, 2000, device, source_samples, target_samples)


print("One step prediction animation saved as 'one_step_prediction.gif'")

(2000, 2) <class 'numpy.ndarray'>
(2000, 2) <class 'numpy.ndarray'>
One step prediction animation saved as 'one_step_prediction.gif'


In [32]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch

# Sampling function
@torch.no_grad()
def sample(model, n_samples, device, source_samples, target_samples):
    model.eval()
    x = torch.tensor(source_samples).float().to(device)
    
    fig, ax = plt.subplots(figsize=(10, 10))
    scatter = ax.scatter([], [], s=25, c='red', alpha=0.7, label='Generated')
    source_scatter = ax.scatter(source_samples[:, 0], source_samples[:, 1], s=30, c='blue', alpha=0.3, label='Source')
    target_scatter = ax.scatter(target_samples[:, 0], target_samples[:, 1], s=30, c='green', alpha=0.3, label='Target')
    time_text = ax.text(0.05, 0.95, '', transform=ax.transAxes, fontsize=20)
    
    # Set axis limits based on the data
    all_data = np.vstack((source_samples, target_samples))
    x_min, x_max = all_data[:, 0].min(), all_data[:, 0].max()
    y_min, y_max = all_data[:, 1].min(), all_data[:, 1].max()
    margin = 0.1 * max(x_max - x_min, y_max - y_min)
    ax.set_xlim(x_min - margin, x_max + margin)
    ax.set_ylim(y_min - margin, y_max + margin)
    
    ax.legend(loc='lower left', fontsize=20)
    
    # Remove axis and ticks
    ax.axis('off')
    
    # Draw lines between initial and final states
    lines = [ax.plot([], [], color='red', alpha=0.1, linewidth=0.5)[0] for _ in range(n_samples)]
    
    def update(frame):
        if frame == 0:
            current_samples = source_samples
        else:
            t = torch.ones(n_samples).to(device)  # Changed this line to match dimensions
            noise_estimate = model(x, t)
            current_samples = (x - noise_estimate).cpu().numpy()
        
        scatter.set_offsets(current_samples)
        time_text.set_text(f't = {frame}')
        
        # Update lines
        for i, line in enumerate(lines):
            line.set_data([source_samples[i, 0], current_samples[i, 0]], [source_samples[i, 1], current_samples[i, 1]])
        
        return scatter, time_text, *lines
    
    anim = animation.FuncAnimation(fig, update, frames=2, interval=1000, blit=True)
    
    # Save as GIF with tight layout
    plt.tight_layout()
    anim.save('one_step_diffusion.gif', writer='pillow', fps=1)
    plt.close(fig)
    
    return x - model(x, torch.ones(n_samples).to(device))  # Changed this line to match dimensions

# Generate source and target samples
source_samples = torch.randn(2000, 2).numpy()
target_samples = X.numpy()[:2000]

# Generate samples and create animation
samples = sample(model, 2000, device, source_samples, target_samples)

print("One-step diffusion animation saved as 'one_step_diffusion.gif'")


One-step diffusion animation saved as 'one_step_diffusion.gif'
