<a href="https://colab.research.google.com/github/youngmin5068/GAN_research/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as utils
from torch.utils.data import Dataset,DataLoader,random_split
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
import os
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample_dir = '/content/drive/MyDrive/samples'
if not os.path.exists(sample_dir):
  os.makedirs(sample_dir)
# 하이퍼파라미터 설정
latent_size = 64
hidden_size = 256
image_size = 784 # 28 * 28
num_epochs = 300
batch_size = 100
from operator import truediv

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],std=[0.5])])

class ganDataset(Dataset):
  def __init__(self,transform=True,train=True):
    self.train = train
    self.transform = transform
    self.mnist_train = datasets.MNIST(root="/content/drive/MyDrive/data",
                                 train=True,
                                 transform=self.transform,
                                 download=True)

    self.mnist_test = datasets.MNIST(root="/content/drive/MyDrive/data",
                                 train=False,
                                 transform=self.transform,
                                 download=True)

  def __len__(self):
    if self.train == True:
      return len(self.mnist_train)
    else:
      return len(self.mnist_test)
  
  def __getitem__(self,idx):
    if self.train == True:
      return self.mnist_train[idx][0]
    else:
      return self.mnist_test[idx][0]



train_dataset = ganDataset(transform=transform,train=True)
test_dataset = ganDataset(transform=transform,train=False)
train_loader = DataLoader(train_dataset,batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=batch_size, shuffle=True)


#generator
G = nn.Sequential(
    nn.Linear(latent_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
)

#discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,1),
    nn.Sigmoid()
)

G.to(device)
D.to(device)

def imshow(img):
    img = (img+1) / 2
    img = img.squeeze() # 차원 중 사이즈 1 을 제거
    np_img = img.numpy() # 이미지 픽셀을 넘파이 배열로 변환
    plt.imshow(np_img,cmap='gray')
    plt.show()

def imshow_grid(img): 
    img = utils.make_grid(img.cpu().detach()) # 이미지 그리드 생성, 이미지 출력만을 위해 cpu에 담고 추적 방이
    img = (img+1)/2
    npimg = img.numpy() # 이미지 픽셀을 넘파이 배열로 변환
    plt.imshow(np.transpose(npimg, (1,2,0)))
    plt.show()

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

dx_epoch=[]
dgx_epoch=[]
total_step = len(train_loader)

for epoch in range(num_epochs):
  for i, images in enumerate(train_loader):
    images = images.reshape(batch_size,-1).to(device)
    
    real_labels = torch.ones(batch_size,1).to(device)
    fake_labels = torch.zeros(batch_size,1).to(device)    
    
    #train discriminator
    outputs = D(images)
    d_loss_real = criterion(outputs,real_labels)
    real_score = outputs

    z = torch.randn(batch_size, latent_size).to(device)
    fake_images = G(z)
    outputs = D(fake_images)
    d_loss_fake = criterion(outputs,fake_labels)
    fake_score = outputs

    d_loss = d_loss_real + d_loss_fake

    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
    
    d_loss.backward()
    d_optimizer.step()

    #train generator
    z = torch.randn(batch_size, latent_size).to(device)
    fake_images = G(z)
    outputs = D(fake_images)
    g_loss = criterion(outputs,real_labels)

    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

    g_loss.backward()
    g_optimizer.step()

    if (i+1) % 200 == 0:
      print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
      .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
      
    dx_epoch.append(real_score.mean().item())            
    dgx_epoch.append(fake_score.mean().item())
      
  if (epoch+1) == 1:
      images = images.reshape(images.size(0), 1, 28, 28)
      save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # 생성된 이미지 저장
  fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
  save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))



print("images의 shape : ",images.shape)


len(dx_epoch)
dx_epoch
# plot    
plt.figure(figsize = (12, 8))
plt.xlabel('epoch')
plt.ylabel('score')
x = np.arange(180000)
len(x), len(dx_epoch)
plt.plot(x, dx_epoch, 'g', label='D(x)')
plt.plot(x, dgx_epoch, 'b', label='D(G(z))')
plt.legend()
plt.show()
from PIL import Image
real_img = np.array(Image.open("/content/drive/MyDrive/samples/real_images.png"))
fake_img = np.array(Image.open("/content/drive/MyDrive/samples/fake_images-299.png"))
plt.figure(figsize=(12,8))
plt.subplot(1,2,1)
plt.imshow(real_img)
plt.axis('off')

plt.subplot(1,2,2)
plt.imshow(fake_img)
plt.axis('off')
fake_img = Image.open("/content/drive/MyDrive/samples/fake_images-299.png")
fake_img
plt.show()

In [None]:
print("images의 shape : ",images.shape)

images의 shape :  torch.Size([100, 784])
