In [None]:
import os
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm import tqdm

# 添加kaiwu的license信息



# MINIST数据集加载

In [None]:
class MNISTWithBatch(Dataset):
    def __init__(self, root, train=True, transform=None, download=True, num_batches=6):
        self.mnist = datasets.MNIST(root=root, train=train, transform=transform, download=download)
        self.num_batches = num_batches
        self.batch_indices = self._create_batch_indices()

    def _create_batch_indices(self):
        num_samples = len(self.mnist)
        return torch.arange(num_samples) // (num_samples // self.num_batches)

    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, idx):
        image, _ = self.mnist[idx]  # 忽略原始 label
        batch_idx = self.batch_indices[idx]
        return image, batch_idx

# Binarized MINIST数据集加载

In [None]:
# 定义数据转换操作
def flatten_tensor(x):
    return x.view(-1)
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为Tensor
    transforms.Lambda(flatten_tensor)
])

In [None]:
batch_size=256
epochs=20
lr=1e-3
kl_beta = 0.00001

model_name = "QVAE_annealing_tanh"
save_path = f"./models/{model_name}"
os.makedirs(save_path, exist_ok=True)

In [None]:
val_dataset = MNISTWithBatch(root='../../data', train=False, download=False, transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

train_dataset = MNISTWithBatch(root='../../data', train=True, download=False, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

mean_x = 0
for x, _ in train_dataloader:
    mean_x += x.mean(dim=0)
mean_x = mean_x / len(train_dataloader)
mean_x = mean_x.cpu().numpy()


In [None]:
val_dataset = MNISTWithBatch(root='../../data', train=False, download=False, transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

train_dataset = MNISTWithBatch(root='../../data', train=True, download=False, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

mean_x = 0
for x, _ in train_dataloader:
    mean_x += x.mean(dim=0)
mean_x = mean_x / len(train_dataloader)
mean_x = mean_x.cpu().numpy()


In [None]:
from qvae import QVAE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 图片拉伸后的维度
input_dim = 784
# fc1压缩后的维度
hidden_dim = 512
# 隐变量维度
latent_dim = 256

# RBM可见层和隐藏层维度
num_var1 = 128
num_var2 = 128
# 重叠分布的beta
dist_beta = 10

model = QVAE(input_dim, hidden_dim, latent_dim, num_var1, num_var2, dist_beta, mean_x)

model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
  # pip install kaiwu==1.0.26a --index-url http://__token__:glpat-7g5zHisFz1iK7Db77s_P@10.0.0.239:5005/api/v4/projects/114/packages/pypi/simple --trusted-host 10.0.0.239 --extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple


In [None]:
patience_counter = 0
best_state_dict = None
loss_history = []
elbo_history = []
kl_history = []
cost_history = []

In [None]:
for epoch in range(1, epochs):
    model.train()
    total_loss, total_elbo, total_kl, total_cost = 0, 0, 0, 0
    for x, _ in train_dataloader:
        x = x.to(device)
 
        optimizer.zero_grad()

        output, recon_x, neg_elbo, wd_loss, kl, cost, _, _ = model.neg_elbo(x, kl_beta)
        loss = neg_elbo + wd_loss
        loss.backward()

        optimizer.step()

        total_loss += loss.item()
        total_elbo += neg_elbo.item()
        total_kl += kl.item()
        total_cost += cost.item()


    avg_loss = total_loss / len(train_dataloader)
    avg_elbo = total_elbo / len(train_dataloader)
    avg_kl = total_kl / len(train_dataloader)
    avg_cost = total_cost / len(train_dataloader)


    loss_history.append(avg_loss)
    elbo_history.append(avg_elbo)
    kl_history.append(avg_kl)
    cost_history.append(avg_cost)

    model_save_path = os.path.join(save_path, f'davepp_epoch{epoch}.pth')
    torch.save(model.state_dict(), model_save_path)

    print(f"Epoch {epoch}/{epochs}: Loss: {avg_loss:.4f}, elbo: {avg_elbo:.4f}, KL: {avg_kl:.4f}, Cost: {avg_cost:.4f}")


In [None]:
def save_list_to_txt(filename, data):
    with open(filename, "w") as f:
        for value in data:
            f.write(f"{value:.6f}\n")


save_list_to_txt(os.path.join(save_path, "loss_history.txt"), loss_history)
save_list_to_txt(os.path.join(save_path, "elbo_history.txt"), elbo_history)
save_list_to_txt(os.path.join(save_path, "cost_history.txt"), cost_history)
save_list_to_txt(os.path.join(save_path, "kl_history.txt"), kl_history)

In [None]:
# # 模型加载模型文件
# model.load_state_dict(torch.load(os.path.join(save_path, "davepp_epoch10.pth")))

In [None]:
def plot_flattened_images_grid(features: torch.Tensor, grid_size: int = 8, save_path: str = None):
    """
    显示并可选保存前 grid_size * grid_size 个 28x28 灰度图像。

    Args:
        features (torch.Tensor): 形状为 [N, 784] 的张量，每行为一个扁平的 28x28 图像。
        grid_size (int): 图像网格边长（默认 8，即显示前 64 张图像）。
        save_path (str): 如果提供，将保存图像到该路径。
    """
    assert features.dim() == 2 and features.size(1) == 784, "features 应为 [N, 784] 的张量"
    num_images = grid_size * grid_size
    assert features.size(0) >= num_images, f"features 中至少应包含 {num_images} 张图像"

    features_numpy = features[:num_images].detach().cpu().numpy()

    fig, axes = plt.subplots(grid_size, grid_size, figsize=(5, 5))
    for i in range(grid_size):
        for j in range(grid_size):
            idx = i * grid_size + j
            img = features_numpy[idx].reshape(28, 28)
            axes[i, j].imshow(img, cmap='gray')
            axes[i, j].axis('off')

    plt.tight_layout()
    
    if save_path:
        # os.makedirs(os.path.dirname(save_path), exist_ok=True)
        # plt.savefig(save_path, bbox_inches='tight')
        # print(f"图像已保存到 {save_path}")
        plt.show()
        plt.close()

In [None]:
features, _ = next(iter(train_dataloader))

In [None]:
plot_flattened_images_grid(features, grid_size=8, save_path = save_path + f'/original.png')

In [None]:
features = features.to(device)

In [None]:
model.eval()

# output, recon_x, neg_elbo, wd_loss, total_kl, cost = model.neg_elbo(features)

output, recon_x, neg_elbo, wd_loss, kl, cost, q, zeta = model.neg_elbo(features, kl_beta)

In [None]:
plot_flattened_images_grid(output, grid_size=8, save_path = save_path + f'/recon_x.png')

In [None]:
from kaiwu.classical import SimulatedAnnealingOptimizer

sampler = SimulatedAnnealingOptimizer(size_limit=100, alpha=0.99)
z = model.rbm.sample(sampler)
shape = z.shape

In [None]:
from torch.distributions import Exponential

smoothing_dist = Exponential(dist_beta)
# 从平滑分布采样
zeta = smoothing_dist.sample(shape)
zeta = zeta.to(z.device)
zeta = torch.where(z == 0., zeta, 0) # 引入z
# zeta = torch.randn(256, 256).to(device)
generated_x = model.decoder(zeta)
generated_x = generated_x + model.train_bias

generated_x = torch.sigmoid(generated_x)

plot_flattened_images_grid(generated_x, grid_size=8, save_path = save_path + f'/generated_x.png')

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance

def get_real_images(dataloader, n_images=10000):
    images = []
    for batch_imgs, _ in dataloader:
        images.append(batch_imgs)
        if sum(img.shape[0] for img in images) >= n_images:
            break
    return torch.cat(images, dim=0)[:n_images]

def generate_images_original_vae(model, latent_dim, n_images=10000, batch_size=64):
    model.eval()
    imgs = []
    with torch.no_grad():
        for _ in tqdm(range(n_images // batch_size)):
            z = torch.randn(batch_size, latent_dim).to(device)
            img = model.decoder(z).cpu()
            imgs.append(img)
    return torch.cat(imgs, dim=0)[:n_images]

def generate_images_qvae(model, latent_dim, n_images=10000, batch_size=64):
    model.eval()
    imgs = []
    sampler = SimulatedAnnealingOptimizer(alpha=0.95)
    with torch.no_grad():
        for _ in tqdm(range(n_images // batch_size)):
            z = model.rbm.sample(sampler)
            # v, h = model.rbm.sample(256,10)
            # z = torch.cat([v, h], dim=1)
            shape = z.shape
            smoothing_dist = Exponential(dist_beta)
            # 从平滑分布采样
            zeta = smoothing_dist.sample(shape)
            zeta = zeta.to(z.device)
            zeta = torch.where(z == 0., zeta, 0)
            # zeta = torch.randn(256, 256).to(device)
            generated_x = model.decoder(zeta)
            
            generated_x = generated_x + model.train_bias

            generated_x = torch.sigmoid(generated_x)
            
            imgs.append(generated_x)
    return torch.cat(imgs, dim=0)[:n_images]

resize = transforms.Resize((299, 299))

def preprocess(images):
    if images.max() > 1.0:
        images = images / 255.0
    return resize(images)


def compute_fid_in_batches(fake_imgs, real_imgs, batch_size=64):
    """
    计算 FID 分数，适用于输入为 (N, 784) 的展平图像（如 MNIST）
    
    参数:
        fake_imgs: 生成图像，shape = (N, 784)
        real_imgs: 真实图像，shape = (M, 784)
        batch_size: 每个批次处理多少图像
        device: 使用 'cuda' 或 'cpu'
    返回:
        FID 分数
    """
    fid = FrechetInceptionDistance(feature=64).to(device)

    def preprocess(images):
        # 转换为图像格式 (B, 1, 28, 28)
        images = images.view(-1, 1, 28, 28)
        # 扩展为三通道
        images = images.repeat(1, 3, 1, 1)
        # 调整大小到 299x299
        resize = transforms.Resize((299, 299), antialias=True)
        return resize(images)

    # 如果不是 tensor，先转成 tensor
    if not isinstance(fake_imgs, torch.Tensor):
        fake_imgs = torch.tensor(fake_imgs, dtype=torch.uint8)
    if not isinstance(real_imgs, torch.Tensor):
        real_imgs = torch.tensor(real_imgs, dtype=torch.uint8)

    # 归一化到 [0, 255] 并转为 uint8（假定输入是 float 在 [0,1] 范围）
    fake_imgs = (fake_imgs * 255).clamp(0, 255).to(torch.uint8)
    real_imgs = (real_imgs * 255).clamp(0, 255).to(torch.uint8)

    # 转换为图像并更新 FID
    for i in range(0, len(real_imgs), batch_size):
        batch = real_imgs[i:i+batch_size]
        batch = preprocess(batch)
        fid.update(batch.to(device), real=True)

    for i in range(0, len(fake_imgs), batch_size):
        batch = fake_imgs[i:i+batch_size]
        batch = preprocess(batch)
        fid.update(batch.to(device), real=False)

    return fid.compute().item()

In [None]:
# 获取真实图像和生成图像
real_imgs = get_real_images(val_dataloader, n_images=10000)
print(f"Real images shape: {real_imgs.shape}")

fake_imgs_original_vae = generate_images_qvae(model, latent_dim=latent_dim)
print(f"Generated images shape: {fake_imgs_original_vae.shape}")

# 计算 FID（更节省内存）
fid_original = compute_fid_in_batches(fake_imgs_original_vae, real_imgs)
print(f"Original VAE FID: {fid_original:.2f}")