In [1]:
from tensorflow.keras.layers import Input,Dense,LeakyReLU,Dropout,BatchNormalization

In [2]:
from tensorflow.keras.models import Model

In [3]:
from tensorflow.keras.optimizers import SGD,Adam

In [4]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sys,os

In [5]:
from tensorflow.keras.datasets import mnist

In [6]:
(X_train,y_train),(X_test,y_test)=mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [7]:
X_train,X_test=X_train/255*2-1,X_test/255*2-1

In [8]:
X_train.shape

(60000, 28, 28)

In [9]:
N,H,W=X_train.shape
D=H*W
X_train=X_train.reshape(-1,D)
X_test=X_test.reshape(-1,D)

In [10]:
latent_dim=100

In [11]:
def build_generator(latent_dim):
  i=Input(shape=(latent_dim,))
  x=Dense(256,activation=LeakyReLU(alpha=0.2))(i)
  x=BatchNormalization(momentum=0.8)(x)
  x=Dense(512,activation=LeakyReLU(alpha=0.2))(x)
  x=BatchNormalization(momentum=0.8)(x)
  x=Dense(1024,activation=LeakyReLU(alpha=0.2))(x)
  x=BatchNormalization(momentum=0.8)(x)
  x=Dense(D,activation='tanh')(x)

  model=Model(i,x)
  return model

In [12]:
def build_discriminator(img_size):
  i=Input(shape=(img_size,))
  x=Dense(512,activation=LeakyReLU(alpha=0.2))(i)
  x=Dense(128,activation=LeakyReLU(alpha=0.2))(x)
  x=Dense(1,activation='sigmoid')(x)
  model=Model(i,x)
  return model

In [13]:
discriminator=build_discriminator(D)
discriminator.compile(loss='binary_crossentropy',optimizer=Adam(0.0002,0.5),metrics=['accuracy'])
generator=build_generator(latent_dim)
z=Input(shape=(latent_dim,))
img=generator(z)
discriminator.trainable=False
fake_pred=discriminator(img)
combined_model=Model(z,fake_pred)
combined_model.compile(loss='binary_crossentropy',optimizer=Adam(0.0002,0.5))

In [14]:
batch_size=32
epochs=30000
sample_period=200
ones=np.ones(batch_size)
zeros=np.zeros(batch_size)
d_losses=[]
g_losses=[]
if not os.path.exists('gan_images'):
  os.makedirs('gan_images')

In [21]:
def sample_images(epoch):
  rows,cols=5,5
  noise=np.random.randn(rows*cols,latent_dim)
  imgs=generator.predict(noise)
  imgs=0.5*imgs+0.5
  fig,axs=plt.subplot(rows,cols)
  idx=0
  for i in range(rows):
    for j in range(cols):
      axs[i,j].imshow(imgs[idx].reshape(H,W),cmap='grey')
      axs[i,j].axis('off')
      idx+=1
  fig.savefig("/content/gan_images/%d.png"%epoch)
  plt.close()

In [22]:
for epoch in range(epochs):
  idx=np.random.randint(0,X_train.shape[0],batch_size)
  real_imgs=X_train[idx]
  noise=np.random.randn(batch_size,latent_dim)
  fake_imgs=generator.predict(noise)
  d_loss_real,d_acc_real=discriminator.train_on_batch(real_imgs,ones)
  d_loss_fake,d_acc_fake=discriminator.train_on_batch(fake_imgs,zeros)
  d_loss=0.5*(d_loss_real+d_loss_fake)
  d_acc=0.5*(d_acc_real+d_acc_fake)

  noise=np.random.randn(batch_size,latent_dim)
  g_loss=combined_model.train_on_batch(noise,ones)

  d_losses.append(d_loss)
  g_losses.append(g_loss)

  if epoch%sample_period==200:
    sample_images(epoch)



[1;30;43mStreaming output truncated to the last 5000 lines.[0m


In [24]:
d_losses

[0.6857441961765289,
 0.33884989470243454,
 0.3326876889914274,
 0.29913355223834515,
 0.3041962841525674,
 0.2844317923299968,
 0.22827821504324675,
 0.20022673066705465,
 0.20259642507880926,
 0.1376525149680674,
 0.14843553584069014,
 0.11837941408157349,
 0.10225427756085992,
 0.09096488635987043,
 0.0767903309315443,
 0.07960741128772497,
 0.08185023441910744,
 0.07692541042342782,
 0.05274247471243143,
 0.05739526520483196,
 0.05641046026721597,
 0.055501352064311504,
 0.04434114298783243,
 0.05869455658830702,
 0.042273186380043626,
 0.0445835841819644,
 0.033792154397815466,
 0.040498074726201594,
 0.03751113568432629,
 0.03227162663824856,
 0.032574190525338054,
 0.03037296747788787,
 0.0336712496355176,
 0.026780678424984217,
 0.03235163877252489,
 0.022215894423425198,
 0.023945620516315103,
 0.025575905572623014,
 0.025109003065153956,
 0.02145282831043005,
 0.02349041565321386,
 0.02496152278035879,
 0.022206067573279142,
 0.018994885962456465,
 0.018249674700200558,
 0.02

In [23]:
g_losses

[0.9376780390739441,
 0.9480072855949402,
 0.9087793827056885,
 1.033583164215088,
 1.0217094421386719,
 1.151721715927124,
 1.2550594806671143,
 1.29756498336792,
 1.4245376586914062,
 1.584409236907959,
 1.6827110052108765,
 1.795168399810791,
 1.8479437828063965,
 2.0429959297180176,
 2.0875582695007324,
 2.202389717102051,
 2.2156641483306885,
 2.3558831214904785,
 2.321070432662964,
 2.3585095405578613,
 2.403209686279297,
 2.5919041633605957,
 2.733942747116089,
 2.596405029296875,
 2.7313220500946045,
 2.81929874420166,
 3.074737548828125,
 2.867539882659912,
 2.9366886615753174,
 2.9208550453186035,
 3.0382041931152344,
 3.040245532989502,
 3.030827522277832,
 3.2341670989990234,
 3.2207207679748535,
 3.335477828979492,
 3.2741806507110596,
 3.2510955333709717,
 3.3125505447387695,
 3.365053653717041,
 3.414140224456787,
 3.354222059249878,
 3.262847900390625,
 3.400722026824951,
 3.575042724609375,
 3.5851552486419678,
 3.6254689693450928,
 3.4821343421936035,
 3.7632622718811