In [1]:
# Tensorflow 2.3でメモリを指定及び節約して使うためのおまじない。
import tensorflow as tf
import random
import numpy as np

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.set_visible_devices(physical_devices[9], 'GPU')
tf.config.experimental.set_memory_growth(physical_devices[9], True)

In [2]:
import os
from glob import glob
from tensorflow.keras.preprocessing.image import load_img, img_to_array, smart_resize

INPUT_SHAPE = (64, 64)

data = np.array([])
train_path = os.path.abspath("pokemon_jpg/")
images = glob(os.path.join(train_path, "*.*"))

data = np.stack([img_to_array(load_img(img).resize(INPUT_SHAPE)) for img in images]) / 127.5 - 1

In [3]:
from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D
from keras.layers import Layer

from keras.models import Model, Sequential
from keras import backend as K
from keras.optimizers import Adam, RMSprop
from keras.callbacks import ModelCheckpoint 
from keras.utils import plot_model
from keras.initializers import RandomNormal

from functools import partial

import numpy as np
import json
import os
import pickle
import matplotlib.pyplot as plt

In [4]:
discriminator_input_shape = (64, 64, 3)
discriminator_conv_filters = [16,32,64,128]
discriminator_conv_kernel_size = [5,5,5,5]
discriminator_conv_strides = [2,2,2,2]

initial_layer_shape = (4, 4, 128)
generator_conv_filters = [64,32,16,3]
generator_conv_kernel_size = [5,5,5,5]
generator_conv_strides = [1,1,1,1]


def generator(z_dim, initial_layer_shape, generator_conv_filters, generator_conv_kernel_size, generator_conv_strides):
    generator_input = Input(shape=(z_dim,))
    x = generator_input
    x = Dense(np.prod(initial_layer_shape), kernel_initializer="he_normal")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Reshape(initial_layer_shape)(x)
    for i in range(len(generator_conv_kernel_size)):
        x = UpSampling2D()(x)
        x = Conv2D(
            filters=generator_conv_filters[i],
            kernel_size=generator_conv_kernel_size[i],
            padding="same",
            kernel_initializer="he_normal")(x)
        
#         x = Conv2DTranspose(
#             filters=generator_conv_filters[i],
#             kernel_size=generator_conv_kernel_size[i],
#             padding="same",
#             strides=generator_conv_strides[i],
#             kernel_initializer="he_normal")(x)

        if i < len(generator_conv_kernel_size) - 1:
            x = BatchNormalization(momentum=0.8)(x)
            x = LeakyReLU(alpha=0.2)(x)
        else:
            generator_output = Activation("tanh")(x)
            
    return Model(generator_input, generator_output)

def discriminator(discriminator_input_shape, discriminator_conv_filters, discriminator_conv_kernel_size, discriminator_conv_strides):
    discriminator_input = Input(discriminator_input_shape)
    x = discriminator_input
    
    for i in range(len(generator_conv_kernel_size)):
        x = Conv2D(filters=discriminator_conv_filters[i],
                   kernel_size=discriminator_conv_kernel_size[i],
                   padding="same",
                   strides=discriminator_conv_strides[i],
                   kernel_initializer="he_normal")(x)
        x = LeakyReLU(alpha=0.2)(x)
    
    x = Flatten()(x)
    
    discriminator_output = Dense(1, activation=None, kernel_initializer="he_normal")(x)
    
    return Model(discriminator_input, discriminator_output)

def wasserstein(y_true, y_pred):
        return -K.mean(y_true * y_pred)


In [5]:
z_dims = 100

disc = discriminator(discriminator_input_shape, discriminator_conv_filters, discriminator_conv_kernel_size, discriminator_conv_strides)

disc.compile(optimizer=Adam(lr=0.0002),
             loss=wasserstein)

disc.trainable = False

gen = generator(z_dims, initial_layer_shape, generator_conv_filters, generator_conv_kernel_size, generator_conv_strides)

gan_input = Input(shape=(z_dims,))
gan_output = disc(gen(gan_input))
gan = Model(gan_input, gan_output)
gan.compile(optimizer=Adam(lr=0.0002),
            loss=wasserstein)

In [6]:
disc.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 64, 64, 3)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 32, 32, 16)        1216      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 32, 32, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 32)        12832     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 64)          51264     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 64)         

In [7]:
gen.summary()

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 2048)              206848    
_________________________________________________________________
batch_normalization (BatchNo (None, 2048)              8192      
_________________________________________________________________
reshape (Reshape)            (None, 4, 4, 128)         0         
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 64)          204864    
_________________________________________________________________
batch_normalization_1 (Batch (None, 8, 8, 64)         

In [8]:
gan.summary()

Model: "functional_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
functional_3 (Functional)    (None, 64, 64, 3)         485603    
_________________________________________________________________
functional_1 (Functional)    (None, 1)                 272289    
Total params: 757,892
Trainable params: 481,283
Non-trainable params: 276,609
_________________________________________________________________


In [None]:
import matplotlib.pyplot as plt
import wandb

wandb.init(project="深層学習_最終課題_upsampling_高解像度")

d_history = []
g_history = []
save_fig_path = os.path.abspath("gan_pokemon_upsample_2")
save_model_path = os.path.abspath("model_upsample_2")


def train(data, n_epochs=6000, batch_size=256, crip_threshold=0.01):
    
    batch_per_epoch = int(data.shape[0] / batch_size)
    valid = np.ones((batch_size, 1))
    fake = -np.ones((batch_size, 1))
    
    for epoch in range(n_epochs):
        
        for i in range(batch_per_epoch):
            
            for j in range(5):
                idx = np.random.randint(0, len(data[0]), batch_size)
                true_imgs = data[idx]

                z = np.random.normal(0, 1, (batch_size, z_dims))
                gen_imgs = gen.predict(z)

                d_loss_real = disc.train_on_batch(true_imgs, valid)
                d_loss_fake = disc.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * (d_loss_real + d_loss_fake)

                for l in disc.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -crip_threshold, crip_threshold) for w in weights]
                    l.set_weights(weights)
            

            z = np.random.normal(0, 1, (batch_size, z_dims))
            g_loss = gan.train_on_batch(z, valid)
            g_history.append(g_loss)
            
            print ("%d [D loss: (%.3f)(R %.3f, F %.3f)]  [G loss: %.3f] " % (epoch, d_loss, d_loss_real, d_loss_fake, g_loss))
            
            wandb.log({"d_loss": d_loss,
                       "g_loss": g_loss})
        
        if epoch % 10 == 0:
            sample_images(epoch)
            save_model(epoch)
        
        K.set_value(disc.optimizer.learning_rate, K.get_value(disc.optimizer.learning_rate)*0.999)
        K.set_value(gan.optimizer.learning_rate, K.get_value(gan.optimizer.learning_rate)*0.999)
    
def sample_images(epoch):
    r, c = 5, 5
    z = np.random.normal(0, 1, (25, z_dims))
    gen_imgs = gen.predict(z)


    gen_imgs = 0.5 * (gen_imgs + 1)
    gen_imgs = np.clip(gen_imgs, 0, 1)

    fig, axs = plt.subplots(r, c, figsize=(15,15))
    cnt = 0

    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig(os.path.join(save_fig_path, f"epoch_{epoch}.png"))
    plt.close()

def save_model(epoch):
    save_folder = os.path.join(save_model_path, f"model_epoch{epoch}")
    os.makedirs(save_folder, exist_ok=True)
    disc.save(os.path.join(save_folder, 'discriminator.h5'))
    gen.save(os.path.join(save_folder, 'generator.h5'))
    gan.save(os.path.join(save_folder, 'gan.h5'))
    
train(data)

[34m[1mwandb[0m: Currently logged in as: [33mtomato-ai[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


0 [D loss: (-0.000)(R -0.001, F 0.001)]  [G loss: -0.001] 
0 [D loss: (-0.000)(R -0.001, F 0.001)]  [G loss: -0.001] 
1 [D loss: (-0.000)(R -0.001, F 0.000)]  [G loss: -0.001] 
1 [D loss: (-0.001)(R -0.002, F 0.001)]  [G loss: -0.001] 
2 [D loss: (-0.001)(R -0.002, F 0.001)]  [G loss: -0.001] 
2 [D loss: (-0.001)(R -0.002, F 0.001)]  [G loss: -0.000] 
3 [D loss: (-0.001)(R -0.002, F 0.000)]  [G loss: -0.000] 
3 [D loss: (-0.001)(R -0.002, F -0.000)]  [G loss: 0.000] 
4 [D loss: (-0.002)(R -0.002, F -0.001)]  [G loss: 0.001] 
4 [D loss: (-0.002)(R -0.003, F -0.002)]  [G loss: 0.001] 
5 [D loss: (-0.003)(R -0.003, F -0.003)]  [G loss: 0.002] 
5 [D loss: (-0.004)(R -0.004, F -0.004)]  [G loss: 0.003] 
6 [D loss: (-0.005)(R -0.005, F -0.005)]  [G loss: 0.006] 
6 [D loss: (-0.005)(R -0.005, F -0.005)]  [G loss: 0.009] 
7 [D loss: (-0.004)(R -0.006, F -0.003)]  [G loss: 0.012] 
7 [D loss: (-0.002)(R -0.006, F 0.001)]  [G loss: 0.013] 
8 [D loss: (-0.000)(R -0.007, F 0.006)]  [G loss: 0.011] 