In [None]:
!pip install imageio

In [None]:
import os
import pickle
import imageio
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

from tqdm import tqdm
from scipy import linalg
import numpy as np

from kaiwu.classical import SimulatedAnnealingOptimizer
from model import Generator, Discriminator, QGAN
from utils import show_result, show_train_hist

In [None]:
# training parameters
batch_size = 128
lr = 0.0002
train_epoch = 100
num_visible = 256
num_hidden = 100

# results save folder
save_path = './MNIST_QGAN_results/'
if not os.path.isdir(f'{save_path}'):
    os.mkdir(f'{save_path}')
if not os.path.isdir(f'{save_path}/Random_results'):
    os.mkdir(f'{save_path}/Random_results')
if not os.path.isdir(f'{save_path}/Fixed_results'):
    os.mkdir(f'{save_path}/Fixed_results')

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
fixed_z_ = torch.randn((5 * 5, 100)).to(device)  # fixed noise

In [None]:
# data_loader
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,))  # MNIST是单通道
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

# network
G = Generator(input_size=100, n_class=28*28).to(device)
D = Discriminator(input_size=28*28, n_class=1).to(device)
sampler = SimulatedAnnealingOptimizer(initial_temperature=1e4, cutoff_temperature=1e2,
                                        alpha=0.9, size_limit=batch_size)
qgan = QGAN(G, D, sampler, num_visible, num_hidden).to(device)

# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)
rbm_optimizer = optim.Adam(qgan.bm.parameters(), lr=lr*0.1)

train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []

print("Training start!")
for epoch in range(train_epoch):
    D_losses = []
    G_losses = []
    rbm_losses = []

    for x_, _ in tqdm(train_loader):
        # train discriminator D
        # ============================================================
        D_optimizer.zero_grad()
        x_ = x_.view(-1, 28 * 28).to(device)
        d_loss = qgan.d_loss(x_)
        d_loss.backward()
        D_optimizer.step()
        D_losses.append(d_loss.item())

        # train generator G
        # ============================================================
        G_optimizer.zero_grad()
        g_loss = qgan.g_loss()
        g_loss.backward()
        G_optimizer.step()
        G_losses.append(g_loss.item())

        # train RBM
        # ============================================================
        rbm_optimizer.zero_grad()
        binarized_features  = D.get_feature(x_).detach() # 提取真实图像的特征且不更新鉴别器梯度
        rbm_loss = qgan.rbm_loss(binarized_features)
        rbm_loss.backward()
        rbm_optimizer.step()
        rbm_losses.append(rbm_loss.item())

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f, loss_rbm: %3.f' % (
        (epoch + 1), train_epoch, 
        torch.mean(torch.FloatTensor(D_losses)), 
        torch.mean(torch.FloatTensor(G_losses)),
        torch.mean(torch.FloatTensor(rbm_losses)))
    )
    
    p = f'{save_path}/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png'
    fixed_p = f'{save_path}/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png'
    show_result(qgan, (epoch+1), save=True, path=p)
    show_result(qgan, (epoch+1), save=True, path=fixed_p)
    train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)).item())
    train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)).item())

print("Training finish!... save training results")
torch.save(G.state_dict(), f"{save_path}/generator_param.pkl")
torch.save(D.state_dict(), f"{save_path}/discriminator_param.pkl")
with open(f'{save_path}/train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f)

show_train_hist(train_hist, save=True, path=f'{save_path}/MNIST_GAN_train_hist.png')

images = []
for e in range(train_epoch):
    img_name = f'{save_path}/Fixed_results/MNIST_GAN_' + str(e + 1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave(f'{save_path}/generation_animation.gif', images, fps=5)

In [None]:
def calculate_fid(real_features, fake_features):
    """
    计算 Fréchet Inception Distance (FID)。
    real_features: [N, d] 真实样本特征
    fake_features: [N, d] 生成样本特征
    """
    mu1 = np.mean(real_features, axis=0)
    mu2 = np.mean(fake_features, axis=0)
    sigma1 = np.cov(real_features, rowvar=False)
    sigma2 = np.cov(fake_features, rowvar=False)

    ssdiff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        covmean = linalg.sqrtm(sigma1.dot(sigma2) + 1e-6 * np.eye(sigma1.shape[0]))

    fid = ssdiff.dot(ssdiff) + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid.real

def compute_fid_score(G, D, dataloader, device, num_samples=10000, z_dim=100):
    """
    使用判别器 D 提取特征，计算 GAN 生成图像与真实图像之间的 FID。
    """
    G.eval()
    D.eval()

    real_feats = []
    fake_feats = []

    total = 0
    with torch.no_grad():
        # 收集真实图像特征
        for x, _ in tqdm(dataloader, desc="Extracting real features"):
            x = x.view(x.size(0), -1).to(device)
            feat = D.get_feature(x).cpu().numpy()  # 假设 get_feature 返回 flatten 特征
            real_feats.append(feat)
            total += x.size(0)
            if total >= num_samples:
                break
    real_feats = np.concatenate(real_feats, axis=0)[:num_samples]

    # 生成假图像并提取特征
    total = 0
    while total < num_samples:
        z = torch.randn(min(num_samples - total, 1000), z_dim).to(device)
        fake_imgs = G(z).detach()
        feat = D.get_feature(fake_imgs).detach().cpu().numpy()
        fake_feats.append(feat)
        total += feat.shape[0]
    fake_feats = np.concatenate(fake_feats, axis=0)[:num_samples]

    fid = calculate_fid(real_feats, fake_feats)
    return fid

fid_score = compute_fid_score(G, D, train_loader, device, num_samples=5000)
print(f"FID Score: {fid_score:.2f}")