## Ising Model GAN 

In [1]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, "..")
from input_pipeline import dataset_tfrecord_pipeline
import time
from IPython import display
from model import make_discriminator_model, make_generator_model, train_step

## Load data

In [2]:
batch_size = 64

train_path = '../../GetData/Python/Data/Data2.5.tfrecord'
train_ds = dataset_tfrecord_pipeline(train_path, flatten=False, batch_size=batch_size)

## Training

In [None]:
def generate_and_save_images(model, epoch, test_input):
  predictions = tf.round(model(test_input, training=False))

  fig = plt.figure(figsize=(4,4))
  
  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0])
      plt.axis('off')
        
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

def train(dataset, epochs, gen_loss_log, disc_loss_log):  
  
  for epoch in range(epochs):
    start = time.time()

    for images in tqdm(dataset): train_step(images, gen_loss_log, disc_loss_log, batch_size, noise_dim, generator, discriminator, generator_optimizer, discriminator_optimizer)

    display.clear_output(wait=True) 
    generate_and_save_images(
      generator,
      epoch + 1,
      random_vector_for_generation
    ) 
    print (f"Time taken for epoch {epoch} is {time.time()- start} sec")
    

noise_dim = 100
num_examples_to_generate = 16
random_vector_for_generation = tf.random.normal([num_examples_to_generate,
                                                 noise_dim])
generator = make_generator_model()
discriminator = make_discriminator_model()
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-6)

## Train the GAN!

In [None]:
EPOCHS=20
EPOCHS_START=5
gen_loss_log=[]
disc_loss_log=[]

## Pretrain

In [None]:
%%time
train(train_ds, EPOCHS_START, gen_loss_log, disc_loss_log)

In [None]:
new_learning_rate = 1e-4
discriminator_optimizer.lr.assign(new_learning_rate)

## Train

In [None]:
%%time
train(train_ds, EPOCHS,gen_loss_log, disc_loss_log)

## Plot the loss of the generator and discriminator

In [None]:
fig, ax1 = plt.subplots()

ax1.plot(np.asarray(disc_loss_log), color='tab:red')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Discriminator Loss', color='tab:red')
ax1.tick_params(axis='y', labelcolor='tab:red')

ax2 = ax1.twinx()
ax2.plot(np.asarray(gen_loss_log), color='tab:blue')
ax2.set_ylabel('Generator Loss', color='tab:blue')
ax2.tick_params(axis='y', labelcolor='tab:blue')

plt.show()

In [None]:
predictions = generator(random_vector_for_generation, training=False)

In [None]:
plt.imshow(predictions[0, :, :, 0])