<a href="https://colab.research.google.com/github/taekyungss/computer_vision_planting_grass/blob/main/VAE_MLP%20%ED%8C%8C%EC%9D%B4%ED%86%A0%EC%B9%98%20%EA%B5%AC%ED%98%84.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
import random
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets,transforms
from torchvision.utils import save_image


In [None]:
image_path = './images'
channels = 1                    # MNIST has only 1

n_epochs = 30
batch_size = 128
lr = 1e-3
b1 = 0.5
b2 = 0.999

img_size = 28
hidden_dim = 400
latent_dim = 10

In [None]:
os.makedirs(image_path, exist_ok = True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# data 불러오는 code / transform을 통해서는 데이터를 요리조리 가능 ex resize / augmentation등 가능
transform = transforms.Compose([
            transforms.ToTensor()
])

# 각각의 데이터셋을 불러오기 => 여기서는 data를 datasets를 이용해서 받아와서 사용하도록 함

train = datasets.MNIST(root = './data/', train=True, transform = transform, download=True)
test = datasets.MNIST(root = './data/', train=False, transform = transform, download=True)

# dataloader의 경우에는 해당 데이터셋을 불러오고 shuffle 즉 인덱스별로 순서대로 할건지 아님 섞을건지를 결정할 수 있도록 함
train_dataloader = torch.utils.data.DataLoader(
    train,
    batch_size = batch_size,
    shuffle = True,
)

test_dataloader = torch.utils.data.DataLoader(
    test,
    batch_size = batch_size,
    shuffle = False
)

In [None]:
def reparameterization(mu, logvar):
  std = torch.exp(logvar/2)
  eps = torch.randn_like(std)
  return mu+eps+std

In [None]:
class Encoder(nn.Module):
  def __init__(self, x_dim = img_size**2, h_dim = hidden_dim, z_dim = latent_dim):
    super(Encoder, self).__init__()

    # 1st hidden layer
    self.fc1 = nn.Sequential(
        nn.Linear(x_dim,h_dim),
        nn.ReLU(),
         nn.Dropout(p=0.2)
    )

    self.fc2 = nn.Sequential(
        nn.Linear(h_dim, h_dim),
        nn.ReLU(),
        nn.Dropout(p=0.2)
    )

    # output later
    self.mu = nn.Linear(h_dim, z_dim)
    self.logvar = nn.Linear(h_dim, z_dim)

  def forward(self,x):
    x = self.fc(self.fc1(x))

    mu = F.relu(self.mu(x))
    logvar = F.relu(self.logvar(x))
    z = reparameterization(mu,logvar)
    return z, mu, logvar

In [None]:
class Decoder(nn.Module):
  def __init__(self, x_dim = img_size**2, h_dim = hidden_dim, z_dim = latent_dim):
    super(Decoder, self).__init__()

    # 1st hidden layer
    self.fc1 = nn.Sequential(
        nn.Linear(z_dim, h_dim),
        nn.ReLU(),
        nn.Dropout(p=0.2)
    )

    # 2nd hidden layer
    self.fc2 = nn.Sequential(
        nn.Linear(h_dim, h_dim),
        nn.ReLU(),
        nn.Dropout(p=0.2)
    )
    # output layer
    self.fc3 = nn.Linear(h_dim, x_dim)

  def forward(self, z):
    z = self.fc2(self,fc1(z))
    x_reconst = F.sigmoid(self.fc3(z))
    return x_reconst


In [None]:
encoder = Encoder().to(device)
decoder = Decoder().to(device)
optimizer = torch.optim.Adam(
    itertools.chain(encoder.parameters(), decoder.parameters()), lr=lr, betas=(b1, b2)
)

# 둘다 gpu device연결후, optimizer 설정 -> adam

In [None]:
print(encoder)

In [None]:
print(decoder)

In [None]:
for epoch in range(n_epochs):
    train_loss = 0
    for i, (x, _) in enumerate(train_dataloader):
        # forward
        x = x.view(-1, img_size**2)
        x = x.to(device)
        z, mu, logvar = encoder(x)
        x_reconst = decoder(z)

        # compute reconstruction loss and KL divergence
        reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction='sum')
        kl_div = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1)

        # backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        if (i+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{n_epochs}], Step [{i+1}/{len(train_dataloader)}], Reconst Loss : {reconst_loss.item():.4f}, KL Div: {kl_div.item():.4f}')

    print(f'===> Epoch: {epoch+1} Average Train Loss: {train_loss/len(train_dataloader.dataset):.4f} ')

    test_loss = 0
    with torch.no_grad():
        for i, (x, _) in enumerate(test_dataloader):
            # forward
            x = x.view(-1, img_size**2)
            x = x.to(device)
            z, mu, logvar = encoder(x)
            x_reconst = decoder(z)

            # compute reconstruction loss and KL divergence
            reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction='sum')
            kl_div = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1)

            loss = reconst_loss + kl_div
            test_loss += loss.item()

            # save reconstruction images
            if i==0:
                x_concat = torch.cat([x.view(-1, 1, 28, 28), x_reconst.view(-1, 1, 28, 28)], dim=3)
                # batch size 개수만큼의 이미지 쌍(input x, reconstructed x)이 저장됨
                save_image(x_concat, os.path.join(image_path,f'reconst-epoch{epoch+1}.png'))

        print(f'===> Epoch: {epoch+1} Average Test Loss: {test_loss/len(test_dataloader.dataset):.4f} ')

        # save sampled images
        z = torch.randn(batch_size, latent_dim).to(device) # N(0, 1)에서 z 샘플링
        sampled_images = decoder(z)
        save_image(sampled_images.view(-1, 1, 28, 28), os.path.join(image_path,f'sampled-epoch{epoch+1}.png'))