本教程来源 [抛开数学，轻松学懂 VAE](https://zhouyifan.net/2022/12/19/20221016-VAE/)

In [11]:
###下载好了图片后，可以用下面的代码创建Dataloader
import os

import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class CelebADataset(Dataset):
    def __init__(self, root, img_shape=(64, 64)) -> None:
        super().__init__()
        self.root = root
        self.img_shape = img_shape
        self.filenames = sorted(os.listdir(root))

    def __len__(self) -> int:
        return len(self.filenames)

    def __getitem__(self, index: int):
        path = os.path.join(self.root, self.filenames[index])
        img = Image.open(path).convert('RGB')
        pipeline = transforms.Compose([
            transforms.CenterCrop(168),
            transforms.Resize(self.img_shape),
            transforms.ToTensor()
        ])
        return pipeline(img)


def get_dataloader(root='data/celebA/img_align_celeba', **kwargs):
    dataset = CelebADataset(root, **kwargs)
    return DataLoader(dataset, batch_size=16, shuffle=True)

In [12]:
if __name__ == '__main__':
    data_path = 'data/celebA/img_align_celeba'
    dataloader = get_dataloader()
    img = next(iter(dataloader))
    print(img.shape)
    # Concat 4x4 images
    N, C, H, W = img.shape
    assert N == 16
    img = torch.permute(img, (1, 0, 2, 3))
    img = torch.reshape(img, (C, 4, 4 * H, W))
    img = torch.permute(img, (0, 2, 1, 3))
    img = torch.reshape(img, (C, 4 * H, 4 * W))
    img = transforms.ToPILImage()(img)
    img.save('work_dirs/tmp.jpg')

torch.Size([16, 3, 64, 64])


[VAE maps inputs into a multivariate normal distribution](https://hackernoon.com/how-to-sample-from-latent-space-with-variational-autoencoder)

In [13]:
import torch
import torch.nn as nn


class VAE(nn.Module):
    '''
    VAE for 64x64 face generation. The hidden dimensions can be tuned.
    '''
    def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None:
        super().__init__()

        # encoder
        prev_channels = 3
        modules = []
        img_length = 64
        for cur_channels in hiddens:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(prev_channels,
                              cur_channels,
                              kernel_size=3,
                              stride=2,
                              padding=1), nn.BatchNorm2d(cur_channels),
                    nn.ReLU()))
            prev_channels = cur_channels
            img_length //= 2
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(prev_channels * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(prev_channels * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim, prev_channels * img_length * img_length)
        self.decoder_input_chw = (prev_channels, img_length, img_length)
        for i in range(len(hiddens) - 1, 0, -1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hiddens[i],
                                       hiddens[i - 1],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hiddens[i - 1]), nn.ReLU()))
        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(hiddens[0],
                                   hiddens[0],
                                   kernel_size=3,
                                   stride=2,
                                   padding=1,
                                   output_padding=1),
                nn.BatchNorm2d(hiddens[0]), nn.ReLU(),
                nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1),
                nn.ReLU()))
        self.decoder = nn.Sequential(*modules)
    def forward(self, x):
        encoded = self.encoder(x)
        print('encoded.shape',encoded.shape)##torch.Size([16, 256, 2, 2])
        encoded = torch.flatten(encoded, 1)
        mean = self.mean_linear(encoded)
        print('mean.shape',mean.shape)##torch.Size([16, latent_dim = 128])
        logvar = self.var_linear(encoded)
        print('logvar.shape',logvar.shape)##torch.Size([16, latent_dim = 128])
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)

        return decoded, mean, logvar
    def sample(self, device='cuda'):
        z = torch.randn(1, self.latent_dim).to(device)
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded

## 来自chatgpt

### KL散度计算公式

假设 \( z \) 是潜在变量，其后验分布 \( q(z|x) \) 为正态分布 \( N(\mu, \sigma^2) \)，其中 \( \mu \) 和 \( \sigma \) 是由神经网络从数据 \( x \) 中学到的。如果先验分布 \( p(z) \) 是标准正态分布 \( N(0,1) \)，那么KL散度的计算公式为：
$$\text{KL}(q(z|x) || p(z)) = -\frac{1}{2} \sum_{i=1}^d \left( 1 + \log(\sigma_i^2) - \mu_i^2 - \sigma_i^2 \right)$$

### 代码解释

```python
kl_loss = torch.mean(
    -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)
```

- `mean` 和 `logvar` 分别代表 \( \mu \) 和 \( \log(\sigma^2) \)。
- `torch.sum(..., 1)` 对每个样本的所有维度求和，计算出每个样本的KL散度。
- `torch.mean(..., 0)` 计算所有样本的平均KL散度，这是整个数据批次的平均KL散度。

以上代码段在计算每个维度的KL散度后，对所有维度求和，并最后计算所有样本的平均值，从而得到整个数据批次的平均KL散度。
----------------------------------------
要完整地展示变分自编码器中KL散度公式从积分形式到简化表达式的代数简化过程，我们将详细解析每一步。这涉及到一些复杂的数学操作，包括完成平方和积分。下面是详细的推导步骤：

### 初始积分形式
我们从KL散度的定义开始：
$$
\text{KL}(q(z|x) || p(z)) = \int q(z|x) \log \frac{q(z|x)}{p(z)} dz
$$
代入正态分布的密度函数，\( q(z|x) = N(\mu, \sigma^2) \) 和 \( p(z) = N(0, 1) \)，我们有：
$$
\log \frac{q(z|x)}{p(z)} = \log \left(\frac{1}{\sigma} e^{-\frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2}} \right)
$$
$$
= -\log \sigma - \frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2}
$$

### 展开和简化
现在我们要将这些项进一步展开和简化：
$$
-\frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2} = -\frac{z^2 - 2z\mu + \mu^2}{2\sigma^2} + \frac{z^2}{2}
$$
我们可以重写这个表达式为：
$$
-\frac{z^2}{2\sigma^2} + \frac{z\mu}{\sigma^2} - \frac{\mu^2}{2\sigma^2} + \frac{z^2}{2}
$$
将相同项 \(z^2\) 合并：
$$
\left(\frac{1}{2} - \frac{1}{2\sigma^2}\right) z^2 + \frac{z\mu}{\sigma^2} - \frac{\mu^2}{2\sigma^2}
$$
简化：
$$
\frac{1-\sigma^2}{2\sigma^2} z^2 + \frac{z\mu}{\sigma^2} - \frac{\mu^2}{2\sigma^2}
$$

### 计算积分
现在，将这个结果代入积分中。利用高斯积分的性质，只有 \(z\) 的平方项和常数项会对积分产生非零贡献，而线性项的积分为零（因为高斯函数关于其均值对称）：
$$\text{KL}(q(z|x) || p(z)) = -\log \sigma - \int \left(\frac{1-\sigma^2}{2\sigma^2} z^2 + \frac{\mu^2}{2\sigma^2}\right) q(z|x) dz
$$
利用高斯分布 \(q(z|x)\) 的期望性质，我们有：
$$
\int z^2 q(z|x) dz = \sigma^2 + \mu^2
$$
代入上述公式，我们可以得到：
$$
\text{KL}(q(z|x) || p(z)) = -\log \sigma - \left(\frac{1-\sigma^2}{2\sigma^2} (\sigma^2 + \mu^2) + \frac{\mu^2}{2\sigma^2}\right)
$$
这进一步简化为：
$$
= -\log \sigma - \frac{1-\sigma^2}{2\sigma^2} \sigma^2 - \frac{\mu^2}{2\sigma^2} - \frac{\mu^2}{2\sigma^2}
$$
$$
= -\log \sigma - \frac{1-\sigma^2}{2} - \frac{\mu^2}{\sigma^2}
$$
最后，化简得到：
$$
\text{KL}(q(z|x) || p(z)) = \frac{1}{2} (\mu^2 + \sigma^2 - \log \sigma^2 - 1)
$$


In [14]:
from time import time

import torch
import torch.nn.functional as F
from torchvision.transforms import ToPILImage

# from dldemos.VAE.load_celebA import get_dataloader
# from dldemos.VAE.model import VAE

# Hyperparameters
n_epochs = 1
# n_epochs = 10
kl_weight = 0.00025
lr = 0.005


def loss_fn(y, y_hat, mean, logvar):
    recons_loss = F.mse_loss(y_hat, y)
    kl_loss = torch.mean(
        -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)
    loss = recons_loss + kl_loss * kl_weight
    return loss

In [18]:
def train(device, dataloader, model):
    optimizer = torch.optim.Adam(model.parameters(), lr)
    dataset_len = len(dataloader.dataset)

    begin_time = time()
    # train
    for i in range(n_epochs):
        loss_sum = 0
        for x in dataloader:
            # print('x.shape',x.shape)
            x = x.to(device)
            y_hat, mean, logvar = model(x)
            loss = loss_fn(x, y_hat, mean, logvar)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_sum += loss
            break##只训练一次 为了看看输出
        loss_sum /= dataset_len
        training_time = time() - begin_time
        minute = int(training_time // 60)
        second = int(training_time % 60)
        print(f'epoch {i}: loss {loss_sum} {minute}:{second}')
        torch.save(model.state_dict(), 'dldemos/VAE/model.pth')


In [19]:
model = VAE()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataloader = get_dataloader()
model.to(device)
train(device, dataloader, model)
# model.load_state_dict(torch.load('dldemos/VAE/model.pth'))

encoded.shape torch.Size([16, 256, 2, 2])
mean.shape torch.Size([16, 128])
logvar.shape torch.Size([16, 128])
epoch 0: loss 2.5405049655091716e-06 0:0


In [22]:
##定义生成图片的函数
def generate(device, model):
    model.eval()
    output = model.sample(device)
    output = output[0].detach().cpu()
    img = ToPILImage()(output)
    img.save('work_dirs/tmp.jpg')

In [23]:
###生成图片
generate(device, model)