# Score Based Model

## ノイズスケジュールの定義

線形にノイズの分散が増加するシンプルなスケジュールを定義します。

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import imageio


In [16]:
# ハイパーパラメータ
timesteps = 1000
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(beta_start, beta_end, timesteps)
alphas = 1. - betas
alpha_cumprod = torch.cumprod(alphas, dim=0)

def noise_schedule(t):
    sqrt_alpha_cumprod_t = torch.sqrt(alpha_cumprod[t])
    sqrt_one_minus_alpha_cumprod_t = torch.sqrt(1 - alpha_cumprod[t])
    return sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t

## スコアネットワークの構築

簡単なMLP（多層パーセプトロン）をスコアネットワークとして使用します。入力はノイズが加えられたデータとタイムステップの埋め込みです。

In [17]:
class ScoreNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, time_embed_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim + time_embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, input_dim)
        self.time_embed = nn.Linear(1, time_embed_dim)

    def forward(self, x, t):
        # タイムステップを埋め込み
        t_embed = self.time_embed(t[:, None].float() / timesteps)
        h = torch.cat([x, t_embed], dim=1)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        return self.fc3(h)

## 損失関数の定義

ノイズを加えたデータから推定されたスコアと、真のノイズの負の値を近づける損失関数を使用します。

In [18]:
def loss_fn(model, x_0, t):
    sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = noise_schedule(t)
    noise = torch.randn_like(x_0)
    # ノイズを加えたデータ
    x_t = sqrt_alpha_cumprod_t[:, None] * x_0 + sqrt_one_minus_alpha_cumprod_t[:, None] * noise
    # スコアネットワークによるスコアの推定
    predicted_score = model(x_t, t)
    # 真のスコアはノイズの負の値に比例する（重み付けは省略）
    target_score = -noise
    loss = F.mse_loss(predicted_score, target_score)
    return loss

## 学習ループ

簡単な学習ループの例です。

In [19]:
# データ生成 (例として1次元のガウス分布からのサンプル)
def generate_data(n_samples=500):
    return torch.randn(n_samples, 1) * 2 + 5

data = generate_data(500)
dataset = TensorDataset(data)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# データの可視化
plt.hist(data.numpy(), bins=30, density=True)
plt.title("Data Distribution")
plt.xlabel("Value")
plt.ylabel("Density")
plt.show()

  plt.show()


In [20]:
# モデルとオプティマイザの初期化
input_dim = 1
hidden_dim = 128
time_embed_dim = 32
model = ScoreNet(input_dim, hidden_dim, time_embed_dim)
optimizer = Adam(model.parameters(), lr=1e-3)

epochs = 1000
for epoch in range(epochs):
    for batch in dataloader:
        x_0 = batch[0]
        t = torch.randint(0, timesteps, (x_0.shape[0],))
        loss = loss_fn(model, x_0, t)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

Epoch [100/1000], Loss: 0.3763
Epoch [200/1000], Loss: 0.2844
Epoch [300/1000], Loss: 0.3956
Epoch [400/1000], Loss: 0.4622
Epoch [500/1000], Loss: 0.3176
Epoch [600/1000], Loss: 0.4808
Epoch [700/1000], Loss: 0.4839
Epoch [800/1000], Loss: 0.5504
Epoch [900/1000], Loss: 0.4184
Epoch [1000/1000], Loss: 0.3083


In [23]:
@torch.no_grad()
def sample_and_record(model, n_samples, timesteps, alphas, alpha_cumprod, betas, device="cpu"):
    x_t = torch.randn(n_samples, 1).to(device)
    intermediate_frames = []
    num_steps = timesteps // 20 # 例えば20ステップごとに保存
    steps_to_save = np.linspace(timesteps - 1, 0, num_steps, dtype=int)

    for i in reversed(range(timesteps)):
        t = torch.ones(n_samples, dtype=torch.long).to(device) * i
        sqrt_alpha_t = torch.sqrt(alphas[i])
        beta_t = betas[i]
        score_t = model(x_t, t)
        x_t = (1 / sqrt_alpha_t) * (x_t - (beta_t / torch.sqrt(1 - alpha_cumprod[i])) * score_t)
        if i > 0:
            noise = torch.randn_like(x_t)
            posterior_variance = beta_t
            x_t = x_t + torch.sqrt(posterior_variance) * noise

        if i in steps_to_save:
            fig, ax = plt.subplots()
            ax.hist(x_t.cpu().numpy(), bins=30, alpha=0.7)
            ax.set_title(f"Timestep: {i}")
            fig.canvas.draw()
            # Use buffer_rgba and convert RGBA to RGB
            image = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
            w, h = fig.canvas.get_width_height()
            image = image.reshape((h, w, 4))[..., :3]  # Drop alpha channel
            intermediate_frames.append(image)
            plt.close()

    return x_t.cpu().numpy(), intermediate_frames

# サンプリングの実行と結果の可視化
model.eval()
sampled_data, frames = sample_and_record(model, n_samples=500, timesteps=timesteps, alphas=alphas, alpha_cumprod=alpha_cumprod, betas=betas)

In [24]:
# GIFとして保存
output_gif_path = "sampling_process_1d.gif"
imageio.mimsave(output_gif_path, frames, duration=0.1) # durationはフレーム間の時間 (秒)
print(f"GIF saved to: {output_gif_path}")

GIF saved to: sampling_process_1d.gif


In [25]:
# 必要に応じて最終的なサンプルの可視化
plt.figure(figsize=(8, 6))
plt.hist(generate_data(1000).numpy(), bins=30, alpha=0.5, label='Original Data')
plt.hist(sampled_data, bins=30, alpha=0.5, label='Sampled Data')
plt.legend()
plt.title('Final Sampled Data')
plt.show()

  plt.show()
