# PyTorch qGAN Implementation

Description

adapted from [PyTorch GAN](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py)

In [120]:
# Necessary imports

import numpy as np
from typing import Union, List, Optional, Iterable

from torch import Tensor, stack, reshape
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.optim import Adam
import torch.nn as nn
import torch.nn.functional as F


from qiskit import Aer, QuantumCircuit
from qiskit.utils import QuantumInstance, algorithm_globals
from qiskit.opflow import Gradient, StateFn
from qiskit.circuit.library import TwoLocal
from qiskit.circuit import ParameterExpression, ParameterVector
from qiskit_machine_learning.neural_networks import CircuitQNN
from qiskit_machine_learning.connectors import TorchConnector
from qiskit_machine_learning.datasets.dataset_helper import discretize_and_truncate

# Set seed for random generators
algorithm_globals.random_seed = 42

### Load training data

For testing purposes, we decide for a 2D multivariate normal distribution.
Each dimension is represented by 2 qubits.

In [121]:
data_dim = [3, 3]

training_data = np.random.default_rng().multivariate_normal(mean=[0., 0.], cov=[[1, 0], [0, 1]], size=1000, check_valid='warn',
                                                        tol=1e-8, method='svd')
# Define minimal and maximal values for the training data
bounds_min = np.percentile(training_data, 5, axis=0)
bounds_max = np.percentile(training_data, 95, axis=0)
bounds = []
for i, _ in enumerate(bounds_min):
    bounds.append([bounds_min[i], bounds_max[i]])

# Pre-processing, i.e., gridding
(training_data,
data_grid,
grid_elements,
prob_data ) = discretize_and_truncate(
training_data,
np.array(bounds),
data_dim,
return_data_grid_elements=True,
return_prob=True,
prob_non_zero=True,
)

# Define the training batch size
batch_size = 100
dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, drop_last=True)

### Specify Backend

In [122]:
# declare quantum instance
backend = Aer.get_backend('aer_simulator')
qi = QuantumInstance(backend, shots = batch_size)

### Definition of quantum generator and the respective gradient

In [123]:
def generator_(qnn: QuantumCircuit,
               parameters: Union[ParameterVector, ParameterExpression, List[ParameterExpression]]) -> TorchConnector:
    """
    Args:
        qnn: Quantum neural network ansatz given as a quantum circuit.
        parameters: The parameters of the quantum neural network which are trained.
    Returns:
        Quantum neural network compatible with PyTorch
    """
    circuit_qnn = CircuitQNN(qnn, input_params=[], weight_params = parameters,
                             quantum_instance=qi, sampling=True, sparse=False,
                             input_gradients=True, interpret=lambda x: grid_elements[x])
    # We use the Qiskit TorchConnector to ensure compatibility with PyTorch
    return TorchConnector(circuit_qnn)

def generator_grad(qnn: QuantumCircuit,
                   parameters: Union[ParameterVector, ParameterExpression, List[ParameterExpression]],
                   param_values: Iterable,
                   grad_method: str = 'param_shift') -> Iterable:
    """
    Custom generator gradient
    Args:
        qnn: Quantum neural network ansatz given as a quantum circuit.
        parameters: The parameters of the quantum neural network which are trained.
        param_values: The current values of the quantum neural network parameters.
        grad_method: Method used to compute the gradients {'param_shift', 'lin_comb', 'fin_diff'}
    Returns:
        List of gradients for the sampling probabilities of the quantum neural network.
    """
    grad = Gradient(grad_method=grad_method).gradient_wrapper(StateFn(qnn), parameters, backend=qi)
    grad_values = grad(param_values)
    return grad_values.tolist()

### Definition of classical discriminator

In [124]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.Linear_in = nn.Linear(len(data_dim), 51)
        self.Leaky_ReLU = nn.LeakyReLU(0.2, inplace=True)
        self.Linear51 = nn.Linear(51, 26)
        self.Linear26 = nn.Linear(26, 1)
        self.Sigmoid = nn.Sigmoid()

    def forward(self, input: Tensor) -> Tensor:
        x = self.Linear_in(input)
        x = self.Leaky_ReLU(x)
        x = self.Linear51(x)
        x = self.Leaky_ReLU(x)
        x = self.Linear26(x)
        x = self.Sigmoid(x)
        return x

### Definition of the quantum neural network ansatz

In [125]:
qnn = QuantumCircuit(sum(data_dim))
qnn.h(qnn.qubits)
ansatz = TwoLocal(sum(data_dim), "ry", "cx", reps=3, entanglement="circular")
qnn.compose(ansatz, inplace=True)

### Definition of the loss functions

In [126]:
# Loss function
g_loss_fun = nn.BCELoss()
d_loss_fun = nn.BCELoss()

### Evaluation of custom gradients for the generator BCE loss function

In [127]:
def g_loss_fun_grad(qnn: QuantumCircuit,
                    parameters: Union[ParameterVector, ParameterExpression, List[ParameterExpression]],
                    param_values: Iterable,
                    discriminator_: nn.Module,
                    grad_method: str = 'param_shift') -> Iterable:
    """
    Custom gradient of the generator loss function considering the custom gradients of the quantum generator
    Args:
        qnn: Quantum neural network ansatz given as a quantum circuit.
        parameters: The parameters of the quantum neural network which are trained.
        param_values: The current values of the quantum neural network parameters.
        discriminator_: Classical neural network representing the discriminator.
        grad_method: Method used to compute the gradients {'param_shift', 'lin_comb', 'fin_diff'}
    Returns:
        List of gradient values, i.e., the gradients of the loss function w.r.t. the quantum neural network parameters
    """
    grads = generator_grad(qnn, parameters, param_values, grad_method = grad_method)
    loss_grad = ()
    for j, grad in enumerate(grads):
        cx = grad[0].tocoo()
        input = []
        target = []
        weight = []
        for index, prob_grad in zip(cx.col, cx.data):
            input.append(grid_elements[index])
            target.append([1.])
            weight.append([prob_grad])
        bce_loss_grad = F.binary_cross_entropy(discriminator_(Tensor(input)), Tensor(target), weight=Tensor(weight))
        loss_grad += (bce_loss_grad, )
    loss_grad = stack(loss_grad)
    return loss_grad

### Definition of the optimizers

In [128]:
# Initialize generator and discriminator
generator = generator_(qnn, ansatz.ordered_parameters)
discriminator = Discriminator()

lr=0.001 #learning rate
b1=0.7 #first momentum parameter
b2=0.999 #second momentum parameter
n_epochs=100 #number of training epochs

#optimizer for the generator
optimizer_G = Adam(generator.parameters(), lr=lr, betas=(b1, b2))
#optimizer for the discriminator
optimizer_D = Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

### Training

In [129]:
for epoch in range(n_epochs):
    for i, data in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(data.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(data.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_data = Variable(data.type(Tensor))
        # Generate a batch of images
        gen_data = generator()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        disc_data = discriminator(real_data)
        real_loss = d_loss_fun(disc_data, valid)
        fake_loss = d_loss_fun(discriminator(gen_data), fake)  # (discriminator(gen_data).detach(), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward(retain_graph=True)
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # # Loss measures generator's ability to fool the discriminator
        g_loss = g_loss_fun(discriminator(gen_data), valid)
        g_loss.retain_grad = True
        g_loss_grad = g_loss_fun_grad(qnn, ansatz.ordered_parameters, generator.weight.data.numpy(), discriminator)
        g_loss.backward(retain_graph=True)
        for j, param in enumerate(generator.parameters()):
            param.grad = g_loss_grad
        optimizer_G.step()


        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        # if batches_done % optimizer_G.sample_interval == 0:
        #     #TODO: Do something like storing, printing or relative entropy evaluation
        #     pass

[Epoch 0/100] [Batch 0/8] [D loss: 0.695544] [G loss: 0.786838]
[Epoch 0/100] [Batch 1/8] [D loss: 0.694922] [G loss: 0.783705]
[Epoch 0/100] [Batch 2/8] [D loss: 0.694174] [G loss: 0.790349]
[Epoch 0/100] [Batch 3/8] [D loss: 0.689924] [G loss: 0.794382]
[Epoch 0/100] [Batch 4/8] [D loss: 0.691398] [G loss: 0.793176]
[Epoch 0/100] [Batch 5/8] [D loss: 0.689089] [G loss: 0.796654]
[Epoch 0/100] [Batch 6/8] [D loss: 0.681667] [G loss: 0.805799]
[Epoch 0/100] [Batch 7/8] [D loss: 0.689685] [G loss: 0.794579]
[Epoch 1/100] [Batch 0/8] [D loss: 0.679661] [G loss: 0.803387]
[Epoch 1/100] [Batch 1/8] [D loss: 0.688019] [G loss: 0.789914]
[Epoch 1/100] [Batch 2/8] [D loss: 0.689248] [G loss: 0.787703]
[Epoch 1/100] [Batch 3/8] [D loss: 0.690922] [G loss: 0.797615]
[Epoch 1/100] [Batch 4/8] [D loss: 0.681390] [G loss: 0.787319]
[Epoch 1/100] [Batch 5/8] [D loss: 0.668715] [G loss: 0.816263]
[Epoch 1/100] [Batch 6/8] [D loss: 0.692583] [G loss: 0.786402]
[Epoch 1/100] [Batch 7/8] [D loss: 0.681

KeyboardInterrupt: 

Alternative approach
class LegendrePolynomial3(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return 0.5 * (5 * input ** 3 - 3 * input)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        return grad_output * 1.5 * (5 * input ** 2 - 1)