<a href="https://colab.research.google.com/github/shere-khan/machine_learning/blob/master/WassersteinGANPytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 128
NUM_EPOCHS = 5
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = .01

\begin{align}
\underset{||f||_L \leq 1}{max}\:\mathbb{E}_{x\sim\mathbb{P}_r}[f(x)] - \mathbb{E}_{x\sim\mathbb{P}_\theta}[f(x)]
\end{align}

Discriminator wants to separate (maximize), and generator wants to put these closer together (minimize\)

In [None]:
transforms = transforms.Compose(
    [
      transforms.Resize(IMAGE_SIZE),
      transforms.ToTensor(),
      transforms.Normalize(
        [.5 for _ in range(CHANNELS_IMG)],
        [.5 for _ in range(CHANNELS_IMG)]
     )
    ]
)

In [None]:
dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        nn.Conv2d(
            channels_img, features_d, kernel_size=4, stride=2, padding=1
        ),
        nn.LeakyReLU(.2),
        self._block(features_d, features_d * 2, 4, 2, 1),
        self._block(features_d * 2, features_d * 4, 4, 2, 1),
        self._block(features_d * 4, features_d * 8, 4, 2, 1),
        nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0)
    )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False
        ),
        nn.InstanceNorm2d(out_channels, affine=True),
        nn.LeakyReLU(.2)
    )

    def forward(self, x):
      return self.disc(x)

In [None]:
class Generator(nn.Module):
  def __init__(self, channels_noise, channels_img, features_g):
    super(Generator, self).__init()
    self.net = nn.Sequential(
        self._block(channels_noise, features_g * 16, 4, 1, 0),
        self._block(channels_noise * 16, features_g * 8, 4, 1, 0),
        self._block(channels_noise * 8, features_g * 4, 4, 1, 0),
        self._block(channels_noise * 4, features_g * 2, 4, 1, 0),
        nn.ConvTranspose2d(
            features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
        ),
        nn.Tanh(),
    )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )
  
  def forward(self, x):
    return self.net(x)