# Danmen-GAN

断面二次モーメントを上げ下げできるヤツ

<img src="https://raw.githubusercontent.com/p-geon/DanmenGAN/master/Gainen.png"></img>


# 基本設定

In [0]:
# Mount data
from google.colab import drive
drive.mount('/content/gdrive')

%cd /content/gdrive/My\ Drive/GAN/exp

! pip install tensorflow==2.2
! pip install tensorflow_addons==0.6.0 --no-deps

import sys # バージョン確認用
import datetime # パラメータと画像保存用

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

from skimage.io import imsave
from skimage import img_as_ubyte # 警告回避用

print(tf.__version__)

! pip install StealthFlow==0.0.13
from stealthflow.fid import FIDNumpy, FIDTF

# 断面二次モーメントを求めるグラフ

- `calc_second_moment_of_area`: 画像 `4-D Tensor` を入力とし、`I_x`, `I_y`, `I_r` (`1-D Tensor`x3)を出力するグラフ
- `calc_second_moment_of_area_for_generated_images`: 外側からこの関数を TensorFlow の関数として使えるようにした

In [0]:
class SecondMomentOfArea:
    def __init__(self, img_shape=(28, 28)):
        """
        1. ピクセルの中心が重心, 0.5^27.5
        2. (全て1の場合) 画像の中心(14, 14)が重心
        となるようにする
        """
        # [0, 1, 2, ..., 27]
        arange_x = 0.5 + np.arange(0, img_shape[0], 1) # 縦のピクセル数
        arange_y = 0.5 + np.arange(0, img_shape[1], 1) # 横のピクセル数

        # ピクセルの中心が重心, 0.5^27.5
        distance_vector_x = np.asarray([0.5+d for d in range(img_shape[1])])
        distance_matrix_x = np.tile(distance_vector_x, (img_shape[0], 1)) # xからの距離 (MNISTの端 → 27.5)
        distance_matrix_y = distance_matrix_x.T # yからの距離

        """
        正規化用マトリックス(この設定における最大 = 全ピクセル使用)
        """
        # 縦方向(y)に対する断面二次モーメント(I_x)を正規化するため、最大の断面二次モーメントを求める
        matrix_for_norm_I_x = np.tile(np.abs(arange_y - img_shape[0]/2.0), (img_shape[1], 1)).T
        norm_I_x = np.sum(matrix_for_norm_I_x)

        # 横方向(x), I_y
        matrix_for_norm_I_y = np.tile(np.abs(arange_x - img_shape[1]/2.0), (img_shape[0], 1)).T
        norm_I_y = np.sum(matrix_for_norm_I_y)
         
        """
        to TFconstant
        """
        self.arange_x = tf.constant(arange_x, dtype=tf.float32) # (28, )
        self.arange_y = tf.constant(arange_y, dtype=tf.float32) # (28,)
        self.distance_matrix_x = tf.constant(distance_matrix_x[np.newaxis, :, :, np.newaxis], dtype=tf.float32) # (1, 28, 28, 1)
        self.distance_matrix_y = tf.constant(distance_matrix_y[np.newaxis, :, :, np.newaxis], dtype=tf.float32) #(1, 28, 28, 1)
        self.norm_I_x = tf.constant(norm_I_x, dtype=tf.float32) #()
        self.norm_I_y = tf.constant(norm_I_y, dtype=tf.float32) #()

    def calc_second_moment_of_area(self, img): # (None, 28, 28, 1)
        """
        断面二次モーメントの計算
        """

        """
        中立軸の計算
        """
        # 密度。ゼロじゃない画素の割合　
        density = (tf.reduce_sum(img, axis=[1, 2], keepdims=True)/(img.shape[1]*img.shape[2]))
        # (1, 28, 28, 1) x (None, 28, 28, 1) -> (None, 28, 28, 1)
        x_moment = tf.math.divide_no_nan(tf.math.multiply(self.distance_matrix_x, img), density) # ゼロ除算回避付
        y_moment = tf.math.divide_no_nan(tf.math.multiply(self.distance_matrix_y, img), density)

        # (None, 28, 28, 1) -> (None, )
        neutral_axis_x = tf.math.reduce_mean(x_moment, axis=[1, 2])
        neutral_axis_y = tf.math.reduce_mean(y_moment, axis=[1, 2])

        """
        断面二次モーメント (縦)
        I_x = ∫_A y^2 dA
        """
        # sub: (None, 28, ) - (None, ) -> abs: (None, 28)
        dy = tf.math.abs(self.arange_y - neutral_axis_y)
        # (None, 28) -> (None, 1, 28)
        dy = tf.reshape(dy, shape=[-1, img.shape[1], 1])
        # (None, 1, 28) -> (None, 28, 28)
        matrix_x = tf.tile(dy, multiples=[1, 1, img.shape[2]])
        # (None, 28, 28) -> (None, 28, 28, 1)
        matrix_x = tf.expand_dims(matrix_x, 3)
        # (None, 28, 28, 1)x(None, 28, 28, 1) -> (None, 28, 28, 1) -> (None,)
        I_x = tf.math.reduce_sum(tf.math.multiply(matrix_x, img), axis=[1, 2])/self.norm_I_x

        """
        断面二次モーメント (横)
        I_y = ∫_A x^2 dA
        """
        # sub: (None, 28, ) - (None, ) -> abs: (None, 28)
        dx = tf.math.abs(self.arange_x - neutral_axis_x)
        # (None, 28) -> (None, 28, 1)
        dx = tf.reshape(dx, shape=[-1, 1, img.shape[2]])
        # (None, 1, 28) -> (None, 28, 28)
        matrix_y = tf.tile(dx, multiples=[1, img.shape[1], 1])
        # (None, 28, 28) -> (None, 28, 28, 1)
        matrix_y = tf.expand_dims(matrix_y, 3)
        # (None, 28, 28, 1)x(None, 28, 28, 1) -> (None, 28, 28, 1) -> (None,)
        I_y = tf.math.reduce_sum(tf.math.multiply(matrix_y, img), axis=[1, 2])/self.norm_I_y
        """
        断面二次極モーメント (正規化のため 2.0 で割る)
        """
        I_r = (I_x + I_y)/2.0

        I_x = tf.keras.layers.Lambda(lambda x: x, name="I_x")(I_x) # 視認性を上げるための恒等関数
        I_y = tf.keras.layers.Lambda(lambda x: x, name="I_y")(I_y) # 視認性を上げるための恒等関数
        I_r = tf.keras.layers.Lambda(lambda x: x, name="I_r")(I_r) # 視認性を上げるための恒等関数

        return I_x, I_y, I_r

    @tf.function
    def calc_second_moment_of_area_for_generated_images(self, img):
        return self.calc_second_moment_of_area(img)

# GAN

Generator, Discirminator, Combined の三つで構成されている。

In [0]:
def build_generator(params, smoa):
    # Noise
    z = z_in = tf.keras.layers.Input(shape=(params.NOISE_DIM, ), name="noise")

    # (NOISE_DIM, ) -> (1024, )
    x = tf.keras.layers.Dense(1024)(z)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    x = tf.keras.layers.BatchNormalization(momentum=0.8)(x)

    # (1024, ) -> (7*7*64, ) -> (7, 7, 64)
    x = tf.keras.layers.Dense(7*7*64)(z)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    x = tf.keras.layers.BatchNormalization(momentum=0.8)(x)
    x = tf.keras.layers.Reshape(target_shape=(7, 7, 64))(x)

    # (7, 7, 64) -> (14, 14, 32)
    x = tf.keras.layers.Conv2DTranspose(32, kernel_size=(5, 5)
        , padding='same', strides=(2, 2), use_bias=False, activation=None)(x)
    x = tf.keras.layers.BatchNormalization(momentum=0.8)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

    # (14, 14, 128) -> (28, 28, 1)
    x = tf.keras.layers.Conv2DTranspose(1, kernel_size=(5, 5)
        , padding='same', strides=(2, 2), use_bias=False, activation=None)(x)
    img = tf.math.tanh(x)
    y = tf.keras.layers.Lambda(lambda x: x, name="generated_image")(img) # img は後ろで使うので y に変数名を変更しておく

    """
    断面二次モーメントの計算 (ResNet みたいなグラフになる)
        線画薄い場所＝面積が減っている とする。
        実際のプリントとは異なってしまうが、今回は仕方ない。
    """
    # range: [-1.0, 1.0] -> [0.0, 1.0]
    img = (img + 1.0)/2.0
    I_x, I_y, I_r = smoa.calc_second_moment_of_area(img)

    return tf.keras.Model(inputs=z_in, outputs=[y, I_x, I_y, I_r])

def build_discriminator():
    # real or generated
    x = x_in = tf.keras.layers.Input(shape=(28, 28, 1))

    # (28, 28, 1) -> (14, 14, 32)
    x = tf.keras.layers.Conv2D(32, kernel_size=(5, 5), strides=(2, 2), padding='same')(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

    # (14, 14, 32) -> (7, 7, 64)
    x = tf.keras.layers.Conv2D(64, kernel_size=(5, 5), strides=(2, 2), padding='same')(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

    # (7, 7, 64) -> (7*7*64, ) -> (1024, )
    x = tf.keras.layers.Reshape(target_shape=(7*7*64, ))(x)
    x = tf.keras.layers.Dense(1024)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

    # (1024, ) -> (1, ), range:[0, 1]
    x = tf.keras.layers.Dense(1)(x)
    p = tf.math.sigmoid(x)

    return tf.keras.Model(inputs=x_in, outputs=p)
    
#-----------------------------------------------------------------------------------------------------------------------------

class GAN:
    def __init__(self, params, smoa):
        self.params = params
        self.smoa = smoa
        self.build_model()

    def build_model(self):
        self.discriminator = build_discriminator()
        self.discriminator.compile(
              loss = self.params.discriminator_loss
            , optimizer = self.params.discriminator_optimizer
            , metrics =  self.params.discriminator_metrics
            )
        self.discriminator.trainable=False

        self.generator = build_generator(self.params, self.smoa)
        self.combined = self.build_combined()
        self.combined.compile(
              loss = self.params.generator_loss
            , optimizer = self.params.generator_optimizer
            , metrics = self.params.generator_metrics
            , loss_weights = self.params.generator_loss_weights
            )
      
    def build_combined(self):
        z_in = tf.keras.layers.Input(shape=(self.params.NOISE_DIM, ))

        y, I_x, I_y, I_r = self.generator(z_in)

        y = tf.keras.layers.Lambda(lambda x: x, name="generated_image")(y)
        I_x = tf.keras.layers.Lambda(lambda x: x, name="I_x")(I_x) # ロス用の名前付けレイヤー(恒等関数)
        I_y = tf.keras.layers.Lambda(lambda x: x, name="I_y")(I_y) # ロス用の名前付けレイヤー(恒等関数)
        I_r = tf.keras.layers.Lambda(lambda x: x, name="I_r")(I_r) # ロス用の名前付けレイヤー(恒等関数)

        p = self.discriminator(y)

        p = tf.keras.layers.Lambda(lambda x: x, name="possibility")(p) # ロス用の名前付けレイヤー(恒等関数)

        model = tf.keras.Model(inputs=z_in, outputs=[p, I_x, I_y, I_r])
        return model

# Custom Metrics: Second Moment of Area

断面二次モーメントを、`tf.keras` の評価時に出力できるようにしたもの。

In [0]:
class MetricsAverageSecondMomentOfArea(tf.keras.metrics.Metric):
    def __init__(self, name="average_second_moment_of_area", **kwargs):
        super(MetricsAverageSecondMomentOfArea, self).__init__(name=name, **kwargs)
        self.average_smoa = self.add_weight(name="average_second_moment_of_area", initializer="zeros")
    
    def update_state(self, y_true, y_pred):
        #y_true = tf.cast(y_true, tf.float32) # 断面二次モーメントの合計を求めるだけなので使わない
        y_pred = tf.cast(y_pred, tf.float32)
        self.average_smoa.assign_add(tf.reduce_mean(y_pred))

    def result(self):
        return self.average_smoa

# Custom Metrics: Frechet Inception Distance

一行で FID を計算するツールを作ったので、それを使った。

詳しくは https://note.com/hyper_pigeon/n/n9c5643413cd7 に書いてある。便利。

In [0]:
calc_fid = FIDNumpy(batch_size=50, scaling=True)

# generate samples / calc FID

- `generate_samples`: 結果確認用の画像生成とその保存、ついでにその断面二次モーメントも計算して結果の確認。`calc_second_moment_of_area_for_generated_images` はここで叩く。

- `calc_fid_for_generator`: FID の計算。とにかく遅い。


In [0]:
def generate_samples(params, GAN, iteration):
    """
    [img1, img2, img3] みたいなバッチを

    [img1, img2,
    img3, 0] みたいな画像にする
    """
    # Predict
    batch, I_x, I_y, I_r = GAN.generator.predict(params.FIXED_NOISE_FOR_PREDICT, batch_size=params.BATCH_SIZE)
    # range: [-1.0, 1.0] -> [0.0, 1.0]
    batch = batch / 2.0 + 0.5

    batch_tensor = tf.convert_to_tensor(batch)
    I_x, I_y, I_r = GAN.smoa.calc_second_moment_of_area_for_generated_images(batch_tensor) # ここで smoa.calc_second_moment_of_area を使い回す
    print(f"I_x:{I_x.numpy().mean():.3f}, I_y:{I_y.numpy().mean():.3f}, I_r:{I_r.numpy().mean():.3f}")

    length = int(np.sqrt(batch.shape[0]))+1
    width, height = length*batch.shape[1], length*batch.shape[2]
    img_buffer = np.zeros(shape=[width, height], dtype=np.float32)

    for h in range(length):
        for w in range(length):
            if(h*length+w >= batch.shape[0]):break
            img_buffer[h*batch.shape[2]:h*batch.shape[2]+batch.shape[2], w*batch.shape[1]:w*batch.shape[1]+batch.shape[1]] = batch[h*length+w, :, :, 0]

    # Output
    plt.imshow(img_buffer, cmap = "gray", vmin=0.0, vmax=1.0)
    plt.show()
    # Save
    imsave(fname=f"_{params.EXPERIMENTAL_NAME}_{iteration}_image.png", arr=img_as_ubyte(img_buffer))

def calc_fid_for_generator(params, GAN, MNIST):
    import time
    start_time = time.time()
    NUMDATA_FID = 5000

    orig = (MNIST.train_X[:NUMDATA_FID] / 2.0 + 0.5).astype(np.float32)

    print(orig.shape, orig.dtype)
    gens = np.zeros(shape=[NUMDATA_FID, 28, 28, 1])

    for i in range(NUMDATA_FID//params.BATCH_SIZE):

        NOISE = np.random.normal(0, 1, (params.BATCH_SIZE, params.NOISE_DIM))
        # Predict
        batch, I_x, I_y, I_r = GAN.generator.predict(NOISE, batch_size=params.BATCH_SIZE)
        # range: [-1.0, 1.0] -> [0.0, 1.0]
        batch = batch / 2.0 + 0.5
        gens[params.BATCH_SIZE*i:params.BATCH_SIZE*(i+1), :, :, :] = batch

    fid = FIDNumpy(batch_size=params.BATCH_SIZE, scaling=True)(gens, orig)
    print("FID", fid)
    print(f"Spent time: {time.time()-start_time}[s]")
    return fid

# Loss Utils

ロスなどを毎イテレーションおきに追加して、一定のスパンでその値の平均と標準偏差を求め、グラフ用にする。FID を毎イテレーション計算するのは時間がかかりすぎるので、グラフ作成時のみ取得する。

全然いい書き方が思いつかない。

In [0]:
class Loss:
    def __init__(self, params):

        self.iteration = []
        self.__initialize_losses()
        self.FID_list = []

        self.capsize = 7
        self.markersize = 7
        self.alpha = 0.5
        self.params = params
    
    def dump_FID(self, FID):
        self.FID_list.append(FID)

    def dump_loss_dicts(self, D_loss_real_dict, D_loss_fake_dict, G_loss_dict):

        self.acc_D_real.append(D_loss_real_dict["accuracy"])
        self.acc_D_fake.append(D_loss_fake_dict["accuracy"])
        
        self.loss_D_real.append(D_loss_real_dict["loss"])
        self.loss_D_fake.append(D_loss_fake_dict["loss"])
        self.loss_G_fake.append(G_loss_dict["possibility_loss"])

        self.MSE_I_x.append(G_loss_dict["I_x_mean_squared_error"])
        self.MSE_I_y.append(G_loss_dict["I_y_mean_squared_error"])
        self.MSE_I_r.append(G_loss_dict["I_r_mean_squared_error"])

        self.I_x.append(G_loss_dict["I_x_average_second_moment_of_area"])
        self.I_y.append(G_loss_dict["I_y_average_second_moment_of_area"])
        self.I_r.append(G_loss_dict["I_r_average_second_moment_of_area"])
        
    def show_loss(self, iteration):

        self.iteration.append(iteration)
        self.__summarize_losses()

        fig = plt.figure(figsize=(8,16), dpi=108)

        # GAN Accuracy
        ax = fig.add_subplot(6, 1, 1)
        ax = self.__arrange_graph(ax)
        ax.set_ylim([0.0, 100.0])
        ax.set_ylabel("%")
        ax.errorbar(self.iteration, 100.0*np.asarray(self.mean_acc_D_real), yerr=100.0*np.asarray(self.std_acc_D_real)
            , capsize=self.capsize, fmt='.-', markersize=self.markersize, ecolor='teal', markeredgecolor = "teal", color='teal', alpha=self.alpha, label="Accuracy:D_real_to_real")
        ax.errorbar(self.iteration, 100.0*np.asarray(self.mean_acc_D_fake), yerr=100.0*np.asarray(self.std_acc_D_fake)
            , capsize=self.capsize, fmt='.:', markersize=self.markersize, ecolor='lightcoral', markeredgecolor = "lightcoral", color='lightcoral', alpha=self.alpha, label="Accuracy:D_fake_to_fake")
        ax.legend()    

        # GAN loss
        ax = fig.add_subplot(6, 1, 2)
        ax = self.__arrange_graph(ax)
        ax.set_ylim([-5.0, 5.0])
        ax.set_ylabel("log-loss")
        ax.errorbar(self.iteration, np.log(self.mean_loss_D_real), yerr=np.log(self.std_loss_D_real)
            , capsize=self.capsize, fmt='.-', markersize=self.markersize, ecolor='teal', markeredgecolor = "teal", color='teal', alpha=self.alpha, label="loss:D_real_to_real")
        ax.errorbar(self.iteration, np.log(self.mean_loss_D_fake), yerr=np.log(self.std_loss_D_fake)
            , capsize=self.capsize, fmt='.:', markersize=self.markersize, ecolor='lightcoral', markeredgecolor = "lightcoral", color='lightcoral', alpha=self.alpha, label="loss:D_fake_to_fake")
        ax.errorbar(self.iteration, np.log(self.mean_loss_G_fake), yerr=np.log(self.std_loss_G_fake)
            , capsize=self.capsize, fmt='.-.', markersize=self.markersize, ecolor='orangered', markeredgecolor = "orangered", color='orangered', alpha=self.alpha, label="loss:G_fake_to_real")
        
        ax.legend()    

        # MSE loss
        ax = fig.add_subplot(6, 1, 3)
        ax = self.__arrange_graph(ax)
        ax.set_ylim([0.0, 1.0])
        ax.set_ylabel("MSE")
        ax.errorbar(self.iteration, self.mean_MSE_I_x, yerr=self.std_MSE_I_x
            , capsize=self.capsize, fmt='.-', markersize=self.markersize, ecolor='steelblue', markeredgecolor = "steelblue", color='steelblue', alpha=self.alpha, label="MSE:I_x")
        ax.errorbar(self.iteration, self.mean_MSE_I_y, yerr=self.std_MSE_I_y
            , capsize=self.capsize, fmt='.:', markersize=self.markersize, ecolor='salmon', markeredgecolor = "salmon", color='salmon', alpha=self.alpha, label="MSE:I_y")
        ax.errorbar(self.iteration, self.mean_MSE_I_r, yerr=self.std_MSE_I_r
            , capsize=self.capsize, fmt='.-.', markersize=self.markersize, ecolor='dimgray', markeredgecolor = "dimgray", color='dimgray', alpha=self.alpha, label="MSE:I_r")
        ax.legend()

        # I_x, I_y, I_r
        ax = fig.add_subplot(6, 1, 4)
        ax = self.__arrange_graph(ax)
        ax.set_ylim([0.0, 0.25])
        ax.set_ylabel("SMoA")
        ax.errorbar(self.iteration, self.mean_I_x, yerr=self.std_I_x
            , capsize=self.capsize, fmt='.-', markersize=self.markersize, ecolor='steelblue', markeredgecolor = "steelblue", color='steelblue', alpha=self.alpha, label="Ave:I_x")
        ax.errorbar(self.iteration, self.mean_I_y, yerr=self.std_I_y
            , capsize=self.capsize, fmt='.:', markersize=self.markersize, ecolor='salmon', markeredgecolor = "salmon", color='salmon', alpha=self.alpha, label="Ave:I_y")
        ax.errorbar(self.iteration, self.mean_I_r, yerr=self.std_I_r
            , capsize=self.capsize, fmt='.-.', markersize=self.markersize, ecolor='dimgray', markeredgecolor = "dimgray", color='dimgray', alpha=self.alpha, label="Ave:I_r")
        ax.legend()

        # FID
        ax = fig.add_subplot(6, 1, 5)
        ax.set_ylim([0.0, 200.0])
        ax = self.__arrange_graph(ax)
        ax.set_ylabel("FID")
        ax.plot(self.iteration, self.FID_list, marker='.', linestyle="-" , markersize=self.markersize, color='steelblue', label="FID")
        ax.legend()
        
        
        # FID vs I_x, I_y, I_r
        ax = fig.add_subplot(6, 1, 6)
        ax = self.__arrange_graph(ax)
        ax.set_ylim([0.0, 0.003])
        ax.set_ylabel("I/FID")
        ax.plot(self.iteration, self.mean_I_x/np.asarray(self.FID_list), marker='.', linestyle="-" , markersize=self.markersize, color='steelblue', label="I_x/FID")
        ax.plot(self.iteration, self.mean_I_y/np.asarray(self.FID_list), marker='.', linestyle=":" , markersize=self.markersize, color='salmon', label="I_y/FID")
        ax.plot(self.iteration, self.mean_I_r/np.asarray(self.FID_list), marker='.', linestyle="-." , markersize=self.markersize, color='dimgray', label="I_r/FID")
        ax.legend()

        plt.savefig(f'{self.params.EXPERIMENTAL_NAME}-{self.iteration}.png') # -----(2)
        plt.show()

    def __arrange_graph(self, ax):
        ax.set_xlabel('iteration(x1000)')
        ax.set_xlim(-0.5, 10+self.iteration[-1])
        ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(self.params.ITERATION_SPAN))
        return ax

    def __initialize_losses(self):

        self.__flush_losses()
        self.mean_acc_D_real, self.mean_acc_D_fake = [], []
        self.mean_loss_D_real, self.mean_loss_D_fake, self.mean_loss_G_fake = [], [], []
        self.mean_MSE_I_x, self.mean_MSE_I_y, self.mean_MSE_I_r = [], [], []
        self.mean_I_x, self.mean_I_y, self.mean_I_r = [], [], []

        self.std_acc_D_real, self.std_acc_D_fake = [], []
        self.std_loss_D_real, self.std_loss_D_fake, self.std_loss_G_fake = [], [], []
        self.std_MSE_I_x, self.std_MSE_I_y, self.std_MSE_I_r = [], [], []
        self.std_I_x, self.std_I_y, self.std_I_r = [], [], []

    def __flush_losses(self):

        self.acc_D_real, self.acc_D_fake = [], []
        self.loss_D_real, self.loss_D_fake, self.loss_G_fake = [], [], []
        self.MSE_I_x, self.MSE_I_y, self.MSE_I_r = [], [], []
        self.I_x, self.I_y, self.I_r = [], [], []
    
    def __summarize_losses(self):
        self.mean_acc_D_real.append(np.mean(self.acc_D_real))
        self.std_acc_D_real.append(np.std(self.acc_D_real))

        print(self.mean_acc_D_real, self.std_acc_D_real)

        self.mean_acc_D_fake.append(np.mean(self.acc_D_fake))
        self.std_acc_D_fake.append(np.std(self.acc_D_fake))

        self.mean_loss_D_real.append(np.mean(self.loss_D_real))
        self.std_loss_D_real.append(np.std(self.loss_D_real))

        self.mean_loss_D_fake.append(np.mean(self.loss_D_fake))
        self.std_loss_D_fake.append(np.std(self.loss_D_fake))

        self.mean_loss_G_fake.append(np.mean(self.loss_G_fake))
        self.std_loss_G_fake.append(np.std(self.loss_G_fake))

        self.mean_MSE_I_x.append(np.mean(self.MSE_I_x))
        self.std_MSE_I_x.append(np.std(self.MSE_I_x))

        self.mean_MSE_I_y.append(np.mean(self.MSE_I_y))
        self.std_MSE_I_y.append(np.std(self.MSE_I_y))
        
        self.mean_MSE_I_r.append(np.mean(self.MSE_I_r))
        self.std_MSE_I_r.append(np.std(self.MSE_I_r))

        self.mean_I_x.append(np.mean(self.I_x))
        self.std_I_x.append(np.std(self.I_x))

        self.mean_I_y.append(np.mean(self.I_y))
        self.std_I_y.append(np.std(self.I_y))
        
        self.mean_I_r.append(np.mean(self.I_r))
        self.std_I_r.append(np.std(self.I_r))

        self.__flush_losses()

    def calc_stat(self):
        return { 
                      "max_I_x": np.max(self.mean_I_x)
                    , "min_I_x": np.min(self.mean_I_x)
                    , "max_I_y": np.max(self.mean_I_y)
                    , "min_I_y": np.min(self.mean_I_y)
                    , "max_I_r": np.max(self.mean_I_r)
                    , "min_I_r": np.min(self.mean_I_r)

                    , "mean_I_x": np.mean(self.mean_I_x)
                    , "mean_I_y": np.mean(self.mean_I_y)
                    , "mean_I_r": np.mean(self.mean_I_r)

                    , "max_FID": np.max(self.FID_list)
                    , "min_FID": np.min(self.FID_list)

                    , "max_I_x_FID": np.max(np.asarray(self.mean_I_x)/np.asarray(self.FID_list))
                    , "max_I_y_FID": np.max(np.asarray(self.mean_I_x)/np.asarray(self.FID_list))
                    , "max_I_r_FID": np.max(np.asarray(self.mean_I_r)/np.asarray(self.FID_list))

                    , "min_I_x_FID": np.min(np.asarray(self.mean_I_x)/np.asarray(self.FID_list))
                    , "min_I_y_FID": np.min(np.asarray(self.mean_I_y)/np.asarray(self.FID_list))
                    , "min_I_r_FID": np.min(np.asarray(self.mean_I_r)/np.asarray(self.FID_list))
                    } 

# Train

- MNIST: MNIST のデータを格納したクラス
- Params: パラメータを格納したクラス

- trainer: 学習の本体

In [0]:
class MNIST:
    def __init__(self, params):
        """
        今回は train_X のみ使用
        (test_X で断面二次モーメントの強さ比較してもいいかも)
        """
        (train_X, _), (_, _) = tf.keras.datasets.mnist.load_data()
        train_X = train_X.astype(np.float32).reshape((-1, 28, 28, 1))/255.0
        self.train_X = train_X * 2.0 - 1.0 #range:[0.0, 1.0] -> [-1.0, 1.0]

        #self.tfdata = tf.data.Dataset.from_tensor_slices(train_X).batch(params.BATCH_SIZE).map(self.rotate_tf).repeat(params.NUM_EPOCHS).shuffle(20)
        self.tfdata = tf.data.Dataset.from_tensor_slices(self.train_X).batch(params.BATCH_SIZE).repeat(params.NUM_EPOCHS).shuffle(20)

class Params:
    def __init__(self):
        self.tf_version = tf.__version__
        self.sys_versionm = sys.version

        self.ITERATION_SPAN = 4000
        self.NUM_EPOCHS = 20
        self.BATCH_SIZE  = 50
        self.MAX_ITERATION = 60000//self.BATCH_SIZE

        self.NOISE_DIM   = 128

        self.discriminator_loss = 'binary_crossentropy'
        self.discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.discriminator_metrics = ['accuracy']
        
        self.generator_loss = {'possibility': 'binary_crossentropy', 'I_x': 'mean_squared_error', 'I_y': 'mean_squared_error', "I_r": "mean_squared_error"}
        self.generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.generator_metrics = ['accuracy' , 'mean_squared_error', MetricsAverageSecondMomentOfArea()]

        self.generator_loss_weights = {"possibility": 1.0, "I_x": 0.0, "I_y": 0.0, "I_r": 0.0} 

        EXP_TYPE = "GAN"
        self.EXPERIMENTAL_NAME = f"{EXP_TYPE}_{(datetime.datetime.utcnow() + datetime.timedelta(hours=9)).strftime('%Y-%m-%d %H:%M:%S')}"
        self.save_params() # ノイズ行列とか入ると視認性が下がるのでここで保存しておく

        """
        何回も使う行列はここで定義しておく
        """
        self.REAL = np.ones(shape=(self.BATCH_SIZE, 1))
        self.FAKE = np.zeros(shape=(self.BATCH_SIZE, 1))
        self.FIXED_NOISE_FOR_PREDICT = np.random.normal(0, 1, (self.BATCH_SIZE, self.NOISE_DIM))
        # [1]xBATCHSIZE, [0]xBATCHSIZE ベクトル を使い回すことで、SMoAロスに対する正解データとする
        self.opt_I_x = [self.REAL, self.FAKE][0] 
        self.opt_I_y = [self.REAL, self.FAKE][0]
        self.opt_I_r = [self.REAL, self.FAKE][0]

    def save_params(self):
        with open(f'_{self.EXPERIMENTAL_NAME}_params.md', 'w') as f:
            f.write(f"# Exp: {self.EXPERIMENTAL_NAME}\n")
            f.write("## params\n")
            for k, v in zip(self.__dict__.keys(), self.__dict__.values()):
                f.write(str(f"{k}, {v}")+"\n")

class Trainer:
    def __init__(self, params):
        self.params = params
        self.MNIST = MNIST(self.params)
        self.smoa = SecondMomentOfArea()
        self.GAN = GAN(self.params, self.smoa)
        self.Loss = Loss(self.params)

        # Visualize layers.
        tf.keras.utils.plot_model(self.GAN.generator, to_file='_model_generator.png', show_shapes=True)
        tf.keras.utils.plot_model(self.GAN.discriminator, to_file='_model_discriminator.png', show_shapes=True)
        self.GAN.combined._layers = [ # rename
            layer for layer in self.GAN.combined._layers if isinstance(layer, tf.keras.layers.Layer)
          ]
        tf.keras.utils.plot_model(self.GAN.combined, to_file="_model_combined.png", show_shapes=True)

    def train(self):
        print("[start training]")

        for iteration, batch in enumerate(self.MNIST.tfdata, start=1):

            y_real = batch
            z_in = np.random.normal(0, 1, (self.params.BATCH_SIZE, self.params.NOISE_DIM))

            y_gen, I_x, I_y, I_r = self.GAN.generator.predict(z_in)

            D_loss_real_dict = self.GAN.discriminator.train_on_batch(y_real, self.params.REAL, return_dict=True)
            D_loss_fake_dict = self.GAN.discriminator.train_on_batch(y_gen, self.params.FAKE, return_dict=True)

            # Train Generator
            G_loss_dict = self.GAN.combined.train_on_batch(z_in, [self.params.REAL, self.params.opt_I_x, self.params.opt_I_y, self.params.opt_I_r], return_dict=True)

            # Dump loss
            self.Loss.dump_loss_dicts(D_loss_real_dict, D_loss_fake_dict, G_loss_dict)

            if(iteration%self.params.ITERATION_SPAN == 0): 
                print(f"iteration: {iteration}")
                self.GAN.generator.save(f'model-{self.params.EXPERIMENTAL_NAME}-{iteration}.hdf5')
                generate_samples(params=self.params, GAN=self.GAN, iteration=iteration)
                FID = calc_fid_for_generator(params=self.params, GAN=self.GAN, MNIST=self.MNIST)
                self.Loss.dump_FID(FID) # Dump FID
                self.Loss.show_loss(iteration=iteration)
        return self.Loss.calc_stat()

# 学習

In [0]:
trainer = Trainer(Params()).train()

# パラメータサーチ用

外側から Params をいじる

- `NAME`: 各実験の保存用の名前を書く場所。自動にしてもよかった
- `loss_weights`: これがロスの配分になる
-  `power_I_x` / `power_I_y` / `power_I_r`: `0`は断面二次モーメントが増すような学習、`1` は下がるような学習になる

In [0]:
NAME = ["Vanilla", "I_x+75", "I_y+75", "I_r+75", "I_x-500", "I_y-500", "I_r-500"]

for i, name in enumerate(NAME):
    print(f"{i}, {name}")
    loss_weights = [
        {"possibility": 1.0, "I_x": 0.0, "I_y": 0.0, "I_r": 0.0} #vanilla
      , {"possibility": 1.0, "I_x": 75.0, "I_y": 0.0, "I_r": 0.0} #I_x +75
      , {"possibility": 1.0, "I_x": 0.0, "I_y": 75.0, "I_r": 0.0} # I_y+75
      , {"possibility": 1.0, "I_x": 0.0, "I_y": 0.0, "I_r": 75.0}  # I_r + 75
      , {"possibility": 1.0, "I_x": 500.0, "I_y": 0.0, "I_r": 00.0} # I_x-500
      , {"possibility": 1.0, "I_x": 0.0, "I_y": 500.0, "I_r": 00.0} #I_y-500
      , {"possibility": 1.0, "I_x": 0.0, "I_y": 0.0, "I_r": 500.0}  #I_r-500
      , {"possibility": 1.0, "I_x": 50.0, "I_y": 50.0, "I_r": 0.0} # X+50 y-50
      , {"possibility": 1.0, "I_x": 50.0, "I_y": 50.0, "I_r": 0.0} # X-50 y-50
    ]
    power_I_x = [0, 0, 0, 0, 1, 0, 0, 0, 1]
    power_I_y = [0, 0, 0, 0, 0, 1, 0, 1, 0]
    power_I_r = [0, 0, 0, 0, 0, 0, 1, 0, 0]

    class EXP1:
        def __init__(self, NAME):
            params = Params()
            params.EXPERIMENTAL_NAME = name
            params.generator_loss_weights = loss_weights[i]
            params.opt_I_x = [params.REAL, params.FAKE][power_I_x[i]] 
            params.opt_I_y = [params.REAL, params.FAKE][power_I_y[i]]
            params.opt_I_r = [params.REAL, params.FAKE][power_I_r[i]]
            params.save_params()
            trainer = Trainer(params)
            self.statdict = trainer.train()
            self.save_params(params.EXPERIMENTAL_NAME)
                        
        def save_params(self, EXPERIMENTAL_NAME):
            with open(f'results_{EXPERIMENTAL_NAME}_params.md', 'w') as f:
                f.write(f"# Exp: {EXPERIMENTAL_NAME}\n")
                f.write("## params\n")
                for k, v in zip(self.statdict.keys(), self.statdict.values()):
                    f.write(str(f"{k}, {v}")+"\n")
    exp1 = EXP1(name)

以上