In [None]:
import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization,Input,Dense,LeakyReLU
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import os

In [None]:
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
latten_dims=100
N,H,W=x_train.shape
D=H*W
x_train=x_train.reshape(-1,D)
y_train=y_train.reshape(N)

In [None]:
def build_generator(latten_dims):
  i=Input(shape=(latten_dims,))
  x=Dense(256,activation=LeakyReLU(alpha=0.2))(i)
  x=BatchNormalization(momentum=0.8)(x)
  x=Dense(512,activation=LeakyReLU(alpha=0.2))(x)
  x=BatchNormalization(momentum=0.8)(x)
  x=Dense(1024,activation=LeakyReLU(alpha=0.2))(x)
  x=BatchNormalization(momentum=0.8)(x)
  x=Dense(D,activation='tanh')(x)
  model=Model(i,x)
  return model

In [None]:
def build_discriminator(img_size):
  i=Input(shape=(img_size,))
  x=Dense(512,activation=LeakyReLU(alpha=0.2))(i)
  x=Dense(256,activation=LeakyReLU(alpha=0.2))(x)
  x=Dense(1,activation='sigmoid')(x)
  model=Model(i,x)
  return model

In [None]:
#make the combine model in which discriminator is freezed
discriminator=build_discriminator(D)
discriminator.compile(optimizer=Adam(0.0002,0.5),
                      loss='binary_crossentropy',
                      metrics=['accuracy'])

generator=build_generator(latten_dims)
z=Input(shape=(latten_dims,))
img=generator(z)

#for freezing discriminator 
discriminator.trainable=False

fake_pred=discriminator(img)

combined_model=Model(z,fake_pred)
combined_model.compile(optimizer=Adam(0.0002,0.5),
                      loss='binary_crossentropy',
                      metrics=['accuracy'])
print(fake_pred)
print(img)#this is output when we input noise in generator function
print(generator)#this is function

if not os.path.exists('gan_images'):
  os.makedirs('gan_images')

KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), name='model_9/dense_23/Sigmoid:0', description="created by layer 'model_9'")
KerasTensor(type_spec=TensorSpec(shape=(None, 784), dtype=tf.float32, name=None), name='model_10/dense_27/Tanh:0', description="created by layer 'model_10'")
<tensorflow.python.keras.engine.functional.Functional object at 0x7f5b66175e50>


In [None]:
batch_size=32
zeros=np.zeros(batch_size)
ones=np.ones(batch_size)
epochs=30000

In [None]:
def sample_images(epoch):
  rows,cols=5,5
  noise=np.random.randn(rows*cols,latten_dims)
  fake_img=generator.predict(noise)

  fake_img=(0.5*fake_img)+0.5
  print(fake_img)
  figs,axs=plt.subplots(5,5)
  idx=0
  for i in range(5):
    for j in range(5):
      axs[i,j].imshow(fake_img[idx].reshape(H,W),cmap='gray')
      axs[i,j].axis('off')
      idx =idx + 1
  figs.savefig("gan_images/{}.png".format(epoch))
  plt.close()

In [None]:
d_losses=[]
g_losses=[]
for epoch in range(epochs):
  #1st> Train the discriminator
  idx=np.random.randint(0,x_train.shape[0],batch_size)
  real_img=x_train[idx]

  noise=np.random.randn(batch_size,latten_dims)
  fake_img=generator.predict(noise)

  d_loss_real,d_acc_real=discriminator.train_on_batch(real_img,ones)
  d_loss_fake,d_acc_fake=discriminator.train_on_batch(fake_img,zeros)
  d_loss=0.5*(d_loss_real+d_loss_fake)
  d_acc=0.5*(d_acc_real+d_acc_fake)

  #2nd train generator
  noise=np.random.randn(batch_size,latten_dims)
  g_loss=combined_model.train_on_batch(noise,ones)

  d_losses.append(d_loss)
  g_losses.append(g_loss)

  if epoch%200==0 :
    sample_images(epoch)

  if epoch%100==0:
    epoch=epoch+1
    print("epoch:{}/{}, d_loss:{}, d_acc:{}, g_loss:{}".format(epoch,epochs,d_loss,d_acc,g_loss))

  
 

[[0.37144208 0.29833525 0.5325547  ... 0.38207012 0.76647913 0.6192215 ]
 [0.6199043  0.346403   0.52487254 ... 0.55470914 0.43739283 0.49657986]
 [0.7486125  0.4713535  0.578058   ... 0.3557151  0.5639522  0.4354803 ]
 ...
 [0.7617518  0.42528212 0.31750023 ... 0.51912534 0.46253994 0.42117774]
 [0.7144611  0.5273567  0.6133698  ... 0.2835179  0.7234126  0.3469115 ]
 [0.59942645 0.43400732 0.4042995  ... 0.45653665 0.5983046  0.61091703]]
epoch:1/30000, d_loss:0.3208794593811051, d_acc:0.90625, g_loss:[0.7413774728775024, 0.40625]
epoch:101/30000, d_loss:0.12306042015552521, d_acc:0.984375, g_loss:[1.6806788444519043, 0.0]
[[0.9858721  0.05328301 0.42162636 ... 0.85599405 0.7678113  0.05075803]
 [0.03120342 0.00099394 0.03223243 ... 0.927595   0.640355   0.12193048]
 [0.4030832  0.37981606 0.12002817 ... 0.00999779 0.1969066  0.45780215]
 ...
 [0.02378279 0.5230703  0.98536956 ... 0.6335921  0.9852092  0.5160632 ]
 [0.29928824 0.86451167 0.17274073 ... 0.7775471  0.9699929  0.01218185

In [None]:
!ls gan_images

In [None]:
from skimage.io import imread
r=imread('gan_images/0.png')
plt.imshow(r)

FileNotFoundError: ignored

In [None]:
from skimage.io import imread
r=imread('gan_images/1000.png')
plt.imshow(r)

In [None]:
from skimage.io import imread
r=imread('gan_images/5000.png')
plt.imshow(r)

In [None]:
from skimage.io import imread
r=imread('gan_images/20000.png')
plt.imshow(r)