<a href="https://colab.research.google.com/github/yastiaisyah/DataSynthesis/blob/main/adversarial_autoencoder_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from torchvision.datasets import MNIST
from torchvision import transforms

# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
mnist = MNIST(root='./data', train=True, transform=transform, download=True)

# Normalize the data to [0, 1] range
mnist.data = mnist.data.float() / 255.0

# Parameters
mb_size = 32
z_dim = 5
X_dim = mnist.data.size(1) * mnist.data.size(2)  # Flattened image dimensions
h_dim = 128
lr = 1e-3

# Encoder
Q = nn.Sequential(
    nn.Linear(X_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, z_dim)
)

# Decoder
P = nn.Sequential(
    nn.Linear(z_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, X_dim),
    nn.Sigmoid()
)

# Discriminator
D = nn.Sequential(
    nn.Linear(z_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, 1),
    nn.Sigmoid()
)

def reset_grad():
    Q.zero_grad()
    P.zero_grad()
    D.zero_grad()

def sample_X(size):
    indices = np.random.randint(0, len(mnist), size)
    X = mnist.data[indices].view(size, -1).float()
    return Variable(X)

Q_solver = optim.Adam(Q.parameters(), lr=lr)
P_solver = optim.Adam(P.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)
cnt=0
"""1000000"""
for it in range(100000):
    X = sample_X(mb_size)

    """ Reconstruction phase """
    z_sample = Q(X)
    X_sample = P(z_sample)

    # Clip values to be within [0, 1]
    X_sample = X_sample.clamp(0, 1)

    # Use BCELoss for binary cross entropy
    recon_loss = nn.BCELoss()(X_sample, X)

    recon_loss.backward()
    P_solver.step()
    Q_solver.step()
    reset_grad()

    """ Regularization phase """
    # Discriminator
    z_real = Variable(torch.randn(mb_size, z_dim))
    z_fake = Q(X)

    D_real = D(z_real)
    D_fake = D(z_fake)

    D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))

    D_loss.backward()
    D_solver.step()
    reset_grad()

    # Generator
    z_fake = Q(X)
    D_fake = D(z_fake)

    G_loss = -torch.mean(torch.log(D_fake))

    G_loss.backward()
    Q_solver.step()
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'
              .format(it, D_loss.item(), G_loss.item(), recon_loss.item()))

        samples = P(z_real).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'
                    .format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 74516856.28it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 76571235.03it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 31893395.24it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 989946.41it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Iter-0; D_loss: 1.397; G_loss: 0.6724; recon_loss: 0.6947
Iter-1000; D_loss: 1.142; G_loss: 0.9264; recon_loss: 0.2744
Iter-2000; D_loss: 1.317; G_loss: 0.9261; recon_loss: 0.2518
Iter-3000; D_loss: 1.863; G_loss: 0.8282; recon_loss: 0.2677
Iter-4000; D_loss: 1.357; G_loss: 0.8837; recon_loss: 0.2719
Iter-5000; D_loss: 0.8553; G_loss: 0.9618; recon_loss: 0.2588
Iter-6000; D_loss: 1.081; G_loss: 0.7807; recon_loss: 0.2753
Iter-7000; D_loss: 1.182; G_loss: 1.382; recon_loss: 0.263
Iter-8000; D_loss: 1.319; G_loss: 0.8952; recon_loss: 0.2699
Iter-9000; D_loss: 4.415; G_loss: 0.03042; recon_loss: 0.2745
Iter-10000; D_loss: 0.7709; G_loss: 1.238; recon_loss: 0.2598
Iter-11000; D_loss: 0.5094; G_loss: 1.451; recon_loss: 0.277
Iter-12000; D_loss: 1.437; G_loss: 1.081; recon_loss: 0.253
Iter-13000; D_loss: 0.7438; G_loss: 1.52; recon_loss: 0.2537
Iter-14000; D_loss: 1.355; G_loss: 0.7209; recon_loss: 0.2447
Iter-15000; D_loss: 1.092; G_loss: 0.867; recon_loss: 0.2568
Iter-16000; D_loss: 1.648;