In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import time


In [2]:
from tensorflow.keras.layers import Conv2D, Dropout, Flatten, Dense, Reshape, Conv2DTranspose, ReLU, BatchNormalization, LeakyReLU
from tensorflow import keras
import tensorflow as tf

In [7]:
import matplotlib.pyplot as plt

In [3]:
def mnist_uni_gen_cnn(input_shape):
    return keras.Sequential([
        # [n, latent] -> [n, 7 * 7 * 128] -> [n, 7, 7, 128]
        Dense(7 * 7 * 128, input_shape=input_shape),
        BatchNormalization(),
        ReLU(),
        Reshape((7, 7, 128)),
        # -> [n, 14, 14, 64]
        Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'),
        BatchNormalization(),
        ReLU(),
        # -> [n, 28, 28, 32]
        Conv2DTranspose(32, (4, 4), strides=(2, 2), padding='same'),
        BatchNormalization(),
        ReLU(),
        # -> [n, 28, 28, 1]
        Conv2D(1, (4, 4), padding='same', activation=keras.activations.tanh)
    ])


def mnist_uni_disc_cnn(input_shape=(28, 28, 1), use_bn=True):
    model = keras.Sequential()
    # [n, 28, 28, n] -> [n, 14, 14, 64]
    model.add(Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=input_shape))
    if use_bn:
        model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(Dropout(0.3))
    # -> [n, 7, 7, 128]
    model.add(Conv2D(128, (4, 4), strides=(2, 2), padding='same'))
    if use_bn:
        model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(Dropout(0.3))
    model.add(Flatten())
    return model


In [4]:
from tensorflow import keras
import tensorflow as tf
import os
import numpy as np

MNIST_PATH = "./mnist.npz"


def load_mnist(path):
    if os.path.isfile(path):
        with np.load(path, allow_pickle=True) as f:
            x_train, y_train = f['x_train'], f['y_train']
            x_test, y_test = f['x_test'], f['y_test']
        return (x_train, y_train), (x_test, y_test)
    return keras.datasets.mnist.load_data(MNIST_PATH)


def get_half_batch_ds(batch_size):
    return get_ds(batch_size//2)


def get_ds(batch_size):
    (x, y), _ = load_mnist(MNIST_PATH)
    x = _process_x(x)
    y = tf.cast(y, tf.int32)
    ds = tf.data.Dataset.from_tensor_slices((x, y)).cache().shuffle(1024).batch(batch_size) \
        .prefetch(tf.data.experimental.AUTOTUNE)
    return ds


def get_test_x():
    (_, _), (x, _) = load_mnist(MNIST_PATH)
    x = _process_x(x)
    return x


def get_test_69():
    _, (x, y) = load_mnist(MNIST_PATH)
    return _process_x(x[y == 6]), _process_x(x[y == 9])


def get_train_x():
    (x, _), _ = load_mnist(MNIST_PATH)
    x = _process_x(x)
    return x


def _process_x(x):
    return tf.expand_dims(tf.cast(x, tf.float32), axis=3) / 255. * 2 - 1


def get_69_ds():
    (x, y), _ = load_mnist(MNIST_PATH)
    x6, x9 = x[y == 6], x[y == 9]
    return _process_x(x6), _process_x(x9)


def downsampling(imgs, to_shape):
    s = to_shape[:2]
    imgs = tf.random.normal(imgs.shape, 0, 0.2) + imgs
    return tf.image.resize(imgs, size=s)

In [5]:
def set_soft_gpu(soft_gpu):
    if soft_gpu:
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")

In [6]:
def save_weights(model):
    name = model.__class__.__name__.lower()
    os.makedirs("./models/{}".format(name), exist_ok=True)
    model.save_weights("./models/{}/model.ckpt".format(name))

In [8]:
def save_gan(model, ep, **kwargs):
    name = model.__class__.__name__.lower()
    if name in ["dcgan", "wgan", "wgangp", "lsgan", "wgandiv", "sagan", "pggan"]:
        imgs = model.call(100, training=False).numpy()
        _save_gan(name, ep, imgs, show_label=False)
    elif name == "gan":
        data = model.call(5, training=False).numpy()
        plt.plot(data.T)
        plt.xticks((), ())
        dir_ = "visual/{}".format(name)
        os.makedirs(dir_, exist_ok=True)
        path = dir_ + "/{}.png".format(ep)
        plt.savefig(path)
    elif name == "cgan" or name == "acgan":
        img_label = np.arange(0, 10).astype(np.int32).repeat(10, axis=0)
        imgs = model.predict(img_label)
        _save_gan(name, ep, imgs, show_label=True)
    elif name in ["infogan"]:
        img_label = np.arange(0, model.label_dim).astype(np.int32).repeat(10, axis=0)
        img_style = np.concatenate(
            [np.linspace(-model.style_scale, model.style_scale, 10)] * 10).reshape((100, 1)).repeat(model.style_dim, axis=1).astype(np.float32)
        img_info = img_label, img_style
        imgs = model.predict(img_info)
        _save_gan(name, ep, imgs, show_label=False)


    elif name == "stylegan":
        n = 12
        global z1, z2       # z1 row, z2 col
        if "z1" not in globals():
            z1 = np.random.normal(0, 1, size=(n, 1, model.latent_dim))
        if "z2" not in globals():
            z2 = np.random.normal(0, 1, size=(n, 1, model.latent_dim))
        imgs = model.predict([
            np.concatenate(
                (z1.repeat(n, axis=0).repeat(1, axis=1), np.repeat(np.concatenate([z2 for _ in range(n)], axis=0), 2, axis=1)),
                axis=1),
            np.zeros([len(z1)*n, model.img_shape[0], model.img_shape[1]], dtype=np.float32)])
        z1_imgs = -model.predict([z1.repeat(model.n_style, axis=1), np.zeros([len(z1), model.img_shape[0], model.img_shape[1]], dtype=np.float32)])
        z2_imgs = -model.predict([z2.repeat(model.n_style, axis=1), np.zeros([len(z2), model.img_shape[0], model.img_shape[1]], dtype=np.float32)])
        imgs = np.concatenate([z2_imgs, imgs], axis=0)
        rest_imgs = np.concatenate([np.ones([1, 28, 28, 1], dtype=np.float32), z1_imgs], axis=0)
        for i in range(len(rest_imgs)):
            imgs = np.concatenate([imgs[:i*(n+1)], rest_imgs[i:i+1], imgs[i*(n+1):]], axis=0)
        _save_gan(name, ep, imgs, show_label=False, nc=n+1, nr=n+1)
    else:
        raise ValueError(name)
    plt.clf()
    plt.close()

def _img_recenter(img):
    return (img + 1) * 255 / 2

In [9]:
def _save_gan(model_name, ep, imgs, show_label=False, nc=10, nr=10):
    if not isinstance(imgs, np.ndarray):
        imgs = imgs.numpy()
    if imgs.ndim > 3:
        imgs = np.squeeze(imgs, axis=-1)
    imgs = _img_recenter(imgs)
    plt.clf()
    plt.figure(0, (nc * 2, nr * 2))
    for c in range(nc):
        for r in range(nr):
            i = r * nc + c
            plt.subplot(nr, nc, i + 1)
            plt.imshow(imgs[i], cmap="gray_r")
            plt.axis("off")
            if show_label:
                plt.text(23, 26, int(r), fontsize=23)
    plt.tight_layout()
    dir_ = "visual/{}".format(model_name)
    os.makedirs(dir_, exist_ok=True)
    path = dir_ + "/{}.png".format(ep)
    plt.savefig(path)



In [10]:
def cvt_gif(folders_or_gan, shrink=10):
    if not isinstance(folders_or_gan, list):
        folders_or_gan = [folders_or_gan.__class__.__name__.lower()]
    for folder in folders_or_gan:
        folder = "visual/"+folder
        fs = [folder+"/" + f for f in os.listdir(folder)]
        imgs = []
        for f in sorted(fs, key=os.path.getmtime):
            if not f.endswith(".png"):
                continue
            try:
                int(os.path.basename(f).split(".")[0])
            except ValueError:
                continue
            img = Image.open(f)
            img = img.resize((img.width//shrink, img.height//shrink), Image.ANTIALIAS)
            imgs.append(img)
        path = "{}/generating.gif".format(folder)
        if os.path.exists(path):
            os.remove(path)
        imgs[-1].save(path, append_images=imgs, optimize=False, save_all=True, duration=400, loop=0)
        print("saved ", path)

In [15]:
class WGAN(keras.Model):
    """
    Wasserstein 距离作为损失函数， 避免D太强导致G的梯度消失。
    D 最大化 Wasserstein 距离，提高收敛性
    G 最小化 Wasserstein 距离
    Clip D weights，局限住太强的 D，让 G 可以跟上 (Lipschitz 约束)。
    """
    def __init__(self, latent_dim, clip, img_shape):
        super().__init__()
        self.latent_dim = latent_dim
        self.clip = clip
        self.img_shape = img_shape
        self.opt = tf.keras.optimizers.legacy.Adam(0.0002, beta_1=0, beta_2=0.9)
        self.g = self._get_generator()
        self._build_d()

    def _build_d(self):
        self.d = self._get_discriminator()

    def call(self, n, training=None, mask=None):
        return self.g.call(tf.random.normal((n, self.latent_dim)), training=training)

    def _get_generator(self):
        model = mnist_uni_gen_cnn((self.latent_dim,))
        model.summary()
        return model

    def _get_discriminator(self, use_bn=True):
        model = keras.Sequential([
            mnist_uni_disc_cnn(self.img_shape, use_bn),
            keras.layers.Dense(1)
        ], name="critic")
        model.summary()
        return model

    @staticmethod
    def w_distance(fake, real=None):
        # the distance of two data distributions
        if real is None:
            return tf.reduce_mean(fake)
        else:
            return tf.reduce_mean(fake) - tf.reduce_mean(real)

    def train_d(self, real_img):
        with tf.GradientTape() as tape:
            fake_img = self.call(len(real_img), training=False)
            pred_real = self.d.call(real_img, training=True)
            pred_fake = self.d.call(fake_img, training=True)
            loss = self.w_distance(pred_fake, pred_real)   # maximize W distance
        grads = tape.gradient(loss, self.d.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.d.trainable_variables))
        # clip discriminator's weights
        for w in self.d.trainable_weights:
            w.assign(tf.clip_by_value(w, -self.clip, self.clip))
        return loss

    def train_g(self, n):
        with tf.GradientTape() as tape:
            g_img = self.call(n, training=True)
            pred_fake = self.d.call(g_img, training=False)
            loss = -self.w_distance(pred_fake)       # minimize W distance
        grads = tape.gradient(loss, self.g.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.g.trainable_variables))
        return loss


def train(gan, ds, steps, d_loop, batch_size):
    t0 = time.time()
    for t in range(steps):
        for _ in range(d_loop):
            idx = np.random.randint(0, len(ds), batch_size)
            real_img = tf.gather(ds, idx)
            d_loss = gan.train_d(real_img)
        g_loss = gan.train_g(batch_size)
        if t % 1000 == 0:
            t1 = time.time()
            print("t={} | time={:.1f} | d_loss={:.2f} | g_loss={:.2f}".format(
                    t, t1 - t0, d_loss.numpy(), g_loss.numpy(), ))
            t0 = t1
            save_gan(gan, t)
    save_weights(gan)
    cvt_gif(gan)


if __name__ == "__main__":
    LATENT_DIM = 100
    CLIP = 0.01
    D_LOOP = 5
    IMG_SHAPE = (28, 28, 1)
    BATCH_SIZE = 64
    STEP = 20001

    set_soft_gpu(True)
    d = get_train_x()
    m = WGAN(LATENT_DIM, CLIP, IMG_SHAPE)
    train(m, d, STEP, D_LOOP, BATCH_SIZE)

1 Physical GPUs, 1 Logical GPUs
Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_8 (Dense)             (None, 6272)              633472    
                                                                 
 batch_normalization_18 (Ba  (None, 6272)              25088     
 tchNormalization)                                               
                                                                 
 re_lu_12 (ReLU)             (None, 6272)              0         
                                                                 
 reshape_4 (Reshape)         (None, 7, 7, 128)         0         
                                                                 
 conv2d_transpose_8 (Conv2D  (None, 14, 14, 64)        131136    
 Transpose)                                                      
                                                                 
 batch_normalization_1

NameError: ignored

<Figure size 640x480 with 0 Axes>