<a href="https://colab.research.google.com/github/tackulus/229352/blob/main/Lab09.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Lab 09**

---

> **229352 Statistical Learning for Data Science 2**

> **Kasidis Torcharoen (610510531)**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!unzip /content/drive/MyDrive/pokemon.zip;
!mkdir /content/drive/MyDrive/new_pokemon; #Folder to save images of new pokemon
!mkdir /content/drive/MyDrive/GAN_weights #Folder to save models' weights

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D, Conv2DTranspose
from tensorflow.keras.layers import BatchNormalization, LeakyReLU
from tensorflow.keras.layers import Input, Flatten, Reshape, RandomRotation
from tensorflow.keras.models import Model
from tensorflow.image import random_hue

In [None]:
#TODO-0: set the training parameters below
BATCH_SIZE = 16
EPOCHS = 80
noise_dim = 100 


num_examples_to_generate = 16
IMAGE_SIZE = 64


images = tf.keras.preprocessing.image_dataset_from_directory(
    directory='/content/pokemon/pokemon', label_mode=None,
    class_names=None, color_mode='rgb', batch_size=BATCH_SIZE, image_size=(IMAGE_SIZE,
    IMAGE_SIZE), shuffle=True)

#preprocessing+augmentation
images = images.map(lambda x: random_hue(x,0.5))
images = images.map(lambda x: (x-127.5)/127.5) #transform pixels to [-1,1] 
images = images.cache().prefetch(buffer_size=tf.data.experimental.AUTOTUNE)


Found 15467 files belonging to 1 classes.


In [None]:
#Generator transforms a Gaussian vector into a 64x64x3 image

z = Input(shape=(noise_dim,)) 

x = Dense(64*8*4*4, use_bias=False)(z)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)

x = Reshape((4,4,64*8))(x)

# see https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2DTranspose
x = Conv2DTranspose(256, 4 , strides=2, padding='same' , use_bias = False)(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)

#TODO-1: fill the rest of the model
x = Conv2DTranspose(128, 4 , strides=2, padding='same' , use_bias = False)(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)

x = Conv2DTranspose(64, 4 , strides=2, padding='same' , use_bias = False)(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)

#TODO-1
out = Conv2DTranspose(3, 4 , strides=2, padding='same' , use_bias = False, 
                      activation='tanh')(x)

generator = Model(inputs = z, outputs = out)


In [None]:
generator.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
dense (Dense)                (None, 8192)              819200    
_________________________________________________________________
batch_normalization (BatchNo (None, 8192)              32768     
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 8192)              0         
_________________________________________________________________
reshape (Reshape)            (None, 4, 4, 512)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 8, 8, 256)         2097152   
_________________________________________________________________
batch_normalization_1 (Batch (None, 8, 8, 256)         1024  

In [None]:
#Discrimination takes an image, classifies whether the image is real or fake

inp = Input(shape=(IMAGE_SIZE,IMAGE_SIZE,3))

y = RandomRotation(0.05, fill_mode='constant',fill_value=1.0)(inp)

y = Conv2D(64, 4, strides = 2, padding = 'same', use_bias = False)(y)
y = BatchNormalization()(y)
y = LeakyReLU(0.2)(y)

#TODO-2: fill the rest of the model
y = Conv2D(128, 4, strides = 2, padding = 'same', use_bias = False)(y)
y = BatchNormalization()(y)
y = LeakyReLU(0.2)(y)

y = Conv2D(256, 4, strides = 2, padding = 'same', use_bias = False)(y)
y = BatchNormalization()(y)
y = LeakyReLU(0.2)(y)

y = Conv2D(512, 4, strides = 2, padding = 'same', use_bias = False)(y)
y = BatchNormalization()(y)
y = LeakyReLU(0.2)(y)

#TODO-2
y = Flatten()(y)
y_out = Dense(1, activation='sigmoid')(y)

discriminator = Model(inputs = inp, outputs = y_out)

In [None]:
discriminator.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 64, 64, 3)]       0         
_________________________________________________________________
random_rotation (RandomRotat (None, 64, 64, 3)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 32, 32, 64)        3072      
_________________________________________________________________
batch_normalization_4 (Batch (None, 32, 32, 64)        256       
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 128)       131072    
_________________________________________________________________
batch_normalization_5 (Batch (None, 16, 16, 128)       512 

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy()

def discriminator_loss(real_pred, fake_pred): 
    #TODO-3: Modify the following line to apply Label Smoothing
    real_loss = cross_entropy(tf.ones_like(real_pred) * 0.9, real_pred)
    fake_loss = cross_entropy(tf.zeros_like(fake_pred), fake_pred)
    total_loss = (real_loss + fake_loss)
    return total_loss
def generator_loss(fake_pred):
    return cross_entropy(tf.ones_like(fake_pred), fake_pred)

In [None]:
# Setup Adam optimizers for both Generator and Discriminator

#TODO-4: specify optimizers of the generator and discriminator
# generator: Adam with learning rate = 0.0002 and beta_1 = 0.5
# discriminator: Adam with learning rate = 0.0004 and beta_1 = 0.5
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0004, beta_1=0.5)

In [None]:
#No TODO here

def train_step(real_images): # real_images: minibatch of real images
    noise = tf.random.normal(mean=0, stddev=1, shape=(BATCH_SIZE, noise_dim))

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        #generate_data & classify as fake or real
        generated_images = generator(noise,training=True)

        real_pred = discriminator(real_images,training = True)
        fake_pred = discriminator(generated_images, training=True)

        #compute loss
        gen_loss = generator_loss(fake_pred)
        disc_loss = discriminator_loss(real_pred, fake_pred)

    #compute gradients
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    #update parameters with gradient descent
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss


In [None]:
# 16 random gaussian vectors to generate example images
noises = tf.random.normal(mean=0, stddev=1, shape=(num_examples_to_generate, noise_dim))

#TODO-5: Question: which line generates new pokemon from random noises?

for epoch in range(EPOCHS):
    print('Starting epoch ',epoch)
    for image_batch in tqdm(images):
        gen_loss, disc_loss =train_step(image_batch)

    print('generative loss: ',float(gen_loss),
          ' discriminative loss: ',float(disc_loss))

    predictions = generator(noises, training=False) # This line

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
       plt.subplot(4, 4, i+1)
       plt.imshow((predictions[i, :, :, :] +1)/2)
       plt.axis('off')

    plt.savefig(f'/content/drive/MyDrive/new_pokemon/image_at_epoch_{epoch:04d}.png')
    generator.save_weights(f'/content/drive/MyDrive/GAN_weights/pokemon_generator_{epoch:04d}')
    discriminator.save_weights(f'/content/drive/MyDrive/GAN_weights/pokemon_discriminator_{epoch:04d}')

Starting epoch  0


100%|██████████| 967/967 [03:15<00:00,  4.96it/s]


generative loss:  1.3204978704452515  discriminative loss:  0.8154924511909485
Starting epoch  1


100%|██████████| 967/967 [02:21<00:00,  6.85it/s]


generative loss:  1.0291776657104492  discriminative loss:  0.9107860326766968
Starting epoch  2


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  0.923734724521637  discriminative loss:  1.1342942714691162
Starting epoch  3


100%|██████████| 967/967 [02:25<00:00,  6.64it/s]


generative loss:  0.8083805441856384  discriminative loss:  1.0403069257736206
Starting epoch  4


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  1.7463964223861694  discriminative loss:  0.6544740200042725
Starting epoch  5


100%|██████████| 967/967 [02:23<00:00,  6.72it/s]


generative loss:  0.7588198184967041  discriminative loss:  1.2395234107971191
Starting epoch  6


100%|██████████| 967/967 [02:21<00:00,  6.81it/s]


generative loss:  2.3326802253723145  discriminative loss:  0.620826780796051
Starting epoch  7


100%|██████████| 967/967 [02:23<00:00,  6.76it/s]


generative loss:  2.9399304389953613  discriminative loss:  0.6235018372535706
Starting epoch  8


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  5.216785907745361  discriminative loss:  0.5365172624588013
Starting epoch  9


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.076332092285156  discriminative loss:  0.6561959385871887
Starting epoch  10


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  2.6259026527404785  discriminative loss:  0.5829294323921204
Starting epoch  11


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.149786949157715  discriminative loss:  0.9684442281723022
Starting epoch  12


100%|██████████| 967/967 [02:34<00:00,  6.28it/s]


generative loss:  2.8042664527893066  discriminative loss:  0.48962628841400146
Starting epoch  13


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  2.9331088066101074  discriminative loss:  0.6152160167694092
Starting epoch  14


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.7118923664093018  discriminative loss:  0.41268977522850037
Starting epoch  15


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.68742036819458  discriminative loss:  0.3902316093444824
Starting epoch  16


100%|██████████| 967/967 [02:25<00:00,  6.64it/s]


generative loss:  1.9658701419830322  discriminative loss:  0.5961238145828247
Starting epoch  17


100%|██████████| 967/967 [02:23<00:00,  6.72it/s]


generative loss:  1.2910293340682983  discriminative loss:  0.8405327796936035
Starting epoch  18


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  2.1316535472869873  discriminative loss:  0.5610984563827515
Starting epoch  19


100%|██████████| 967/967 [02:29<00:00,  6.49it/s]


generative loss:  2.7558624744415283  discriminative loss:  0.4777185022830963
Starting epoch  20


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]
  app.launch_new_instance()


generative loss:  3.4353930950164795  discriminative loss:  0.7241974472999573
Starting epoch  21


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  2.985231637954712  discriminative loss:  0.4347580075263977
Starting epoch  22


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.821601867675781  discriminative loss:  0.5320324301719666
Starting epoch  23


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.7320737838745117  discriminative loss:  0.3973955512046814
Starting epoch  24


100%|██████████| 967/967 [02:21<00:00,  6.85it/s]


generative loss:  3.1446404457092285  discriminative loss:  0.4663289487361908
Starting epoch  25


100%|██████████| 967/967 [02:26<00:00,  6.59it/s]


generative loss:  4.435831069946289  discriminative loss:  0.38714176416397095
Starting epoch  26


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.5027658939361572  discriminative loss:  0.40226611495018005
Starting epoch  27


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.013540744781494  discriminative loss:  0.39785778522491455
Starting epoch  28


100%|██████████| 967/967 [02:29<00:00,  6.49it/s]


generative loss:  3.7437169551849365  discriminative loss:  0.4653446674346924
Starting epoch  29


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.7375314235687256  discriminative loss:  0.4333227276802063
Starting epoch  30


100%|██████████| 967/967 [02:34<00:00,  6.26it/s]


generative loss:  3.76383900642395  discriminative loss:  0.4192488193511963
Starting epoch  31


100%|██████████| 967/967 [02:19<00:00,  6.92it/s]


generative loss:  3.771596908569336  discriminative loss:  0.4550090730190277
Starting epoch  32


100%|██████████| 967/967 [02:20<00:00,  6.86it/s]


generative loss:  4.434063911437988  discriminative loss:  0.4313001334667206
Starting epoch  33


100%|██████████| 967/967 [02:27<00:00,  6.54it/s]


generative loss:  4.3126349449157715  discriminative loss:  0.5513419508934021
Starting epoch  34


100%|██████████| 967/967 [02:24<00:00,  6.67it/s]


generative loss:  1.9129252433776855  discriminative loss:  0.7590307593345642
Starting epoch  35


100%|██████████| 967/967 [02:22<00:00,  6.79it/s]


generative loss:  6.390647888183594  discriminative loss:  0.3855164051055908
Starting epoch  36


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.333792209625244  discriminative loss:  0.42432528734207153
Starting epoch  37


100%|██████████| 967/967 [02:27<00:00,  6.56it/s]


generative loss:  5.659630298614502  discriminative loss:  0.5074780583381653
Starting epoch  38


100%|██████████| 967/967 [02:34<00:00,  6.27it/s]


generative loss:  2.98885440826416  discriminative loss:  0.5041288733482361
Starting epoch  39


100%|██████████| 967/967 [02:30<00:00,  6.41it/s]


generative loss:  4.600791931152344  discriminative loss:  0.38261350989341736
Starting epoch  40


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  2.333014488220215  discriminative loss:  0.5138341784477234
Starting epoch  41


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  2.6655759811401367  discriminative loss:  0.5427317023277283
Starting epoch  42


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  5.278759002685547  discriminative loss:  0.39750105142593384
Starting epoch  43


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.153359413146973  discriminative loss:  0.3764536678791046
Starting epoch  44


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.2994437217712402  discriminative loss:  0.44719237089157104
Starting epoch  45


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.2188687324523926  discriminative loss:  0.44397836923599243
Starting epoch  46


100%|██████████| 967/967 [02:29<00:00,  6.47it/s]


generative loss:  5.033209800720215  discriminative loss:  0.4033500552177429
Starting epoch  47


100%|██████████| 967/967 [02:30<00:00,  6.44it/s]


generative loss:  6.189250469207764  discriminative loss:  0.4111863970756531
Starting epoch  48


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.595890998840332  discriminative loss:  0.3822370767593384
Starting epoch  49


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.4842872619628906  discriminative loss:  0.43954089283943176
Starting epoch  50


100%|██████████| 967/967 [02:33<00:00,  6.30it/s]


generative loss:  4.583714485168457  discriminative loss:  0.4247097969055176
Starting epoch  51


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  5.0301923751831055  discriminative loss:  0.3915993869304657
Starting epoch  52


100%|██████████| 967/967 [02:36<00:00,  6.20it/s]


generative loss:  5.845663070678711  discriminative loss:  0.5605529546737671
Starting epoch  53


100%|██████████| 967/967 [02:34<00:00,  6.26it/s]


generative loss:  2.7828118801116943  discriminative loss:  0.5290855169296265
Starting epoch  54


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.047154426574707  discriminative loss:  0.3648183345794678
Starting epoch  55


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.734324932098389  discriminative loss:  0.3587886393070221
Starting epoch  56


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.138100624084473  discriminative loss:  0.4719215929508209
Starting epoch  57


100%|██████████| 967/967 [02:24<00:00,  6.68it/s]


generative loss:  5.0411787033081055  discriminative loss:  0.36282435059547424
Starting epoch  58


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.6591596603393555  discriminative loss:  0.38219064474105835
Starting epoch  59


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.8508195877075195  discriminative loss:  0.43799787759780884
Starting epoch  60


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.689577102661133  discriminative loss:  0.38707736134529114
Starting epoch  61


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.7616944313049316  discriminative loss:  0.4521394371986389
Starting epoch  62


100%|██████████| 967/967 [02:36<00:00,  6.16it/s]


generative loss:  5.691154479980469  discriminative loss:  0.3762694001197815
Starting epoch  63


100%|██████████| 967/967 [02:36<00:00,  6.19it/s]


generative loss:  6.028400421142578  discriminative loss:  0.3922988772392273
Starting epoch  64


100%|██████████| 967/967 [02:35<00:00,  6.23it/s]


generative loss:  1.9616658687591553  discriminative loss:  0.6257632970809937
Starting epoch  65


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.0615885257720947  discriminative loss:  0.6598601341247559
Starting epoch  66


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.542686700820923  discriminative loss:  0.5493724346160889
Starting epoch  67


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  3.724461078643799  discriminative loss:  0.3799181282520294
Starting epoch  68


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.1081223487854  discriminative loss:  0.4048711657524109
Starting epoch  69


100%|██████████| 967/967 [02:40<00:00,  6.04it/s]


generative loss:  3.144423484802246  discriminative loss:  0.41483423113822937
Starting epoch  70


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.14066743850708  discriminative loss:  0.3743675649166107
Starting epoch  71


100%|██████████| 967/967 [02:32<00:00,  6.35it/s]


generative loss:  3.356078624725342  discriminative loss:  0.5736881494522095
Starting epoch  72


100%|██████████| 967/967 [02:32<00:00,  6.35it/s]


generative loss:  4.872702598571777  discriminative loss:  0.3813190460205078
Starting epoch  73


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]


generative loss:  4.429702281951904  discriminative loss:  0.3871072828769684
Starting epoch  74


100%|██████████| 967/967 [03:21<00:00,  4.79it/s]

My program didn't finish on epoch 80. It maybe because of internet disconection.

In [None]:
%%shell
jupyter nbconvert --to html /content/Lab09.ipynb

[NbConvertApp] Converting notebook /content/Lab09.ipynb to html
[NbConvertApp] ERROR | Notebook JSON is invalid: Additional properties are not allowed (u'metadata' was unexpected)

Failed validating u'additionalProperties' in stream:

On instance[u'cells'][11][u'outputs'][0]:
{u'metadata': {u'tags': None},
 u'name': u'stdout',
 u'output_type': u'stream',
 u'text': u'Starting epoch  0\n'}
[NbConvertApp] Writing 347357 bytes to /content/Lab09.html


