In [1]:
# Tensorflow 2.3でメモリを指定及び節約して使うためのおまじない。
import tensorflow as tf
import random
import numpy as np
import matplotlib.pyplot as plt

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.set_visible_devices(physical_devices[4], 'GPU')
tf.config.experimental.set_memory_growth(physical_devices[4], True)

In [2]:
import os
from glob import glob
from tensorflow.keras.preprocessing.image import load_img, img_to_array, smart_resize

INPUT_SHAPE = (64, 64)

data = np.array([])
train_path = os.path.abspath("pokemon_jpg/")
images = glob(os.path.join(train_path, "*.*"))

data = np.stack([img_to_array(load_img(img).resize(INPUT_SHAPE)) for img in images]) / 127.5 - 1

In [3]:
from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D, ReLU
from keras.layers import Layer

from keras.models import Model, Sequential
from keras import backend as K
from keras.optimizers import Adam, RMSprop
from keras.callbacks import ModelCheckpoint 
from keras.utils import plot_model
from keras.initializers import RandomNormal

from functools import partial

import numpy as np
import json
import os
import pickle
import matplotlib.pyplot as plt

In [4]:
discriminator_input_shape = (64, 64, 3)
discriminator_conv_filters = [16,32,64,128]
discriminator_conv_kernel_size = [4,4,4,4]
discriminator_conv_strides = [2,2,2,2]

initial_layer_shape = (4, 4, 128)
generator_conv_filters = [64,32,16,3]
generator_conv_kernel_size = [4,4,4,4]
generator_conv_strides = [1,1,1,1]


def generator(initial_layer_shape, generator_conv_filters, generator_conv_kernel_size, generator_conv_strides, z_dim=100):
    generator_input = Input(shape=(z_dim,))
    x = generator_input
    x = Dense(np.prod(initial_layer_shape))(x)
    x = BatchNormalization()(x)
    x = Reshape(initial_layer_shape)(x)
    for i in range(len(generator_conv_kernel_size)):
        x = UpSampling2D()(x)
        x = Conv2D(
            filters=generator_conv_filters[i],
            kernel_size=generator_conv_kernel_size[i],
            padding="same")(x)

        if i < len(generator_conv_kernel_size) - 1:
            x = BatchNormalization()(x)
            x = ReLU()(x)
        else:
            generator_output = Activation("tanh")(x)
            
    return Model(generator_input, generator_output)

def discriminator(discriminator_input_shape, discriminator_conv_filters, discriminator_conv_kernel_size, discriminator_conv_strides):
    discriminator_input = Input(discriminator_input_shape)
    x = discriminator_input
    
    for i in range(len(discriminator_conv_filters)):
        x = Conv2D(filters=discriminator_conv_filters[i],
                   kernel_size=discriminator_conv_kernel_size[i],
                   padding="same",
                   strides=discriminator_conv_strides[i])(x)
        if i != 0:
            x = BatchNormalization()(x)

        x = LeakyReLU(alpha=0.2)(x)
    
    x = Flatten()(x)
    
    discriminator_output = Dense(1, activation="sigmoid")(x)
    
    return Model(discriminator_input, discriminator_output)


In [5]:
dis = discriminator(discriminator_input_shape, discriminator_conv_filters, discriminator_conv_kernel_size, discriminator_conv_strides)
dis.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 64, 64, 3)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 32, 32, 16)        784       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 32, 32, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 32)        8224      
_________________________________________________________________
batch_normalization (BatchNo (None, 16, 16, 32)        128       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 64)         

In [6]:
gen = generator(initial_layer_shape,generator_conv_filters,generator_conv_kernel_size,generator_conv_strides)
gen.summary()

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 2048)              206848    
_________________________________________________________________
batch_normalization_3 (Batch (None, 2048)              8192      
_________________________________________________________________
reshape (Reshape)            (None, 4, 4, 128)         0         
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 64)          131136    
_________________________________________________________________
batch_normalization_4 (Batch (None, 8, 8, 64)         

In [7]:
z_dims = 100

disc = discriminator(discriminator_input_shape, discriminator_conv_filters, discriminator_conv_kernel_size, discriminator_conv_strides)

disc.compile(optimizer=Adam(lr=0.0002),
             loss="binary_crossentropy")

disc.trainable = False

gen = generator(initial_layer_shape, generator_conv_filters, generator_conv_kernel_size, generator_conv_strides)

gan_input = Input(shape=(z_dims,))
gan_output = disc(gen(gan_input))
gan = Model(gan_input, gan_output)
gan.compile(optimizer=RMSprop(lr=0.0002),
            loss="binary_crossentropy")

In [8]:
gan.summary()

Model: "functional_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
functional_7 (Functional)    (None, 64, 64, 3)         388403    
_________________________________________________________________
functional_5 (Functional)    (None, 1)                 175985    
Total params: 564,388
Trainable params: 384,083
Non-trainable params: 180,305
_________________________________________________________________


In [9]:
import matplotlib.pyplot as plt
import wandb
import keras.backend as K

wandb.init(project="深層学習_最終課題_DCGAN")

d_history = []
g_history = []
save_fig_path = os.path.abspath("gan_pokemon_dc")
save_model_path = os.path.abspath("model_dc")

def make_noisy_label(labels, prob):
    num = int(prob * labels.shape[0])
    flipped = np.random.choice([i for i in range(labels.shape[0])], size=num)
    
    for i in flipped:
        labels[i] = 1 - labels[i]
    return labels

def train(data, n_epochs=6000, batch_size=256):
    
    batch_per_epoch = int(data.shape[0] / batch_size)
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    for epoch in range(n_epochs):
        
        for i in range(batch_per_epoch):
            
            idx = np.random.randint(0, len(data[0]), batch_size)
            true_imgs = data[idx]

            z = np.random.normal(0, 1, (batch_size, z_dims))
            gen_imgs = gen.predict(z)

            noisy_label_true = make_noisy_label(valid, 0.05)
            d_loss_real = disc.train_on_batch(true_imgs, noisy_label_true)
            noisy_label_fake = make_noisy_label(fake, 0.05)
            d_loss_fake = disc.train_on_batch(gen_imgs, noisy_label_fake)
            d_loss = 0.5 * (d_loss_real + d_loss_fake)
            

            z = np.random.normal(0, 1, (batch_size, z_dims))
            g_loss = gan.train_on_batch(z, valid)
            g_history.append(g_loss)
            
            print ("%d [D loss: (%.3f)(R %.3f, F %.3f)]  [G loss: %.3f] " % (epoch, d_loss, d_loss_real, d_loss_fake, g_loss))
            
            wandb.log({"d_loss": d_loss,
                       "g_loss": g_loss,
                       "learning_rate": K.get_value(gan.optimizer.learning_rate)})
        
        if epoch % 10 == 0:
            sample_images(epoch)
            save_model(epoch)
            
        K.set_value(disc.optimizer.learning_rate, K.get_value(disc.optimizer.learning_rate)*0.999)
        K.set_value(gan.optimizer.learning_rate, K.get_value(gan.optimizer.learning_rate)*0.999)
        
    
def sample_images(epoch):
    r, c = 5, 5
    z = np.random.normal(0, 1, (25, z_dims))
    gen_imgs = gen.predict(z)


    gen_imgs = 0.5 * (gen_imgs + 1)
    gen_imgs = np.clip(gen_imgs, 0, 1)

    fig, axs = plt.subplots(r, c, figsize=(15,15))
    cnt = 0

    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig(os.path.join(save_fig_path, f"epoch_{epoch}.png"))
    plt.close()

def save_model(epoch):
    save_folder = os.path.join(save_model_path, f"model_epoch{epoch}")
    os.makedirs(save_folder, exist_ok=True)
    disc.save(os.path.join(save_folder, 'discriminator.h5'))
    gen.save(os.path.join(save_folder, 'generator.h5'))
    gan.save(os.path.join(save_folder, 'gan.h5'))
    
train(data)

[34m[1mwandb[0m: Currently logged in as: [33mtomato-ai[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


0 [D loss: (1.018)(R 1.323, F 0.712)]  [G loss: 0.688] 
0 [D loss: (0.617)(R 0.546, F 0.687)]  [G loss: 0.683] 
1 [D loss: (0.535)(R 0.422, F 0.648)]  [G loss: 0.682] 
1 [D loss: (0.502)(R 0.460, F 0.543)]  [G loss: 0.688] 
2 [D loss: (0.551)(R 0.536, F 0.566)]  [G loss: 0.692] 
2 [D loss: (0.635)(R 0.554, F 0.716)]  [G loss: 0.695] 
3 [D loss: (0.726)(R 0.667, F 0.785)]  [G loss: 0.693] 
3 [D loss: (0.727)(R 0.668, F 0.786)]  [G loss: 0.692] 
4 [D loss: (0.753)(R 0.668, F 0.837)]  [G loss: 0.690] 
4 [D loss: (0.754)(R 0.698, F 0.811)]  [G loss: 0.684] 
5 [D loss: (0.773)(R 0.720, F 0.827)]  [G loss: 0.682] 
5 [D loss: (0.711)(R 0.677, F 0.746)]  [G loss: 0.681] 
6 [D loss: (0.715)(R 0.687, F 0.743)]  [G loss: 0.681] 
6 [D loss: (0.726)(R 0.698, F 0.755)]  [G loss: 0.684] 
7 [D loss: (0.719)(R 0.689, F 0.749)]  [G loss: 0.681] 
7 [D loss: (0.729)(R 0.683, F 0.775)]  [G loss: 0.684] 
8 [D loss: (0.743)(R 0.685, F 0.801)]  [G loss: 0.684] 
8 [D loss: (0.758)(R 0.719, F 0.797)]  [G loss: 