# Generative Models in JAX Framework

## Table of contents:

* GANs
*

## Generative Adversarial Networks:

Model architectures Implemented:
* Vanilla GAN
* DC GAN

# Installing Dependencies

## Dataset Loaders

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

## Helper Functions

In [None]:
@partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def create_state(rng, model_cls, input_shape): 
  r"""Create the training state given a model class. """ 
  model = model_cls()

  tx = optax.adam(0.0002, b1=0.5, b2=0.999)
  variables = model.init(rng, jnp.ones(input_shape))

  state = TrainState.create(apply_fn=model.apply, tx=tx, 
      params=variables['params'], batch_stats=variables['batch_stats'])
  
  return state


@jax.pmap
def sample_from_generator(generator_state, input_noise):
  """Sample from the generator in evaluation mode."""
  generated_data = generator_state.apply_fn(
      {'params': generator_state.params,
       'batch_stats': generator_state.batch_stats},
      input_noise, train=False, mutable=False)
  return generated_data

# Vanilla GAN:

In [None]:
class vanilla_gan(nn.module):
    super().__init__()
    def __init__(self):
        
    def __call__():
        """
        
        """
        
class Generator(nn.module):
    features: int = 64
    dtype: type = jnp.float32
        
    @nncompact
    def __call__():
        conv_transpose = partial(nn.ConvTranspose, padding='VALID',
                             kernel_init=normal_init(0.02), dtype=self.dtype)
        batch_norm = partial(nn.BatchNorm, use_running_average=not train, axis=-1, 
                             scale_init=normal_init(0.02), dtype=self.dtype)
        
        
        
            
    class Generator(nn.Module):
  features: int = 64
  dtype: type = jnp.float32

In [None]:
class Generator(hk.Module):
  """Generator network."""

  def __init__(self, output_channels=(32, 1), name=None):
    super().__init__(name=name)
    self.output_channels = output_channels

  def __call__(self, x):
    """Maps noise latents to images."""
    x = hk.Linear(7 * 7 * 64)(x)
    x = jnp.reshape(x, x.shape[:1] + (7, 7, 64))
    for output_channels in self.output_channels:
      x = jax.nn.relu(x)
      x = hk.Conv2DTranspose(output_channels=output_channels,
                             kernel_shape=[5, 5],
                             stride=2,
                             padding="SAME")(x)
    # We use a tanh to ensure that the generated samples are in the same
    # range as the data.
    return jnp.tanh(x)

In [None]:
class Discriminator(hk.Module):
  """Discriminator network."""

  def __init__(self,
               output_channels=(8, 16, 32, 64, 128),
               strides=(2, 1, 2, 1, 2),
               name=None):
    super().__init__(name=name)
    self.output_channels = output_channels
    self.strides = strides

  def __call__(self, x):
    """Classifies images as real or fake."""
    for output_channels, stride in zip(self.output_channels, self.strides):
      x = hk.Conv2D(output_channels=output_channels,
                    kernel_shape=[5, 5],
                    stride=stride,
                    padding="SAME")(x)
      x = jax.nn.leaky_relu(x, negative_slope=0.2)
    x = hk.Flatten()(x)
    # We have two classes: 0 = input is fake, 1 = input is real.
    logits = hk.Linear(2)(x)
    return logits

In [None]:
class GAN:
  """A basic GAN."""

  def __init__(self, num_latents):
    self.num_latents = num_latents

    # Define the Haiku network transforms.
    # We don't use BatchNorm so we don't use `with_state`.
    self.gen_transform = hk.without_apply_rng(
        hk.transform(lambda *args: Generator()(*args)))
    self.disc_transform = hk.without_apply_rng(
        hk.transform(lambda *args: Discriminator()(*args)))

    # Build the optimizers.
    self.optimizers = GANTuple(gen=optax.adam(1e-4, b1=0.5, b2=0.9),
                               disc=optax.adam(1e-4, b1=0.5, b2=0.9))

  @functools.partial(jax.jit, static_argnums=0)
  def initial_state(self, rng, batch):
    """Returns the initial parameters and optimize states."""
    # Generate dummy latents for the generator.
    dummy_latents = jnp.zeros((batch.shape[0], self.num_latents))

    # Get initial network parameters.
    rng_gen, rng_disc = jax.random.split(rng)
    params = GANTuple(gen=self.gen_transform.init(rng_gen, dummy_latents),
                      disc=self.disc_transform.init(rng_disc, batch))
    print("Generator: \n\n{}\n".format(tree_shape(params.gen)))
    print("Discriminator: \n\n{}\n".format(tree_shape(params.disc)))

    # Initialize the optimizers.
    opt_state = GANTuple(gen=self.optimizers.gen.init(params.gen),
                         disc=self.optimizers.disc.init(params.disc))
    return GANState(params=params, opt_state=opt_state)


    def sample(self, rng, gen_params, num_samples):
    """Generates images from noise latents."""
    latents = jax.random.normal(rng, shape=(num_samples, self.num_latents))
    return self.gen_transform.apply(gen_params, latents)

  def gen_loss(self, gen_params, rng, disc_params, batch):
    """Generator loss."""
    # Sample from the generator.
    fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0])

    # Evaluate using the discriminator. Recall class 1 is real.
    fake_logits = self.disc_transform.apply(disc_params, fake_batch)
    fake_probs = jax.nn.softmax(fake_logits)[:, 1]
    loss = -jnp.log(fake_probs)

    return jnp.mean(loss)

  def disc_loss(self, disc_params, rng, gen_params, batch):
    """Discriminator loss."""
    # Sample from the generator.
    fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0])

    # For efficiency we process both the real and fake data in one pass.
    real_and_fake_batch = jnp.concatenate([batch, fake_batch], axis=0)
    real_and_fake_logits = self.disc_transform.apply(disc_params,
                                                     real_and_fake_batch)
    real_logits, fake_logits = jnp.split(real_and_fake_logits, 2, axis=0)

    # Class 1 is real.
    real_labels = jnp.ones((batch.shape[0],), dtype=jnp.int32)
    real_loss = sparse_softmax_cross_entropy(real_logits, real_labels)

    # Class 0 is fake.
     fake_labels = jnp.zeros((batch.shape[0],), dtype=jnp.int32)
    fake_loss = sparse_softmax_cross_entropy(fake_logits, fake_labels)

    return jnp.mean(real_loss + fake_loss)

  @functools.partial(jax.jit, static_argnums=0)
  def update(self, rng, gan_state, batch):
    """Performs a parameter update."""
    rng, rng_gen, rng_disc = jax.random.split(rng, 3)

    # Update the discriminator.
    disc_loss, disc_grads = jax.value_and_grad(self.disc_loss)(
        gan_state.params.disc,
        rng_disc,
        gan_state.params.gen,
        batch)
    disc_update, disc_opt_state = self.optimizers.disc.update(
        disc_grads, gan_state.opt_state.disc)
    disc_params = optax.apply_updates(gan_state.params.disc, disc_update)

    # Update the generator.
    gen_loss, gen_grads = jax.value_and_grad(self.gen_loss)(
        gan_state.params.gen,
        rng_gen,
        gan_state.params.disc,
        batch)
    gen_update, gen_opt_state = self.optimizers.gen.update(
        gen_grads, gan_state.opt_state.gen)
    gen_params = optax.apply_updates(gan_state.params.gen, gen_update)
    params = GANTuple(gen=gen_params, disc=disc_params)
    opt_state = GANTuple(gen=gen_opt_state, disc=disc_opt_state)
    gan_state = GANState(params=params, opt_state=opt_state)
    log = {
        "gen_loss": gen_loss,
        "disc_loss": disc_loss,
    }

    return rng, gan_state, log


# Traning GAN