<a href="https://colab.research.google.com/github/srirambandi/GAN/blob/master/wgan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# when running in colab notebooks, first install library
!pip install import-ai
# upload respective dataset manually from examples directory of the library or download as below
!apt install subversion
!svn checkout https://github.com/srirambandi/ai/trunk/examples/MNIST

In [2]:
import ai
import numpy as np

In [3]:
z_dim = 100
gf_dim = 64
df_dim = 64

In [4]:
ai.manual_seed(2357)

In [5]:
def data_generator(m):
    train_dict = np.load('MNIST/train.npy', allow_pickle=True)
    test_dict = np.load('MNIST/test.npy', allow_pickle=True)
    data = np.concatenate([train_dict.item()['data'], test_dict.item()['data']])
    data = data.transpose(1, 2, 0)   # making data batch-last
    data = data.reshape(1, *data.shape) / 255   # adding channel dimension and normalizing data
    
    while True:
        for batch in range(int(data.shape[-1] / m)):
            yield data[...,batch * m:(batch + 1) * m]

In [6]:
class Generator(ai.Module):
    def __init__(self):
        self.g_fc = ai.Linear(z_dim, 8*gf_dim * 2 * 2)
        self.g_bn1 = ai.BatchNorm((8*gf_dim, 2, 2))
        self.g_deconv1 = ai.ConvTranspose2d(8*gf_dim, 4*gf_dim, kernel_size=5, stride=2, padding=2, a=1)
        self.g_bn2 = ai.BatchNorm((4*gf_dim, 4, 4))
        self.g_deconv2 = ai.ConvTranspose2d(4*gf_dim, 2*gf_dim, kernel_size=5, stride=2, padding=2, a=0)
        self.g_bn3 = ai.BatchNorm((2*gf_dim, 7, 7))
        self.g_deconv3 = ai.ConvTranspose2d(2*gf_dim, gf_dim, kernel_size=5, stride=2, padding=2, a=1)
        self.g_bn4 = ai.BatchNorm((gf_dim, 14, 14))
        self.g_deconv4 = ai.ConvTranspose2d(gf_dim, 1, kernel_size=5, stride=2, padding=2, a=1)
        
    def forward(self, z):
        o1 = ai.G.reshape(self.g_fc(z), (8*gf_dim, 2, 2))
        o2 = ai.G.relu(self.g_bn1(o1))
        o3 = ai.G.relu(self.g_bn2(self.g_deconv1(o2)))
        o4 = ai.G.relu(self.g_bn3(self.g_deconv2(o3)))
        o5 = ai.G.relu(self.g_bn4(self.g_deconv3(o4)))
        fake_image = ai.G.tanh(self.g_deconv4(o5))
        return fake_image

In [7]:
class Critic(ai.Module):
    def __init__(self):
        self.d_conv1 = ai.Conv2d(1, 64, kernel_size=5, stride=2, padding=2)
        self.d_conv2 = ai.Conv2d(64, 2*64, kernel_size=5, stride=2, padding=2)
        self.d_bn1 = ai.BatchNorm((2*64, 7, 7))
        self.d_conv3 = ai.Conv2d(2*64, 3*64, kernel_size=5, stride=2, padding=2)
        self.d_bn2 = ai.BatchNorm((3*64, 4, 4))
        self.d_conv4 = ai.Conv2d(3*64, 4*64, kernel_size=5, stride=2, padding=2)
        self.d_bn3 = ai.BatchNorm((4*64, 2, 2))
        self.d_fc = ai.Linear(1024, 1)
        
    def forward(self, image):
        o1 = ai.G.lrelu(self.d_conv1(image))
        o2 = ai.G.lrelu(self.d_bn1(self.d_conv2(o1)))
        o3 = ai.G.lrelu(self.d_bn2(self.d_conv3(o2)))
        o4 = ai.G.lrelu(self.d_bn3(self.d_conv4(o3)))
        o5 = self.d_fc(o4)
        return o5

In [8]:
generator = Generator()
critic = Critic()
print(generator)
print(critic)

Generator(
  g_fc: Linear(input_features=100, output_features=2048, bias=True)
  g_bn1: BatchNorm((512, 2, 2), axis=-1, momentum=0.9, bias=True)
  g_deconv1: ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), a=(1, 1), bias=True)
  g_bn2: BatchNorm((256, 4, 4), axis=-1, momentum=0.9, bias=True)
  g_deconv2: ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), a=(0, 0), bias=True)
  g_bn3: BatchNorm((128, 7, 7), axis=-1, momentum=0.9, bias=True)
  g_deconv3: ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), a=(1, 1), bias=True)
  g_bn4: BatchNorm((64, 14, 14), axis=-1, momentum=0.9, bias=True)
  g_deconv4: ConvTranspose2d(64, 1, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), a=(1, 1), bias=True)
)
Critic(
  d_conv1: Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=True)
  d_conv2: Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=True)
  d_bn1: BatchNorm((128, 7, 7)

In [9]:
lr = 0.00005
c = 0.01
g_optim = ai.Optimizer(generator.parameters(), optim_fn='RMSProp', lr=lr)
c_optim = ai.Optimizer(critic.parameters(), optim_fn='RMSProp', lr=lr)

In [10]:
it, epoch = 0, 0
m = 64   # batch size
n_critic = 5   # number of critic updates per generator update
data = data_generator(m)

In [11]:
def evaluate():
    ai.G.grad_mode = False

    # generate images like real data
    z = np.random.randn(z_dim, m)
    fake_images = generator.forward(z)
    
    ai.G.grad_mode = True

In [None]:
while epoch < 10:
    epoch += 1
    it = 0

    while it < int(70000/m):

        # freeze generator before optimizing critic
        for p in generator.parameters():
            p.eval_grad = False

        # training critic to identify real/fake data
        for _ in range(n_critic):

            real_images = data.__next__()
            real_labels = np.ones((1, m))
            if (real_images.shape[-1] != m):
                continue

            c_loss_real = critic.forward(real_images)

            z = np.random.randn(z_dim, m)
            fake_images = generator.forward(z)
            fake_labels = np.zeros((1, m))

            c_loss_fake = critic.forward(fake_images)

            c_loss = c_loss_real - c_loss_fake
            c_loss.grad = np.ones(c_loss.shape)

            c_loss.backward()
            c_optim.step()
            c_optim.zero_grad()

            ai.clip_grad_value(critic.parameters(), c)

        # unfreeze generator
        for p in generator.parameters():
            p.eval_grad = True

        # training generator to fool descriminator with fake data
        z = np.random.randn(z_dim, m)
        fake_images = generator.forward(z)
        fake_labels = np.ones((1, m))

        g_loss = critic.forward(fake_images)
        g_loss.grad = np.ones(g_loss.shape)

        g_loss.backward()
        g_optim.step()
        g_optim.zero_grad()
        c_optim.zero_grad()

        if it%1 == 0:
            print('epoch: {}, iter: {}, c_loss: {}, g_loss: {}, sum_loss: {}'.format(epoch, it, c_loss.data[0, 0], g_loss.data[0, 0], (c_loss.data[0, 0] + g_loss.data[0, 0])))
        it += n_critic
    
    print('Epoch {} completed. Accuracy: {}'.format(epoch, evaluate()))
    generator.save()
    critic.save()