<a href="https://colab.research.google.com/github/vvvu/potential-chainsaw/blob/main/pytorch-tutorial/%5BAdvanced%5D_Variational_Autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

### PyTorch中，`nn`与`nn.functional`有什么区别

1. 两者的相同之处
   - `nn.()`和`nn.functional.()`的实际功能是相同的，即`nn.Conv2d()`与`nn.functional.conv2d()`都是进行卷积操作，`nn.Dropout()`和`nn.functional.dropout()`都是进行**dropout**操作
   - 运行效率也**几乎相同**

2. 两者的不同之处

   - `nn.functional.()`是函数结构，而`nn.()`是`nn.functional.()`的**类封装**，且`nn.()`继承于一个公共祖先`nn.Module()`，这就导致`nn.()`除了有`nn.functional.()`的功能以为，**还有`nn.Module()`相关的属性和方法，例如`train(),eval(),load_state_dict(),state_dict()`**等
   - 调用方式不同，`nn.()`需要实例化并传入参数，`nn.functional.()`同时传入输入数据和`weight,bias`等其他参数
   - `nn.()`继承于`nn.Module()`，**这让它可以很好地与`nn.Sequential()`连用**，而`nn.functional.()`则不可以
   - `nn.()`不需要自己定义和管理`Weight`，而`nn.functional.()`需要自己定义`Weight`，每次调用都需要手动传入`Weight`，**不利于代码复用**

3. 总结：两者的应用场景

   > 这个问题依赖于你要解决你问题的复杂度和个人风格喜好。在`nn.()`不能满足你的功能需求时，`nn.functional.()`是更佳的选择，因为`nn.functional.()`更加的灵活(更加接近底层），你可以在其基础上定义出自己想要的功能。

   总而言之，可以使用`nn.()`尽量使用，不可以的情况再换`nn.functional.()`。这样可以更好表达网络的层次关系，且所有的组件都继承于`nn.Module()`，更为和谐统一。

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
  os.makedirs(sample_dir)

In [7]:
# Hyper Parameters
image_size = 28 * 28
h_dim = 400
z_dim = 20 # z-space/latent space
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

# MNIST dataset
dataset = torchvision.datasets.MNIST(root = './data',
                                     train = True,
                                     transform = transforms.ToTensor(),
                                     download = True)

# Data Loader
data_loader = torch.utils.data.DataLoader(dataset = dataset,
                                          batch_size = batch_size,
                                          shuffle = True)

In [4]:
# VAE model
class VAE(nn.Module):
  def __init__(self, image_size = 784, h_dim = 400, z_dim = 20):
    super(VAE, self).__init__()
    self.fc1 = nn.Linear(image_size, h_dim)
    self.fc2 = nn.Linear(h_dim, z_dim)
    self.fc3 = nn.Linear(h_dim, z_dim)
    self.fc4 = nn.Linear(z_dim, h_dim)
    self.fc5 = nn.Linear(h_dim, image_size)

  def encode(self, x):
    h = F.relu(self.fc1(x))
    return self.fc2(h), self.fc3(h)

  def reparameterize(self, mu, log_var):
    std = torch.exp(log_var/2)
    eps = torch.randn_like(std)
    '''
    - torch.randn_like(): Returns a tensor with the same size as `input` that
    is filled with random numbers from a normal distribution with mean 0 and
    variance 1. Actually, `torch.randn_like(input)` is equivalent to `torch.randn(input.size())`
    '''
    return mu + eps * std

  def decode(self, z):
    h = F.relu(self.fc4(z))
    return F.sigmoid(self.fc5(h))

  def forward(self, x):
    mu, log_var = self.encode(x)
    z = self.reparameterize(mu, log_var)
    x_reconst = self.decode(z)
    return x_reconst, mu, log_var

In [5]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [6]:
# Start training
for epoch in range(num_epochs):
  for i, (x, _) in enumerate(data_loader):
    # Forward pass
    x = x.to(device).view(-1, image_size)
    x_reconst, mu, log_var = model(x)

    # Compute reconstruction loss and KL divergence
    reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
    kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    # Backprop and optimize
    loss = reconst_loss + kl_div
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i + 1) % 10 == 0:
      print("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
            .format(epoch + 1, num_epochs, i + 1, len(data_loader), 
                    reconst_loss.item(), kl_div.item()))
      
  with torch.no_grad():
    # Save the sampled images
    z = torch.randn(batch_size, z_dim).to(device)
    out = model.decode(z).view(-1, 1, 28, 28)
    save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch + 1)))

    # Save the reconstructed images
    out, _, _ = model(x)
    x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim = 3)
    save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch + 1)))



Epoch[1/15], Step [10/469], Reconst Loss: 37055.2930, KL Div: 3659.2480
Epoch[1/15], Step [20/469], Reconst Loss: 30148.3438, KL Div: 1020.1989
Epoch[1/15], Step [30/469], Reconst Loss: 27276.3477, KL Div: 1281.7863
Epoch[1/15], Step [40/469], Reconst Loss: 26898.0918, KL Div: 715.1939
Epoch[1/15], Step [50/469], Reconst Loss: 27114.1270, KL Div: 691.8860
Epoch[1/15], Step [60/469], Reconst Loss: 26966.9863, KL Div: 771.9081
Epoch[1/15], Step [70/469], Reconst Loss: 24098.0703, KL Div: 1114.8950
Epoch[1/15], Step [80/469], Reconst Loss: 23949.5898, KL Div: 1023.3667
Epoch[1/15], Step [90/469], Reconst Loss: 23285.9434, KL Div: 1291.3933
Epoch[1/15], Step [100/469], Reconst Loss: 21779.4297, KL Div: 1349.2325
Epoch[1/15], Step [110/469], Reconst Loss: 20397.9668, KL Div: 1556.3445
Epoch[1/15], Step [120/469], Reconst Loss: 20621.3711, KL Div: 1605.1548
Epoch[1/15], Step [130/469], Reconst Loss: 19545.6387, KL Div: 1715.9426
Epoch[1/15], Step [140/469], Reconst Loss: 19769.1680, KL Div: 