# Deep Convolutional Generative Adversarial Network

This example builds DCGAN in Equinox, which is a Generative Adversarial Network (GAN) using convolutional layers in the discriminator and generator.

This example is an implementation based on the paper [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434).

Authors: Alec Radford (alec@indico.io), Luke Metz (luke@indico.io), and Soumith Chintala (soumith@fb.com)

In [8]:
# imports

import jax

# import jax.numpy as jnp
# import optax
# import torch
from typing import Union

import equinox as eqx

In [9]:
# Hyperparameters

In [11]:
# Generator Model


class Generator(eqx.Module):
    conv_layers: list[Union[eqx.nn.ConvTranspose2d, eqx.nn.BatchNorm, jax.nn.relu]]
    output_layers: list[Union[eqx.nn.ConvTranspose2d, jax.nn.tanh]]

    def __init__(
        self, input_shape: int = 100, output_shape: tuple[int, int] = (64, 64, 3)
    ):
        self.conv_layers = [
            eqx.nn.ConvTranspose2d(
                in_channels=input_shape,
                out_channels=output_shape[0] * 8,
                kernel_size=4,
                stride=1,
                padding=0,
                bias=False,
            ),
            eqx.nn.BatchNorm(),
            jax.nn.relu,
            eqx.nn.ConvTranspose2d(
                in_channels=output_shape[0] * 8,
                out_channels=output_shape[0] * 4,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            eqx.nn.BatchNorm(),
            jax.nn.relu,
            eqx.nn.ConvTranspose2d(
                in_channels=output_shape[0] * 4,
                out_channels=output_shape[0] * 2,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            eqx.nn.BatchNorm(),
            jax.nn.relu,
            eqx.nn.ConvTranspose2d(
                in_channels=output_shape[0] * 2,
                out_channels=output_shape[0],
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            eqx.nn.BatchNorm(),
            jax.nn.relu,
        ]

        self.output_layers = [
            eqx.nn.ConvTranspose2d(
                in_channels=output_shape[0],
                out_channels=output_shape[2],
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            jax.nn.tanh,
        ]

    def __call__(self, x):
        for layer in self.conv_layers:
            x = layer(x)
        for layer in self.output_layers:
            x = layer(x)

        return x

In [5]:
# Discriminator Model


class Discriminator(eqx.Module):
    pass