<a href="https://colab.research.google.com/github/yastiaisyah/DataSynthesis/blob/main/conditional_vae_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from torchvision.datasets import MNIST
from torchvision import transforms

# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
mnist = MNIST(root='./data', train=True, transform=transform, download=True)

# Parameters
mb_size = 32
z_dim = 5
y_dim = 10  # Number of condition classes (0-9 digits)
X_dim = 28 * 28  # Flattened image dimensions
h_dim = 128
lr = 1e-3

# Define Q, P, sample_z, and other functions here

def sample_z(mu, log_var):
    eps = Variable(torch.randn(mu.size()))
    return mu + torch.exp(log_var / 2) * eps


# =============================== TRAINING ====================================

# Create the model objects
Q = nn.Sequential(
    nn.Linear(X_dim + y_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, z_dim * 2)  # Output both mu and log_var
)

P = nn.Sequential(
    nn.Linear(z_dim + y_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, X_dim),
    nn.Sigmoid()
)

params = list(Q.parameters()) + list(P.parameters())
solver = optim.Adam(params, lr=lr)

cnt = 0

"""1000000"""
for it in range(100000):
    # Sample minibatch from the dataset
    indices = np.random.randint(0, len(mnist), mb_size)
    X, c = mnist.data[indices], mnist.targets[indices]
    X = X.view(mb_size, -1).float() / 255.0
    c_onehot = torch.zeros(mb_size, y_dim).scatter_(1, c.view(-1, 1), 1)

    X, c_onehot = Variable(X), Variable(c_onehot)

    # Forward pass
    z_params = Q(torch.cat([X, c_onehot], 1))
    z_mu, z_log_var = z_params[:, :z_dim], z_params[:, z_dim:]
    z = sample_z(z_mu, z_log_var)
    X_sample = P(torch.cat([z, c_onehot], 1))

    # Loss calculation
    recon_loss = nn.BCELoss(reduction='sum')(X_sample, X) / mb_size
    kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_log_var) + z_mu**2 - 1. - z_log_var, 1))
    loss = recon_loss + kl_loss

    # Backpropagation and optimization
    solver.zero_grad()
    loss.backward()
    solver.step()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; Loss: {:.4}'.format(it, loss.item()))

        c = torch.zeros(mb_size, y_dim).float()
        c[:, np.random.randint(0, y_dim)] = 1.0
        c = Variable(c)

        z = Variable(torch.randn(mb_size, z_dim))

        samples = P(torch.cat([z, c], 1)).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 207760132.95it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 25236602.88it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 176044072.71it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz



100%|██████████| 4542/4542 [00:00<00:00, 16089973.62it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Iter-0; Loss: 547.1
Iter-1000; Loss: 143.6
Iter-2000; Loss: 124.5
Iter-3000; Loss: 121.3
Iter-4000; Loss: 120.3
Iter-5000; Loss: 125.1
Iter-6000; Loss: 119.7
Iter-7000; Loss: 135.1
Iter-8000; Loss: 121.1
Iter-9000; Loss: 130.1
Iter-10000; Loss: 121.7
Iter-11000; Loss: 117.5
Iter-12000; Loss: 118.8
Iter-13000; Loss: 127.6
Iter-14000; Loss: 122.8
Iter-15000; Loss: 125.1
Iter-16000; Loss: 118.5
Iter-17000; Loss: 113.9
Iter-18000; Loss: 108.5
Iter-19000; Loss: 122.6
Iter-20000; Loss: 110.9
Iter-21000; Loss: 123.0
Iter-22000; Loss: 106.6
Iter-23000; Loss: 116.8
Iter-24000; Loss: 120.9
Iter-25000; Loss: 119.3
Iter-26000; Loss: 116.7
Iter-27000; Loss: 119.5
Iter-28000; Loss: 110.7
Iter-29000; Loss: 117.6
Iter-30000; Loss: 124.9
Iter-31000; Loss: 112.4
Iter-32000; Loss: 119.3
Iter-33000; Loss: 123.4
Iter-34000; Loss: 120.9
Iter-35000; Loss: 110.7
Iter-36000; Loss: 129.7
Iter-37000; Loss: 125.7
Iter-38000; Loss: 126.9
It