# 1. Standard Model in PyTorch
本节将以VAE为例，介绍一个标准的PyTorch网络模型的构建步骤。

## 1.1. Define Parameters
定义网络模型参数。
通过命令行输入方式自定义的基础参数一般包含：

| 参数 | 含义 |
| --- | ---|
|batch size|训练时输入的batch大小，即1个batch包含多少个样本，与显存大小有关|
|epochs|训练轮次数量，一般设置为100|
|device|训练模型时使用device为CPU还是GPU，通常使用GPU进行运算（MPC: "mps"; Nvidia: "cuda")|
|log interval|训练时，多少个batch输出一次log信息，通常包含epoch索引、batch索引、loss信息等|
|learning rate|模型学习率，通常<=1e-2，可设置为1e-3|

In [39]:
# Parameters (User Defined)
import argparse

parser = argparse.ArgumentParser(description='Standard Model Definition in PyTorch')

parser.add_argument('-b', '--batch-size', type=int, default=128,
                    help='input batch size for training (default: 128)')
parser.add_argument('-e', '--epochs', type=int, default=3,
                    help='number of epochs to train (default: 10)')
parser.add_argument('-d', '--device', type=str, default="cpu",
                    help='device to train model (default: cpu)')
parser.add_argument('-log','--log-interval', type=int, default=10,
                    help='how many batches to wait before logging training status (default: 10)')
parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4,
                    help="learning rate (default: 10e-4)")

args = parser.parse_args(args=[]) # args=[] to make .ipynb run correctly
# print(args)

In [40]:
# Parameters (Coder Defined)
seed = 1 # Random Seed for initialization
kwargs = {"num_workers":2, "pin_memory":True} if args.device =='cuda' else {} # Training Settings

## 1.2 Model Development
定义模型本身与Loss函数计算方式。

其中，模型本身需要继承nn.Module父类，并至少需要重写以下函数：

|函数名|功能|
|---|---|
|\_\_init\_\_|定义模型需要的网络结构块，以及实例需要的一些变量，及可训练的参数（需要用nn定义才有梯度）|
|forward|前向传播函数|

In [41]:
import torch
from torch import nn, optim
from torch.nn import functional as F

# Model Definition
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # Mu
        self.fc22 = nn.Linear(400, 20) # log(Var)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
    
    def encoder(self, x):
        h1 = F.relu(self.fc1(x))
        mu = self.fc21(h1)
        logvar = self.fc22(h1)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, logvar = self.encoder(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

torch.manual_seed(seed) # Fix random seed
model = VAE().to(args.device) # Model Instantiation
optimizer = optim.Adam(model.parameters(), lr = args.learning_rate) # Specify Optimizer

In [42]:
# Loss Function Definition
def loss_function(x_hat, x, mu, logvar): # Loss_function = -ELBO
    # Reconstruction Term
    Recon = F.binary_cross_entropy(x_hat, x.view(-1, 784), reduction='sum')

    # Regularization Loss 
    KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return Recon + KL

## 1.3 Define Train and Test Processes
这里可以定义训练阶段和测试阶段的过程，使得main函数中的实现更为简洁(Optional)

In [43]:
import torch.utils.data
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Dataloader for training and testing (read img as chw format)
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train = True, download = False, transform = transforms.ToTensor()),
    batch_size = args.batch_size, shuffle = True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train = False, transform = transforms.ToTensor()),
    batch_size = args.batch_size, shuffle = False, **kwargs)

def train(epoch):
    model.train() # Open {BN, drop out} if the model has these layers
    train_loss = 0 # total loss in an epoch
    for batch_idx, (data, _) in enumerate(train_loader):
        # Initialization
        data = data.to(args.device)
        optimizer.zero_grad()
        # Loss backward
        x_hat, mu, logvar = model(data)
        loss = loss_function(x_hat, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        # Optimize
        optimizer.step()
        # Print Log Information
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss:{:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx/len(train_loader),
                loss.item() / len(data)))
    
    # Print Information (Epoch)
    print("[Epoch]:{}, Average Loss:{:.4f}".format(
        epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval() # Cancle {BN, drop out} if the model has these layers
    test_loss = 0 # total loss in an epoch

    # Do not calculate gradient during test process
    with torch.no_grad(): 
        for i, (data, _) in enumerate(test_loader):
            # Initialization
            data = data.to(args.device)
            # Loss Calculation
            x_hat, mu, logvar = model(data)
            test_loss += loss_function(x_hat, data, mu, logvar).item()
            # Print Information (Here we only print the first batch)
            if i == 0:
                n = min(data.size(0), 8) # n samples shown
                # Compare x and x_hat
                comparison = torch.cat([data[:n],
                    x_hat.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(), './data/results/rec_' + str(epoch) + '.png', nrows=n)
    
    test_loss /= len(test_loader.dataset)
    print("Test Loss:{:.6f}".format(test_loss))

# 1.4 Main()
循环epoch遍 train() + test()

In [44]:
%%time
if __name__ == "__main__":
    for epoch in range(1, args.epochs +1):
        train(epoch)
        test(epoch)
        # Test Decoder
        with torch.no_grad():
            sample = torch.randn(64, 20).to(args.device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                './data/results/sample_'+str(epoch)+'.png')       

[Epoch]:1, Average Loss:260.0498
Test Loss:194.492690
[Epoch]:2, Average Loss:176.7785
Test Loss:163.309465
[Epoch]:3, Average Loss:156.8008
Test Loss:149.016278
CPU times: user 1min 18s, sys: 1min 8s, total: 2min 26s
Wall time: 37.8 s
