In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from google.colab import drive
drive.mount('/content/drive')

class CelebADataset(Dataset):
    def __init__(self, image_dir,  transform=None):
        self.image_dir = image_dir
        self.transform = transform
        #文件太大  选前400个
        self.file_list = os.listdir(image_dir)[:400]
    def __len__(self):
      return len(self.file_list)
    def __getitem__(self, idx):
        # 返回每个样本的图像和注释（如果有）
        image_path = os.path.join(self.image_dir, self.file_list[idx])
        image = Image.open(image_path).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image

# 设置数据集路径
image_dir = '/content/drive/MyDrive/VAE/pytorch-mnist-VAE-master/img_align_celeba/img_align_celeba'

# 定义数据预处理操作
transform = transforms.Compose([
    # 根据需要调整大小
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    # 根据需要进行归一化,
    #因为是自编码，这里可以选择保留原始方差和均值
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


# 创建CelebA数据集对象
celeba_dataset = CelebADataset(image_dir, transform)

# 定义批量大小和是否打乱数据
batch_size = 64
shuffle = True

# 创建数据加载器
celeba_loader = torch.utils.data.DataLoader(celeba_dataset, batch_size=batch_size, shuffle=shuffle)
train_ratio = 0.8  # 训练集所占比例
dataset_size = len(celeba_dataset)
train_size = int(train_ratio * dataset_size)
test_size = dataset_size - train_size
print(train_size,test_size)
train_dataset, test_dataset = random_split(celeba_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
320 80


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device

device(type='cuda')

In [None]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()

        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
    #encoder: 编码器方法，接受一个输入x，经过两个全连接层和ReLU激活函数
    #output 两个向量mu和log_var，分别表示隐变量的均值和对数方差
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    #sampling: 采样方法，接受两个向量mu和log_var
    #根据公式z = mu + exp(0.5*log_var)*eps生成一个隐变量z
    #其中eps是一个标准正态分布的随机向量
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
    #decoder: 解码器方法，接受一个隐变量z，经过两个全连接层和ReLU激活函数
    #output 一个向量，经过sigmoid激活函数后表示重构的数据
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h))
    '''
    对给定图像数据，先调用encoder方法得到mu和log_var
    再调用sampling方法得到z
    再调用decoder方法得到重构的数据
    返回重构的数据，mu和log_var
    '''
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1,3* 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

# build model
vae = VAE(x_dim=3*784, h_dim1= 512, h_dim2=256, z_dim=2)
vae.to(device)

VAE(
  (fc1): Linear(in_features=2352, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc31): Linear(in_features=256, out_features=2, bias=True)
  (fc32): Linear(in_features=256, out_features=2, bias=True)
  (fc4): Linear(in_features=2, out_features=256, bias=True)
  (fc5): Linear(in_features=256, out_features=512, bias=True)
  (fc6): Linear(in_features=512, out_features=2352, bias=True)
)

In [None]:
for i in train_loader:
  print(i.shape)

torch.Size([64, 3, 28, 28])
torch.Size([64, 3, 28, 28])
torch.Size([64, 3, 28, 28])
torch.Size([64, 3, 28, 28])
torch.Size([64, 3, 28, 28])


In [None]:
#VAE模型的所有参数做为adam优化器参数
optimizer = optim.Adam(vae.parameters())
# return reconstruction error + KL divergence losses
# 即二元交叉熵（BCE）来衡量重构误差，KL散度衡量隐变量分布和标准正态分布之间的差异
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 3*784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [None]:
def train(epoch):
    #训练模式
    vae.train()
    #重置损失
    train_loss = 0
    for batch_idx, (data) in enumerate(train_loader):
        data = data.to(device)
        #重置梯度
        optimizer.zero_grad()
        #重构的数据，mu和log_var（隐变量的均值和对数方差）
        recon_batch, mu, log_var = vae(data)

        #loss计算

        loss = loss_function(recon_batch, data, mu, log_var)

        loss.backward()
        train_loss += loss.item()
        #更新参数
        optimizer.step()
        #打印info
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [None]:
def test():
    #评估模式
    vae.eval()
    #重置loss
    test_loss= 0
    with torch.no_grad():
      #正向传播，累计每个批次的损失值，取均值
        for data in test_loader:
            data = data.to(device)
            recon, mu, log_var = vae(data)

            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
for epoch in range(1, 400):
    train(epoch)
    test()

====> Epoch: 1 Average loss: 1615.5004
====> Test set loss: 1590.9292
====> Epoch: 2 Average loss: 1576.3232
====> Test set loss: 1567.9051
====> Epoch: 3 Average loss: 1555.3573
====> Test set loss: 1543.0051
====> Epoch: 4 Average loss: 1523.8096
====> Test set loss: 1512.3978
====> Epoch: 5 Average loss: 1493.5364
====> Test set loss: 1481.6777
====> Epoch: 6 Average loss: 1460.4941
====> Test set loss: 1466.5116
====> Epoch: 7 Average loss: 1448.4058
====> Test set loss: 1456.5253
====> Epoch: 8 Average loss: 1439.0099
====> Test set loss: 1470.5283
====> Epoch: 9 Average loss: 1431.3832
====> Test set loss: 1454.2132
====> Epoch: 10 Average loss: 1419.3416
====> Test set loss: 1437.3203
====> Epoch: 11 Average loss: 1414.0177
====> Test set loss: 1441.9769
====> Epoch: 12 Average loss: 1412.2003
====> Test set loss: 1433.1761
====> Epoch: 13 Average loss: 1405.7690
====> Test set loss: 1424.7568
====> Epoch: 14 Average loss: 1400.1804
====> Test set loss: 1424.7262
====> Epoch: 15

KeyboardInterrupt: ignored

In [None]:
with torch.no_grad():
    #从标准正态分布中随机采样
    z = torch.randn(64, 2).to(device)
    #input_num=output_num,解码
    sample = vae.decoder(z).to(device)
    #save
    save_image(sample.view(64, 3, 28, 28), './VAE_' + '.png')

In [None]:
#保存参数文件,自己选择保存地址
torch.save(vae.state_dict(), '/content/drive/MyDrive/VAE/pytorch-mnist-VAE-master/VAE_params.pkl')

**AE**


In [None]:
class AE(nn.Module):
 def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
    super(AE, self).__init__()

    # encoder part
    self.fc1 = nn.Linear(x_dim, h_dim1)
    self.fc2 = nn.Linear(h_dim1, h_dim2)
    self.fc3 = nn.Linear(h_dim2, z_dim) # only one layer for hidden variable
    # decoder part
    self.fc4 = nn.Linear(z_dim, h_dim2)
    self.fc5 = nn.Linear(h_dim2, h_dim1)
    self.fc6 = nn.Linear(h_dim1, x_dim)

 def encoder(self, x):
    h = F.relu(self.fc1(x))
    h = F.relu(self.fc2(h))
    return self.fc3(h) # return hidden variable


 def decoder(self, z):
    h = F.relu(self.fc4(z))
    h = F.relu(self.fc5(h))
    return F.sigmoid(self.fc6(h))

 def forward(self, x):
    z = self.encoder(x.view(-1, 3*784))


    return self.decoder(z), z # return reconstruction and hidden variable

# build model
ae = AE(x_dim=3*784, h_dim1= 512, h_dim2=256, z_dim=2)

ae.to(device)

AE(
  (fc1): Linear(in_features=2352, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=2, bias=True)
  (fc4): Linear(in_features=2, out_features=256, bias=True)
  (fc5): Linear(in_features=256, out_features=512, bias=True)
  (fc6): Linear(in_features=512, out_features=2352, bias=True)
)

In [None]:
optimizer = optim.Adam(ae.parameters())
# return reconstruction error
def loss_function(recon_x, x):
 BCE = F.binary_cross_entropy(recon_x, x.view(-1, 3*784), reduction='sum')
 return BCE # no KL divergence

In [None]:
def train(epoch):
 ae.train()
 train_loss = 0
 for batch_idx, data, in enumerate(train_loader):
  data = data.to(device)
  optimizer.zero_grad()

  recon_batch, z = ae(data) # no mu and log_var
  loss = loss_function(recon_batch, data)

  loss.backward()
  train_loss += loss.item()
  optimizer.step()

  if batch_idx % 100 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
      epoch, batch_idx * len(data), len(train_loader.dataset),
      100. * batch_idx / len(train_loader), loss.item() / len(data)))
 print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

def test():
 ae.eval()
 test_loss= 0
 with torch.no_grad():
  for data in test_loader:
    data = data.to(device)
    recon_batch, z = ae(data) # no mu and log_var

    # sum up batch loss
    test_loss += loss_function(recon_batch, data).item()

    test_loss /= len(test_loader.dataset)
  print('====> Test set loss: {:.4f}'.format(test_loss))
for epoch in range(1, 3000):
    train(epoch)
    test()

====> Epoch: 1 Average loss: 1617.5077
====> Test set loss: 332.2645
====> Epoch: 2 Average loss: 1581.2021
====> Test set loss: 328.1815
====> Epoch: 3 Average loss: 1552.6415
====> Test set loss: 320.7577
====> Epoch: 4 Average loss: 1492.7379
====> Test set loss: 308.6075
====> Epoch: 5 Average loss: 1459.2379
====> Test set loss: 304.8855
====> Epoch: 6 Average loss: 1443.2845
====> Test set loss: 303.6268
====> Epoch: 7 Average loss: 1434.5585
====> Test set loss: 301.5566
====> Epoch: 8 Average loss: 1426.1610
====> Test set loss: 301.1121
====> Epoch: 9 Average loss: 1422.1063
====> Test set loss: 300.7966
====> Epoch: 10 Average loss: 1418.1871
====> Test set loss: 300.5694
====> Epoch: 11 Average loss: 1413.4543
====> Test set loss: 299.5793
====> Epoch: 12 Average loss: 1406.6066
====> Test set loss: 298.5708
====> Epoch: 13 Average loss: 1396.6222
====> Test set loss: 297.2416
====> Epoch: 14 Average loss: 1390.5991
====> Test set loss: 296.6579
====> Epoch: 15 Average loss:

KeyboardInterrupt: ignored

#保存重构图片
AE and VAE 都跑大约3K次进行比对

In [None]:
subset = torch.utils.data.Subset(test_dataset, indices=torch.randperm(len(test_dataset))[:64])
loader = torch.utils.data.DataLoader(subset, batch_size=64, shuffle=False)
with torch.no_grad():
 for data in loader:
  data=data.to(device)
  recon_batch, z = ae(data)
  save_image(data.view(64, 3, 28, 28), './orig_' + '.png')
  save_image(recon_batch.view(64, 3, 28, 28), './AE_recon_' + '.png')

  recon_batch, _, _ = vae(data)

  save_image(recon_batch.view(64, 3, 28, 28), './VAE_recon_' + '.png')

In [None]:
torch.save(ae.state_dict(), '/content/drive/MyDrive/VAE/pytorch-mnist-VAE-master/AE_params.pkl')

只用MSE评估

In [None]:
#进行量化评估
#MSE
def evaluate_reconstruction(model, loader):
    model.eval()
    mse_loss = nn.MSELoss(reduction='mean')
    total_loss = 0.0
    with torch.no_grad():
        for images in loader:
            images = images.view(-1, 3*784)
            images = images.to(device)
            reconstructions= model(images)[0]
            loss = mse_loss(reconstructions, images)
            total_loss += loss.item() * images.size(0)

    # 计算平均重建损失
    avg_loss = total_loss / len(loader.dataset)
    return avg_loss

ae_reconstruction_loss = evaluate_reconstruction(ae, test_loader)
vae_reconstruction_loss = evaluate_reconstruction(vae, test_loader)

print(f'AE Reconstruction Loss: {ae_reconstruction_loss:.4f}')
print(f'VAE Reconstruction Loss: {vae_reconstruction_loss:.4f}')

AE Reconstruction Loss: 0.0639
VAE Reconstruction Loss: 0.0664


使用MSE和峰值信噪比作为评估指标

In [None]:
import torch
import torch.nn.functional as F

def calculate_psnr(reconstructions, originals, max_pixel_value=1.0):
    mse = F.mse_loss(reconstructions, originals)
    psnr = 20 * torch.log10(torch.tensor(max_pixel_value)) - 10 * torch.log10(mse)
    return psnr

def evaluate_reconstruction(model, loader):
    model.eval()
    mse_loss = nn.MSELoss(reduction='mean')
    total_loss = 0.0
    total_psnr = 0.0
    with torch.no_grad():
        for images in loader:
            images = images.view(-1, 3*784)
            images = images.to(device)
            reconstructions = model(images)[0]
            loss = mse_loss(reconstructions, images)
            total_loss += loss.item() * images.size(0)
            psnr = calculate_psnr(reconstructions, images)
            total_psnr += psnr.item() * images.size(0)

    avg_loss = total_loss / len(loader.dataset)
    avg_psnr = total_psnr / len(loader.dataset)
    return avg_loss, avg_psnr
A
# 加载AE模型参数
ae = AE(x_dim=3*784, h_dim1= 512, h_dim2=256, z_dim=2)
ae.load_state_dict(torch.load('/content/drive/MyDrive/VAE/pytorch-mnist-VAE-master/pytorch_celeba_VAE/AE_params.pkl', map_location=torch.device('cpu')))
ae.to(device)

# 加载VAE模型参数
vae = VAE(x_dim=3*784, h_dim1= 512, h_dim2=256, z_dim=2)
vae.load_state_dict(torch.load('/content/drive/MyDrive/VAE/pytorch-mnist-VAE-master/pytorch_celeba_VAE/VAE_params.pkl', map_location=torch.device('cpu')))
vae.to(device)

ae_reconstruction_loss, ae_psnr = evaluate_reconstruction(ae, test_loader)
vae_reconstruction_loss, vae_psnr = evaluate_reconstruction(vae, test_loader)

print(f'AE Reconstruction Loss: {ae_reconstruction_loss:.4f}')
print(f'AE PSNR: {ae_psnr:.2f} dB')
print(f'VAE Reconstruction Loss: {vae_reconstruction_loss:.4f}')
print(f'VAE PSNR: {vae_psnr:.2f} dB')


AE Reconstruction Loss: 0.0194
AE PSNR: 17.13 dB
VAE Reconstruction Loss: 0.0203
VAE PSNR: 16.96 dB


In [None]:
pwd

'/content'

In [None]:
num_steps = 10  # 插值步数
with torch.no_grad():
    # 从潜在空间中采样两个点
    z1 = torch.randn(1, 2).to(device)  # 第一个点
    z2 = torch.randn(1, 2).to(device)  # 第二个点

    # 生成插值的图像
    interpolated_images = []
    for step in range(num_steps):
        # 在两个潜在点之间进行线性插值
        alpha = float(step) / (num_steps - 1)
        interpolated_z = (1 - alpha) * z1 + alpha * z2

        # 解码插值的潜在点以生成图像
        interpolated_image = vae.decoder(interpolated_z)

        interpolated_images.append(interpolated_image)

    # 拼接并保存插值的图像
    interpolated_images = torch.cat(interpolated_images, dim=0)
    save_image(interpolated_images.view(num_steps, 3, 28, 28), './interpolated_images.png')