<a href="https://colab.research.google.com/github/srirambandi/GAN/blob/master/gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# when running in colab notebooks, first install library
!pip install import-ai
# upload respective dataset manually from examples directory of the library or download as below
!apt install subversion
!svn checkout https://github.com/srirambandi/ai/trunk/examples/MNIST

Collecting import-ai
  Downloading https://files.pythonhosted.org/packages/16/c9/61a99b75a3ccd70ddd207ea12e28b9baa17feba9fffb2d7181d03d1b8147/import_ai-1.3.11-py3-none-any.whl
Collecting graphviz>=0.14
  Downloading https://files.pythonhosted.org/packages/83/cc/c62100906d30f95d46451c15eb407da7db201e30f42008f3643945910373/graphviz-0.14-py2.py3-none-any.whl
Installing collected packages: graphviz, import-ai
  Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
Successfully installed graphviz-0.14 import-ai-1.3.11
Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following package was automatically installed and is no longer required:
  libnvidia-common-440
Use 'apt autoremove' to remove it.
The following additional packages will be installed:
  libapr1 libaprutil1 libserf-1-1 libsvn1
Suggested packages:
  db5.3-util libapache2-mod-svn subversion-tools
The following 

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [3]:
import ai
import numpy as np

In [4]:
z_dim = 100
gf_dim = 64
df_dim = 64

In [5]:
ai.manual_seed(2357)

In [6]:
def data_generator(m):
    train_dict = np.load('MNIST/train.npy', allow_pickle=True)
    test_dict = np.load('MNIST/test.npy', allow_pickle=True)
    data = np.concatenate([train_dict.item()['data'], test_dict.item()['data']])
    data = data.transpose(1, 2, 0)   # making data batch-last
    data = data.reshape(1, *data.shape) / 255   # adding channel dimension and normalizing data
    epoch = 0
    
    while True:
        epoch += 1
        for batch in range(int(data.shape[-1] / m)):
            yield data[...,batch * m:(batch + 1) * m], epoch

In [7]:
class Generator(ai.Module):
    def __init__(self):
        self.g_fc = ai.Linear(z_dim, 8*gf_dim * 2 * 2)
        self.g_bn1 = ai.BatchNorm((8*gf_dim, 2, 2))
        self.g_deconv1 = ai.ConvTranspose2d(8*gf_dim, 4*gf_dim, kernel_size=5, stride=2, padding=2, a=1)
        self.g_bn2 = ai.BatchNorm((4*gf_dim, 4, 4))
        self.g_deconv2 = ai.ConvTranspose2d(4*gf_dim, 2*gf_dim, kernel_size=5, stride=2, padding=2, a=0)
        self.g_bn3 = ai.BatchNorm((2*gf_dim, 7, 7))
        self.g_deconv3 = ai.ConvTranspose2d(2*gf_dim, gf_dim, kernel_size=5, stride=2, padding=2, a=1)
        self.g_bn4 = ai.BatchNorm((gf_dim, 14, 14))
        self.g_deconv4 = ai.ConvTranspose2d(gf_dim, 1, kernel_size=5, stride=2, padding=2, a=1)
        
    def forward(self, z):
        o1 = ai.G.reshape(self.g_fc(z), (8*gf_dim, 2, 2))
        o2 = ai.G.relu(self.g_bn1(o1))
        o3 = ai.G.relu(self.g_bn2(self.g_deconv1(o2)))
        o4 = ai.G.relu(self.g_bn3(self.g_deconv2(o3)))
        o5 = ai.G.relu(self.g_bn4(self.g_deconv3(o4)))
        fake_image = ai.G.tanh(self.g_deconv4(o5))
        return fake_image

In [8]:
class Discriminator(ai.Module):
    def __init__(self):
        self.d_conv1 = ai.Conv2d(1, 64, kernel_size=5, stride=2, padding=2)
        self.d_conv2 = ai.Conv2d(64, 2*64, kernel_size=5, stride=2, padding=2)
        self.d_bn1 = ai.BatchNorm((2*64, 7, 7))
        self.d_conv3 = ai.Conv2d(2*64, 3*64, kernel_size=5, stride=2, padding=2)
        self.d_bn2 = ai.BatchNorm((3*64, 4, 4))
        self.d_conv4 = ai.Conv2d(3*64, 4*64, kernel_size=5, stride=2, padding=2)
        self.d_bn3 = ai.BatchNorm((4*64, 2, 2))
        self.d_fc = ai.Linear(1024, 1)
        
    def forward(self, image):
        o1 = ai.G.lrelu(self.d_conv1(image))
        o2 = ai.G.lrelu(self.d_bn1(self.d_conv2(o1)))
        o3 = ai.G.lrelu(self.d_bn2(self.d_conv3(o2)))
        o4 = ai.G.lrelu(self.d_bn3(self.d_conv4(o3)))
        o5 = self.d_fc(o4)
        return ai.G.sigmoid(o5)

In [9]:
generator = Generator()
discriminator = Discriminator()
# generator.load('/content/drive/My Drive/GAN/Generator.npy')
# discriminator.load('/content/drive/My Drive/GAN/Discriminator.npy')
print(generator)
print(discriminator)

Generator(
  g_fc: Linear(input_features=100, output_features=2048, bias=True)
  g_bn1: BatchNorm((512, 2, 2), axis=-1, momentum=0.9, bias=True)
  g_deconv1: ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), a=(1, 1), bias=True)
  g_bn2: BatchNorm((256, 4, 4), axis=-1, momentum=0.9, bias=True)
  g_deconv2: ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), a=(0, 0), bias=True)
  g_bn3: BatchNorm((128, 7, 7), axis=-1, momentum=0.9, bias=True)
  g_deconv3: ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), a=(1, 1), bias=True)
  g_bn4: BatchNorm((64, 14, 14), axis=-1, momentum=0.9, bias=True)
  g_deconv4: ConvTranspose2d(64, 1, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), a=(1, 1), bias=True)
)
Discriminator(
  d_conv1: Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=True)
  d_conv2: Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=True)
  d_bn1: BatchNorm((128

In [10]:
lr = 0.0002
beta1 = 0.5
L = ai.Loss(loss_fn='BCELoss')
g_optim = ai.Optimizer(generator.parameters(), optim_fn='Adam', lr=lr, beta1=beta1)
d_optim = ai.Optimizer(discriminator.parameters(), optim_fn='Adam', lr=lr, beta1=beta1)

In [11]:
it, epoch = 0, 0
m = 64   # batch size
n_discriminator = 1   # number of descriminator updates per generator update

# real images data generator
data = data_generator(m)

sample_z = np.random.uniform(-1, 1, (z_dim, m))
# sampled_images = np.load('/content/drive/My Drive/GAN/sampled_images.npy')
sampled_images = None

In [12]:
def sampler(sampled_images):
    ai.G.grad_mode = False

    # generate images like real data
    fake_images = generator.forward(sample_z).data
    fake_images = (fake_images + 1.) / 2.

    if sampled_images is not None:
        sampled_images = np.concatenate([sampled_images, fake_images], axis=-1)
    else:
        sampled_images = fake_images
    
    ai.G.grad_mode = True

    return sampled_images

In [13]:
for it in range(10000):
    
    # freeze generator before optimizing descriminator
    for p in generator.parameters():
        p.eval_grad = False

    # training descriminator to identify real/fake data
    for _ in range(n_discriminator):

        real_images, epoch = data.__next__()
        real_labels = np.ones((1, m))
        if (real_images.shape[-1] != m):
            continue

        real_probs = discriminator.forward(real_images)
        d_loss_real = L.loss(real_probs, real_labels)

        z = np.random.randn(z_dim, m)
        fake_images = generator.forward(z)
        fake_labels =  np.zeros((1, m))

        fake_probs = discriminator.forward(fake_images)
        d_loss_fake = L.loss(fake_probs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        d_loss.grad = np.zeros(d_loss.shape)

        d_loss.backward()
        d_optim.step()
        d_optim.zero_grad()

    # unfreeze generator
    for p in generator.parameters():
        p.eval_grad = True

    # training generator to fool descriminator with fake data
    z = np.random.uniform(-1, 1, (z_dim, m))
    fake_images = generator.forward(z)
    fake_labels =  np.ones((1, m))

    fake_probs = discriminator.forward(fake_images)
    g_loss = L.loss(fake_probs, fake_labels)

    g_loss.backward()
    g_optim.step()
    g_optim.zero_grad()
    d_optim.zero_grad()

    if it%10 == 0:
        print('Iter: {}, Epoch: {}, d_loss: {}, g_loss: {}'.format(it, epoch, d_loss.data[0, 0], g_loss.data[0, 0]))
    
    if it%100 == 0:
        sampled_images=sampler(sampled_images)
        np.save('/content/drive/My Drive/GAN/sampled_images.npy', sampled_images)
        generator.save('/content/drive/My Drive/GAN/Generator.npy')
        discriminator.save('/content/drive/My Drive/GAN/Discriminator.npy')

using Adam
using Adam
Iter: 0, Epoch: 1, d_loss: 1.4335286368865323, g_loss: 0.833239350290024
saving model...
Successfully saved model in /content/drive/My Drive/GAN/Generator.npy
saving model...
Successfully saved model in /content/drive/My Drive/GAN/Discriminator.npy
Iter: 10, Epoch: 1, d_loss: 1.2385491067579015, g_loss: 0.8104526989292754
Iter: 20, Epoch: 1, d_loss: 1.0340996498594213, g_loss: 0.9235075225132299
Iter: 30, Epoch: 1, d_loss: 0.7504887007381555, g_loss: 1.2236338665452886
Iter: 40, Epoch: 1, d_loss: 0.6435615975799105, g_loss: 1.380117115563181
Iter: 50, Epoch: 1, d_loss: 0.7209061828820125, g_loss: 1.4645621703774203
Iter: 60, Epoch: 1, d_loss: 0.4668926745634556, g_loss: 1.582988254356857
Iter: 70, Epoch: 1, d_loss: 0.49439195351069876, g_loss: 1.6848709696560589
Iter: 80, Epoch: 1, d_loss: 0.4686239724260851, g_loss: 1.7883883879502942
Iter: 90, Epoch: 1, d_loss: 0.4997568281571594, g_loss: 1.55228654791657
Iter: 100, Epoch: 1, d_loss: 0.45042082957443313, g_loss:

KeyboardInterrupt: ignored