In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Mar 10 20:48:03 2018

@author: lps
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
import torchvision.datasets as dst
from torchvision.utils import save_image

In [5]:
EPOCH = 15
BATCH_SIZE = 64
n = 2   # num_workers
LATENT_CODE_NUM = 32   
log_interval = 10


transform=transforms.Compose([transforms.ToTensor()])
data_train = dst.MNIST('MNIST_data/', train=True, transform=transform, download=True)
data_test = dst.MNIST('MNIST_data/', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=data_train, num_workers=n,batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=data_test, num_workers=n,batch_size=BATCH_SIZE, shuffle=True)

In [6]:
class VAE(nn.Module):
      def __init__(self):
            super(VAE, self).__init__()
      
            self.encoder = nn.Sequential(
                  nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
                  nn.BatchNorm2d(64),
                  nn.LeakyReLU(0.2, inplace=True),
                  
                  nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
                  nn.BatchNorm2d(128),
                  nn.LeakyReLU(0.2, inplace=True),
                      
                  nn.Conv2d(128, 128, kernel_size=3 ,stride=1, padding=1),
                  nn.BatchNorm2d(128),
                  nn.LeakyReLU(0.2, inplace=True),                  
                  )
            
            self.fc11 = nn.Linear(128 * 7 * 7, LATENT_CODE_NUM)
            self.fc12 = nn.Linear(128 * 7 * 7, LATENT_CODE_NUM)
            self.fc2 = nn.Linear(LATENT_CODE_NUM, 128 * 7 * 7)
            
# nn.ConvTranspose2d 逆卷积 反卷积
            self.decoder = nn.Sequential(                
                  nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
                  nn.ReLU(inplace=True),
                  
                  nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
                  nn.Sigmoid()
                  )

      def reparameterize(self, mu, logvar):
            eps = Variable(torch.randn(mu.size(0), mu.size(1))).cuda()
            z = mu + eps * torch.exp(logvar/2)            
            
            return z
      
      def forward(self, x):
             out1, out2 = self.encoder(x), self.encoder(x)  # batch_s, 8, 7, 7
             mu = self.fc11(out1.view(out1.size(0),-1))     # batch_s, latent
             logvar = self.fc12(out2.view(out2.size(0),-1)) # batch_s, latent
             z = self.reparameterize(mu, logvar)      # batch_s, latent      
             out3 = self.fc2(z).view(z.size(0), 128, 7, 7)    # batch_s, 8, 7, 7
             
             return self.decoder(out3), mu, logvar


def loss_func(recon_x, x, mu, logvar):
      BCE = F.binary_cross_entropy(recon_x, x,  size_average=False)
      KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
      
      return BCE+KLD


vae = VAE().cuda()
optimizer =  optim.Adam(vae.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

torch中的exp函数

torch中的exp函数是指数函数,用于将输入的每个元素取e的指数幂。 在PyTorch中,可以使用torch.exp()函数来计算输入张量的指数。函数输入可以是张量、变量或标量。返回值是一个新的张量,其中每个元素的值为输入张量中相应元素的指数幂。

VAE的损失函数：由两部分的和组成（bce_loss、kld_loss)。bce_loss即为binary_cross_entropy（二分类交叉熵）损失，即用于衡量原图与生成图片的像素误差。kld_loss即为KL-divergence（KL散度），用来衡量潜在变量的分布和单位高斯分布的差异

enumerate函数在PyTorch中是一个强大而便利的工具,用于在迭代过程中获取元素的索引和值。

In [16]:
def train(EPOCH):
      vae.train()
      total_loss = 0
      for i, (data, _) in enumerate(train_loader, 0):
            data = Variable(data).cuda()
            optimizer.zero_grad()
            recon_x, mu, logvar = vae.forward(data)
            loss = loss_func(recon_x, data, mu, logvar)
            loss.backward()
#             total_loss += loss.data[0].item()
            total_loss += loss.item()
            
            optimizer.step()
            
            if i % log_interval == 0:
                  sample = Variable(torch.randn(64, LATENT_CODE_NUM)).cuda()
                  sample = vae.decoder(vae.fc2(sample).view(64, 128, 7, 7)).cpu()
                  save_image(sample.data.view(64, 1, 28, 28),
               'sample_' + str(epoch) + '.png')
                  print('Train Epoch:{} -- [{}/{} ({:.0f}%)] -- Loss:{:.6f}'.format(
                              epoch, i*len(data), len(train_loader.dataset), 
                              100.*i/len(train_loader), loss.item()/len(data)))
                  
      print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, total_loss / len(train_loader.dataset)))
      

In [17]:
for epoch in range(1, EPOCH):
    train(epoch)



====> Epoch: 1 Average loss: 128.6307
====> Epoch: 2 Average loss: 110.2799
====> Epoch: 3 Average loss: 107.3735
====> Epoch: 4 Average loss: 105.8038
====> Epoch: 5 Average loss: 104.7233
====> Epoch: 6 Average loss: 103.8183
====> Epoch: 7 Average loss: 103.1661
====> Epoch: 8 Average loss: 102.5843
====> Epoch: 9 Average loss: 102.1093
====> Epoch: 10 Average loss: 101.5805
====> Epoch: 11 Average loss: 101.2629
====> Epoch: 12 Average loss: 100.8935
====> Epoch: 13 Average loss: 100.6292
====> Epoch: 14 Average loss: 100.3773
