<a href="https://colab.research.google.com/github/repairedserver/Test/blob/master/%EC%98%A4%ED%86%A0%EC%9D%B8%EC%BD%94%EB%8D%94.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
from torchvision import transforms
import torch.nn.functional as F

import torch.nn as nn
import torch.optim as optim

import numpy as np
import cv2
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
dataset = torchvision.datasets.MNIST('./data/', download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [8]:
class Flatten(torch.nn.Module): #4D -> 2D 계산
  def forward(self, x):
    batch_size = x.shape[0]
    return x.view(batch_size, -1)

class Deflatten(nn.Module): #2D -> 4D 계산

  def __init__(self, k):
    super(Deflatten, self).__init__()
    self.k = k

  def forward(self, x):
    s = x.size()
    feature_size = int((s[1]//self.k)**.5)
    return x.view(s[0], self.k, feature_size, feature_size)

class Autoencoder(nn.Module):
  def __init__(self):
    super(Autoencoder, self).__init__()
    k = 16
        
    self.encoder = nn.Sequential(
                  nn.Conv2d(1, k, 3, stride=2),
                  nn.ReLU(),
                  nn.Conv2d(k, k*2, 3, stride=2),
                  nn.ReLU(),
                  nn.Conv2d(2*k, 4*k, 3, stride=1),
                  nn.ReLU(),
                  Flatten(),
                  nn.Linear(1024, 10),
                  nn.ReLU()
      )

    self.encoder = nn.Sequential(
                     nn.Linear(10, 1024),
                     nn.ReLU(),
                     Deflatten(4*k),
                     nn.ConvTranspose2d(4*k, 2*k, 3, stride=1),
                     nn.ReLU(),
                     nn.ConvTranspose2d(2*k, k, 3, stride=2),
                     nn.ReLU(),
                     nn.ConvTranspose2d(k, 1, 3, stride=2, output_padding=1),
                     nn.Sigmoid()
        )

  def forward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)

    return decoded

In [9]:
model = Autoencoder().to(device)

In [14]:
#시각화 함수
def normalize_output(img):
  img = (img - img.min())/(img.max()-img.min())
  return img

def check_plot():
  with torch.no_grad():
    for data in trainloader:
      inputs = data[0].to(device)
      outputs = model(inputs)

      input_samples = inputs.permute(0, 2, 3, 1).cpu().numpy() #원래 이미지
      reconstructed_samples = outputs.permute(0, 2, 3, 1).cpu().numpy() #생성 이미지
      break

    columns = 10
    rows = 5

    fig = plt.figure(figsize=(columns, rows))

    for i in range(1, columns*rows+1):
      img = input_samples[i-1]
      fig.add_subplot(rows, columns, i)
      plt.imshow(img,2)
      plt.axis('off')
    plt.show()

In [12]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
for epoch in range(51):

  running_loss = 0.0
  for i, data in enumerate(trainloader, 0):
    inputs = data[0].to(device)
    optimizer.zero_grad()
    outputs = model(inputs)

    loss = criterion(inputs, outputs)

    loss.backward()
    optimizer.step()
    running_loss += loss.item()

  cost = running_loss / len(trainloader)

  if epoch % 10 == 0:
    print('[%d] loss : %.3f %%' %(epoch+1, cost))
    check_plot()