In [69]:
import torch
from torch.autograd import Variable

In [70]:
def to_var(x):
  if torch.cuda.is_available():
    x = x.cuda()
  return Variable(x)

In [71]:
import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [72]:
transform = transforms.Compose([transforms.ToTensor(),              
                  transforms.Normalize(mean=0.5,std=0.5)]) 

In [73]:
train_dataset = dsets.MNIST(root='./data',
                train=True,
                transform=transform,
                download=True)
test_dataset = dsets.MNIST(root='./data',
                train=False,
                transform=transform,
                download=False)

In [74]:
import torch.utils.data as Data

In [75]:
data_loader = Data.DataLoader(dataset=train_dataset,
                 batch_size = 100,
                 shuffle=True)

In [76]:
import torch.nn as nn

In [77]:
D = nn.Sequential(
    nn.Linear(28*28, 256),             
    nn.LeakyReLU(0.2),             
    nn.Linear(256, 256),             
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),    
    nn.Sigmoid()
)

In [78]:
G = nn.Sequential(
    nn.Linear(64, 256),               
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 28*28),    
    nn.Tanh()                     
) 

In [79]:
if torch.cuda.is_available():
  D.cuda()
  G.cuda()

In [80]:
loss_fn = nn.BCELoss()
d_opt = torch.optim.Adam(D.parameters(),lr=0.0003)
g_opt = torch.optim.Adam(G.parameters(),lr=0.0003)

In [81]:
from torchvision.utils import save_image 

In [82]:
def denorm(x):
  out = (x+1)/2                     #denormalize
  return out.clamp(0,1)

In [83]:
for epoch in range(200):
  for i, (images,_) in enumerate(data_loader):      #enumerate給編號
    batch_size = images.size(0)
    images = to_var(images.view(batch_size, -1))

    real_labels = to_var(torch.ones(batch_size,1))
    fake_labels = to_var(torch.zeros(batch_size,1))

    outputs = D(images)
    d_loss_real = loss_fn(outputs, real_labels)
    real_score = outputs
    
    z = to_var(torch.randn(batch_size, 64))       #randn 有正有負 產生tensor
    fake_images = G(z)
    outputs = D(fake_images)
    d_loss_fake = loss_fn(outputs, fake_labels)
    fake_score = outputs

    d_loss = d_loss_real+d_loss_fake
    D.zero_grad()
    d_loss.backward()
    d_opt.step()

    z = to_var(torch.randn(batch_size, 64))
    fake_images = G(z)
    outputs = D(fake_images)

    g_loss = loss_fn(outputs, real_labels)
    D.zero_grad()
    G.zero_grad()
    g_loss.backward()
    g_opt.step()

    if (i+1)%300 == 0:
      print("Epoch %-3d, batch: %-3d, d_loss: %-.4f, g_loss: %-.4f,"
      "D(x): %-.2f, D(G(z)): %-.2f"
      %(epoch,i+1,d_loss.data,g_loss.data,real_score.data.mean(),fake_score.data.mean()))
  if epoch == 0:
    images = images.view(images.size(0), 1, 28, 28)
    save_image(denorm(images),"/content/data/real_images.png")
  fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
  save_image(denorm(fake_images),"/content/data/fake_images-%d.png"%(epoch+1))

Epoch 0  , batch: 300, d_loss: 0.2485, g_loss: 4.3422,D(x): 0.98, D(G(z)): 0.16
Epoch 0  , batch: 600, d_loss: 1.0102, g_loss: 1.8690,D(x): 0.80, D(G(z)): 0.45
Epoch 1  , batch: 300, d_loss: 0.8826, g_loss: 4.1570,D(x): 0.77, D(G(z)): 0.26
Epoch 1  , batch: 600, d_loss: 1.6384, g_loss: 2.1110,D(x): 0.67, D(G(z)): 0.55
Epoch 2  , batch: 300, d_loss: 1.3412, g_loss: 1.2795,D(x): 0.64, D(G(z)): 0.52
Epoch 2  , batch: 600, d_loss: 0.4454, g_loss: 1.9359,D(x): 0.85, D(G(z)): 0.23
Epoch 3  , batch: 300, d_loss: 0.5015, g_loss: 2.7302,D(x): 0.89, D(G(z)): 0.26
Epoch 3  , batch: 600, d_loss: 1.7415, g_loss: 2.1919,D(x): 0.62, D(G(z)): 0.41
Epoch 4  , batch: 300, d_loss: 0.8704, g_loss: 1.8869,D(x): 0.75, D(G(z)): 0.27
Epoch 4  , batch: 600, d_loss: 0.9640, g_loss: 3.7414,D(x): 0.67, D(G(z)): 0.24
Epoch 5  , batch: 300, d_loss: 1.3091, g_loss: 1.4398,D(x): 0.71, D(G(z)): 0.41
Epoch 5  , batch: 600, d_loss: 1.0260, g_loss: 1.6658,D(x): 0.67, D(G(z)): 0.36
Epoch 6  , batch: 300, d_loss: 0.3484, g

# CNN

In [105]:
class DisCriminator(nn.Module): #刻意弱化他
  def __init__(self):
    super(DisCriminator, self).__init__()
    self.conv1 = nn.Conv2d(
        in_channels = 1,
        out_channels = 64,
        kernel_size= 5, #看的範圍大，所以學的不好->讓generator有辦法train
        stride = 2,
        padding = 2,
        bias = True
    )
    self.leaky_relu = nn.LeakyReLU()
    self.dropout_2d = nn.Dropout2d(0.3)
    self.conv2 = nn.Conv2d(
        in_channels = 64,
        out_channels = 128,
        kernel_size= 5, 
        stride = 2,
        padding = 2,
        bias = True
    )
    self.linear1 = nn.Linear(128*7*7, 1, bias=True)
    self.sigmoid = nn.Sigmoid()
  def forward(self, x):
    out = self.conv1(x)
    out = self.leaky_relu(out)
    out = self.dropout_2d(out)
    out = self.conv2(out)
    out = self.leaky_relu(out)
    out = self.dropout_2d(out)
    out = out.view(-1, 128*7*7)
    out = self.linear1(out)
    out = self.sigmoid(out)
    return out

In [106]:
class Generator(nn.Module):
  def __init__(self, latent_dim = 100, batchnorm= True):
    super(Generator, self).__init__()
    self.latent_dim = latent_dim
    self.batchnorm = batchnorm
    self.linear1 = nn.Linear(latent_dim, 7*7*256, bias=False)
    self.bn1d1 = nn.BatchNorm1d(256*7*7) if batchnorm else None
    self.leaky_relu = nn.LeakyReLU()
    self.conv1 = nn.Conv2d(
        in_channels = 256,
        out_channels = 128,
        kernel_size = 5,
        stride = 1,
        padding = 2,
        bias = False
    )
    self.bn2d1 = nn.BatchNorm2d(128) if batchnorm else None
    self.conv2 = nn.ConvTranspose2d(
        in_channels =128,
        out_channels = 64,
        kernel_size = 4,
        stride = 2,
        padding = 1,
        bias = False
    )
    self.bn2d2 = nn.BatchNorm2d(64) if batchnorm else None
    self.conv3 = nn.ConvTranspose2d(
    in_channels =64,
    out_channels = 1,
    kernel_size = 4,
    stride = 2,
    padding = 1,
    bias = False
    )
    self.tanh = nn.Tanh()
  def forward(self, x):
    out = self.linear1(x)
    if self.batchnorm:
      out = self.bn1d1(out)
    out = self.leaky_relu(out)
    out = out.view((-1, 256, 7, 7))
    out = self.conv1(out)
    if self.batchnorm:
      out = self.bn2d1(out)
    out = self.leaky_relu(out)
    out = self.conv2(out)
    if self.batchnorm:
      out = self.bn2d2(out)
    out = self.leaky_relu(out)
    out = self.conv3(out)
    out = self.tanh(out)
    return out

In [107]:
DCG = Generator()
DCD = DisCriminator()

In [108]:
if torch.cuda.is_available():
  DCG.cuda()
  DCD.cuda()

In [109]:
dcd_opt = torch.optim.Adam(DCD.parameters(),lr=0.001, betas=(0.5,0.999))
dcg_opt = torch.optim.Adam(DCG.parameters(),lr=0.0001, betas=(0.5,0.999))

In [110]:
for epoch in range(50):
  for i, (images,_) in enumerate(data_loader):
    batch_size = images.size(0)
    images = to_var(images)

    real_labels = to_var(torch.ones(batch_size,1))
    fake_labels = to_var(torch.zeros(batch_size,1))

    outputs = DCD(images)
    d_loss_real = loss_fn(outputs, real_labels)
    real_score = outputs
    
    z = to_var(torch.randn(batch_size, 100))     
    fake_images = DCG(z)
    outputs = DCD(fake_images)
    d_loss_fake = loss_fn(outputs, fake_labels)
    fake_score = outputs

    d_loss = d_loss_real+d_loss_fake
    DCD.zero_grad()
    d_loss.backward()
    dcd_opt.step()

    z = to_var(torch.randn(batch_size, 100))
    fake_images = DCG(z)
    outputs = DCD(fake_images)

    g_loss = loss_fn(outputs, real_labels)
    DCD.zero_grad()
    DCG.zero_grad()
    g_loss.backward()
    dcg_opt.step()

    if (i+1)%300 == 0:
      print("Epoch %-3d, batch: %-3d, d_loss: %-.4f, g_loss: %-.4f,"
      "D(x): %-.2f, D(G(z)): %-.2f"
      %(epoch,i+1,d_loss.data,g_loss.data,real_score.data.mean(),fake_score.data.mean()))
  if epoch == 0:
    images = images.view(images.size(0), 1, 28, 28)
    save_image(denorm(images),"/content/data1/real_images.png")
  fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
  save_image(denorm(fake_images),"/content/data1/fake_images-%d.png"%(epoch+1))

Epoch 0  , batch: 300, d_loss: 0.8629, g_loss: 1.7478,D(x): 0.56, D(G(z)): 0.10
Epoch 0  , batch: 600, d_loss: 0.7261, g_loss: 1.9313,D(x): 0.74, D(G(z)): 0.25
Epoch 1  , batch: 300, d_loss: 0.8500, g_loss: 1.2757,D(x): 0.66, D(G(z)): 0.27
Epoch 1  , batch: 600, d_loss: 0.8623, g_loss: 1.4074,D(x): 0.68, D(G(z)): 0.29
Epoch 2  , batch: 300, d_loss: 0.7339, g_loss: 1.7774,D(x): 0.71, D(G(z)): 0.25
Epoch 2  , batch: 600, d_loss: 0.9170, g_loss: 1.8160,D(x): 0.74, D(G(z)): 0.35
Epoch 3  , batch: 300, d_loss: 0.8777, g_loss: 1.5642,D(x): 0.71, D(G(z)): 0.29
Epoch 3  , batch: 600, d_loss: 0.8016, g_loss: 1.6359,D(x): 0.75, D(G(z)): 0.29
Epoch 4  , batch: 300, d_loss: 0.9661, g_loss: 1.6579,D(x): 0.71, D(G(z)): 0.32
Epoch 4  , batch: 600, d_loss: 0.8975, g_loss: 1.3144,D(x): 0.63, D(G(z)): 0.24
Epoch 5  , batch: 300, d_loss: 0.8548, g_loss: 1.3861,D(x): 0.63, D(G(z)): 0.21
Epoch 5  , batch: 600, d_loss: 0.9315, g_loss: 1.2002,D(x): 0.62, D(G(z)): 0.21
Epoch 6  , batch: 300, d_loss: 0.7656, g