In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# 定义VAE模型
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # 编码器
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # 均值
        self.fc22 = nn.Linear(400, 20) # 对数方差

        # 解码器
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))#每个图像都被展平为一个长度为 784 的一维数组（对应于 28x28 像素的MNIST图像）
        z = self.reparameterize(mu, logvar)#再参数化使得模型可以通过反向传播进行训练，同时保留了隐变量的随机性
        return self.decode(z), mu, logvar

In [8]:
# 损失函数
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [9]:
# 数据加载器
transform = transforms.Compose([
    transforms.ToTensor()
#     transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

In [10]:
# 初始化模型和优化器
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [11]:
# 训练过程
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        optimizer.zero_grad() #在进行反向传播之前，先将模型参数的梯度归零
        recon_batch, mu, logvar = model(data) #模型进行前向传播得到重构图像 recon_batch，以及编码的均值 mu 和对数方差 logvar。
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')

    print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')


In [12]:
# 执行训练
for epoch in range(1, 11):
    train(epoch)

====> Epoch: 1 Average loss: 163.3530
====> Epoch: 2 Average loss: 121.2747
====> Epoch: 3 Average loss: 114.3030
====> Epoch: 4 Average loss: 111.4394
====> Epoch: 5 Average loss: 109.7427
====> Epoch: 6 Average loss: 108.5880
====> Epoch: 7 Average loss: 107.7486
====> Epoch: 8 Average loss: 107.0784
====> Epoch: 9 Average loss: 106.5675
====> Epoch: 10 Average loss: 106.1237
