## 安装kears-crontrib

In [8]:
! pip install git+https://www.github.com/keras-team/keras-contrib.git

Collecting git+https://www.github.com/keras-team/keras-contrib.git
  Cloning https://www.github.com/keras-team/keras-contrib.git to /private/var/folders/46/b7dzk4mn6g54qzptv608w7d00000gn/T/pip-t1_7py2z-build
Installing collected packages: keras-contrib
  Running setup.py install for keras-contrib ... [?25ldone
[?25hSuccessfully installed keras-contrib-2.0.8
[33mYou are using pip version 9.0.1, however version 18.0 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


## 读取数据

In [1]:
from glob import glob
import numpy as np

class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))
        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = scipy.misc.imresize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = scipy.misc.imresize(img, self.img_res)
            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.

        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))
        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size

        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(self.n_batches-1):
            batch_A = path_A[i*batch_size:(i+1)*batch_size]
            batch_B = path_B[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = scipy.misc.imresize(img_A, self.img_res)
                img_B = scipy.misc.imresize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B

    def load_img(self, path):
        img = self.imread(path)
        img = scipy.misc.imresize(img, self.img_res)
        img = img/127.5 - 1.
        return img[np.newaxis, :, :, :]

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

## 引用相关套件

In [2]:
import scipy

from keras.datasets import mnist
from keras_contrib.layers.normalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os

Using TensorFlow backend.


## 建立生成器

In [3]:
# 用U-NET网路结构
def build_generator():

    # 下采样（Down Sampling）
    def conv2d(layer_input, filters, f_size=4):
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        d = InstanceNormalization()(d)
        return d
    
    # 上采样（Upsampling）
    def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
        if dropout_rate:
            u = Dropout(dropout_rate)(u)
        u = InstanceNormalization()(u)
        u = Concatenate()([u, skip_input])
        return u

    #  输入图片
    d0 = Input(shape=img_shape)

    # 下采样
    d1 = conv2d(d0, gf)
    d2 = conv2d(d1, gf*2)
    d3 = conv2d(d2, gf*4)
    d4 = conv2d(d3, gf*8)

    # 上采样和跨层连结
    u1 = deconv2d(d4, d3, gf*4)
    u2 = deconv2d(u1, d2, gf*2)
    u3 = deconv2d(u2, d1, gf)

    u4 = UpSampling2D(size=2)(u3)
    output_img = Conv2D(channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

    return Model(d0, output_img)

## 建立鉴别器

In [4]:
def build_discriminator():

    # 建立鉴别器网路层
    def d_layer(layer_input, filters, f_size=4, normalization=True):
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if normalization:
            d = InstanceNormalization()(d)
        return d

    img = Input(shape=img_shape)

    d1 = d_layer(img, df, normalization=False)
    d2 = d_layer(d1, df*2)
    d3 = d_layer(d2, df*4)
    d4 = d_layer(d3, df*8)

    validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

    return Model(img, validity)

## 训练过程

In [5]:
def train(epochs, batch_size=1, sample_interval=50):

    start_time = datetime.datetime.now()

    # 真的图片的答案判定为1生成图片判定为0
    valid = np.ones((batch_size,) + disc_patch)
    fake = np.zeros((batch_size,) + disc_patch)

    for epoch in range(epochs):
        for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch(batch_size)):

            #  训练鉴别器
            
            # 将图片从A翻译到B再从B翻译到A
            fake_B = g_AB.predict(imgs_A)
            fake_A = g_BA.predict(imgs_B)

            # 训练A鉴别器
            dA_loss_real = d_A.train_on_batch(imgs_A, valid)
            dA_loss_fake = d_A.train_on_batch(fake_A, fake)
            dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

            # 训练B鉴别器
            dB_loss_real = d_B.train_on_batch(imgs_B, valid)
            dB_loss_fake = d_B.train_on_batch(fake_B, fake)
            dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

            # 总鉴别器损失
            d_loss = 0.5 * np.add(dA_loss, dB_loss)


            #  训练生成器
            g_loss = combined.train_on_batch([imgs_A, imgs_B],
                                                    [valid, valid,
                                                    imgs_A, imgs_B,
                                                    imgs_A, imgs_B])

            # 计算训练时间
            elapsed_time = datetime.datetime.now() - start_time

            # 打印进度
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                    % ( epoch, epochs,
                                                                        batch_i, data_loader.n_batches,
                                                                        d_loss[0], 100*d_loss[1],
                                                                        g_loss[0],
                                                                        np.mean(g_loss[1:3]),
                                                                        np.mean(g_loss[3:5]),
                                                                        np.mean(g_loss[5:6]),
                                                                        elapsed_time))

            # 每五回合采样图片
            if batch_i % sample_interval == 0:
                sample_images(epoch, batch_i)

## 图片取样

In [6]:
def sample_images(epoch, batch_i):
    os.makedirs('images/%s' % dataset_name, exist_ok=True)
    r, c = 2, 3

    imgs_A = data_loader.load_data(domain="A", batch_size=1, is_testing=True)
    imgs_B = data_loader.load_data(domain="B", batch_size=1, is_testing=True)

    # 将图片从A转译到B
    fake_B = g_AB.predict(imgs_A)
    fake_A = g_BA.predict(imgs_B)
    # 重建图片（从B转译回A）
    reconstr_A = g_BA.predict(fake_B)
    reconstr_B = g_AB.predict(fake_A)

    # 生成图片
    gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

    # 将图片标准化到0  -  1之间的范围
    gen_imgs = 0.5 * gen_imgs + 0.5

    titles = ['Original', 'Translated', 'Reconstructed']
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt])
            axs[i, j].set_title(titles[j])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%s/%d_%d.png" % (dataset_name, epoch, batch_i))
    plt.close()

## 建构模型

In [7]:
# 图片参数
img_rows = 128
img_cols = 128
channels = 3
img_shape = (img_rows, img_cols, channels)

In [11]:
# 读取数据
dataset_name = 'horse2zebra'
data_loader = DataLoader(dataset_name=dataset_name, img_res=(img_rows, img_cols))


# 计算D输出的维度(PatchGAN)
patch = int(img_rows / 2**4)
disc_patch = (patch, patch, 1)

# 生成器与鉴别器第一层的神经元数
gf = 32
df = 64

# 损失
lambda_cycle = 10.0               # 循环一致性损失
lambda_id = 0.1 * lambda_cycle    # Identity loss

optimizer = Adam(0.0002, 0.5)

# 建立以及编译鉴别器
d_A = build_discriminator()
d_B = build_discriminator()
d_A.compile(loss='mse',
    optimizer=optimizer,
    metrics=['accuracy'])
d_B.compile(loss='mse',
    optimizer=optimizer,
    metrics=['accuracy'])

In [12]:
# 建立生成器
g_AB = build_generator()
g_BA = build_generator()

# 从A B两个Domain输入图片
img_A = Input(shape=img_shape)
img_B = Input(shape=img_shape)

# 从Domain A 将图片转译到 Domain B
fake_B = g_AB(img_A)
fake_A = g_BA(img_B)
# 从Domain B 将图片转译回 Domain A (重建)
reconstr_A = g_BA(fake_B)
reconstr_B = g_AB(fake_A)

# 建立图片识别ID
img_A_id = g_BA(img_A)
img_B_id = g_AB(img_B)

# 只训练生成器
d_A.trainable = False
d_B.trainable = False

# 利用鉴别器辨别图片真伪
valid_A = d_A(fake_A)
valid_B = d_B(fake_B)

# 训练生成器骗过鉴别器
combined = Model(inputs=[img_A, img_B],
                      outputs=[ valid_A, valid_B,
                                reconstr_A, reconstr_B,
                                img_A_id, img_B_id ])
combined.compile(loss=['mse', 'mse',
                       'mae', 'mae',
                       'mae', 'mae'],
                    loss_weights=[  1, 1,
                                    lambda_cycle, lambda_cycle,
                                    lambda_id, lambda_id ],
                    optimizer=optimizer)

## 训练CycleGAN

In [None]:
train(epochs=20, batch_size=10, sample_interval=10)

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
  'Discrepancy between trainable weights and collected trainable'


[Epoch 0/20] [Batch 0/106] [D loss: 28.575975, acc:  12%] [G loss: 32.662563, adv: 8.985338, recon: 0.674166, id: 0.594438] time: 0:00:31.423852 


`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


[Epoch 0/20] [Batch 1/106] [D loss: 18.708851, acc:  10%] [G loss: 23.850031, adv: 3.718084, recon: 0.754779, id: 0.677637] time: 0:00:44.625135 
[Epoch 0/20] [Batch 2/106] [D loss: 6.108274, acc:  18%] [G loss: 33.770657, adv: 8.953201, recon: 0.727703, id: 0.710256] time: 0:00:56.064625 
[Epoch 0/20] [Batch 3/106] [D loss: 3.475364, acc:  28%] [G loss: 29.591640, adv: 6.659639, recon: 0.748904, id: 0.690904] time: 0:01:07.432699 
[Epoch 0/20] [Batch 4/106] [D loss: 3.221425, acc:  20%] [G loss: 17.634134, adv: 0.862468, recon: 0.734473, id: 0.689285] time: 0:01:19.437452 
[Epoch 0/20] [Batch 5/106] [D loss: 0.714048, acc:  47%] [G loss: 17.721844, adv: 2.011209, recon: 0.629395, id: 0.547526] time: 0:01:31.504293 
[Epoch 0/20] [Batch 6/106] [D loss: 0.645739, acc:  43%] [G loss: 13.123373, adv: 0.862360, recon: 0.519493, id: 0.462641] time: 0:01:43.504518 
[Epoch 0/20] [Batch 7/106] [D loss: 0.787574, acc:  43%] [G loss: 13.718512, adv: 1.350994, recon: 0.503539, id: 0.429093] time: 

[Epoch 0/20] [Batch 58/106] [D loss: 0.288422, acc:  62%] [G loss: 5.999347, adv: 0.651062, recon: 0.212798, id: 0.208702] time: 0:12:39.768991 
[Epoch 0/20] [Batch 59/106] [D loss: 0.333338, acc:  48%] [G loss: 6.156004, adv: 0.655350, recon: 0.220536, id: 0.185324] time: 0:12:52.239997 
[Epoch 0/20] [Batch 60/106] [D loss: 0.418815, acc:  33%] [G loss: 6.048318, adv: 0.691613, recon: 0.212553, id: 0.184398] time: 0:13:04.458802 
[Epoch 0/20] [Batch 61/106] [D loss: 0.417786, acc:  32%] [G loss: 6.052221, adv: 0.567244, recon: 0.224231, id: 0.180085] time: 0:13:17.053322 
[Epoch 0/20] [Batch 62/106] [D loss: 0.481978, acc:  35%] [G loss: 8.068579, adv: 0.856917, recon: 0.291965, id: 0.255859] time: 0:13:29.361450 
[Epoch 0/20] [Batch 63/106] [D loss: 0.505195, acc:  32%] [G loss: 6.862192, adv: 0.672155, recon: 0.252961, id: 0.210994] time: 0:13:42.409581 
[Epoch 0/20] [Batch 64/106] [D loss: 0.275887, acc:  55%] [G loss: 6.230003, adv: 0.707281, recon: 0.219813, id: 0.202858] time: 0