# VAE (Variational AutoEncoder)
- MNIST 숫자 이미지 구현 및 학습

In [1]:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torch.nn.functional as F

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# MNIST 데이터셋 로드
dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 481kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.44MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.72MB/s]


In [4]:
# VAE 모델 정의
class VAE(nn.Module):
  def __init__(self, input_dim, hidden_dim=200, z_dim=20):
    super(VAE, self).__init__()

    self.img2hid = nn.Linear(input_dim, hidden_dim)
    self.hid2mu = nn.Linear(hidden_dim, z_dim)
    self.hid2sigma = nn.Linear(hidden_dim, z_dim)

    self.z2hid = nn.Linear(z_dim, hidden_dim)
    self.hid2img = nn.Linear(hidden_dim, input_dim)

    self.relu = nn.ReLU()

  def encoder(self, x):
    x = self.img2hid(x)
    x = self.relu(x)
    mu = self.hid2mu(x)
    sigma = self.hid2sigma(x)
    return mu, sigma

  def decoder(self, z):
    z = self.z2hid(z)
    z = self.relu(z)
    x = self.hid2img(z)
    x = torch.sigmoid(x)
    return x

  def forward(self, x):
    mu, sigma = self.encoder(x)
    epsilon = torch.randn_like(sigma)
    z_reparam = mu + sigma + epsilon
    x_reconst = self.decoder(z_reparam)
    return x_reconst, mu, sigma

In [5]:
# 모델 생성 및 학습 설정
model = VAE(784, 200, 20).to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.BCELoss(reduction='sum')

In [6]:
for epoch in range(10):
  for i, (x, _) in tqdm(enumerate(train_loader)):
    x = x.to(device).view(x.shape[0], 784)

    x_reconst, mu, sigma = model(x)

    reconst_loss = criterion(x_reconst, x)

    kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

    loss = reconst_loss + kl_div
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


1875it [00:12, 156.22it/s]
1875it [00:11, 166.19it/s]
1875it [00:11, 168.99it/s]
1875it [00:11, 166.09it/s]
1875it [00:11, 168.83it/s]
1875it [00:11, 168.50it/s]
1875it [00:11, 169.07it/s]
1875it [00:11, 167.75it/s]
1875it [00:11, 165.24it/s]
1875it [00:10, 171.82it/s]


In [10]:
model = model.to('cpu')

# 추론 함수
def inference(digit, num_samples=5):
  images = []
  idx = 0

  for x, y in dataset:
    if y == digit:
      images.append(x)
      idx += 1
      if idx >= num_samples:
        break

  encoding_digit = []
  for img in images:
    with torch.no_grad():
      mu, sigma = model.encoder(img.view(1, 784))
    encoding_digit.append((mu, sigma))

  mu, sigma = encoding_digit[0]

  for example in range(num_samples):
    epsilon = torch.randn_like(sigma)
    z = mu + sigma + epsilon
    out = model.decoder(z)
    out = out.view(-1, 1, 28, 28)
    save_image(out, f'digit{digit}_sample_{example}.png')

In [11]:
inference(7)