In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

In [3]:
#该函数用于输出生成图片
def plot(samples, sample_count=sample_count):
    fig = plt.figure(figsize=(sample_count, sample_count))
    gs = gridspec.GridSpec(sample_count, sample_count)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(tf.reshape(sample,(28, 28)), cmap='Greys_r')
        
    return fig

In [2]:
# 参数设置
batchsz = 100
z_dim = 100
learning_rate = 1e-4
is_training = True
epochs = 30000
sample_count = 10
network_path = "E:\\C_all\\Desktop\\深度之眼\\paper\\tensorflow\\文献实现\\GAN\\模型输出\\CGAN\\"
img_path = "E:\\C_all\\Desktop\\深度之眼\\paper\\tensorflow\\文献实现\\GAN\\out\\"
if not os.path.exists(img_path):
    os.makedirs(img_path)

In [4]:
class Maxoutlayer(tf.keras.layers.Layer):
    def __init__(self, k, m):
        super(Maxoutlayer, self).__init__()
        self.k = int(k)
        self.m = int(m)

    def build(self, input_shape, dtype=tf.float32):
        self.d = input_shape[-1]
        print(self.d,input_shape)
        self.w = self.add_weight(name='w',
                                 shape=(self.d, self.m, self.k),
                                 initializer='uniform',
                                 dtype=dtype,
                                 trainable=True)
        self.b = self.add_weight(name='b',
                                 shape=(self.m, self.k),
                                 initializer='zero',
                                 dtype=dtype,
                                 trainable=True)
        super(Maxoutlayer, self).build(input_shape)

    def call(self, x):
        outputs = tf.tensordot(x, self.w, axes=1) + self.b
        outputs = tf.reduce_max(outputs, axis=2)
        return outputs
class Generator(keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
#         self.dense1 = keras.layers.Dense(128) # 正常dense
        self.dense1 = Maxoutlayer(5, 128) # maxoutlayer
        self.dense2 = keras.layers.Dense(784)

    def call(self, z, y, training=None):
        zy = tf.concat([z, y], axis=1)
        g_1 = tf.nn.elu(self.dense1(zy))
        mean,variance = tf.nn.moments(g_1,-1) # maxoutlayer
        g_1 = (g_1 - tf.reshape(mean,(batchsz,-1))) / tf.reshape(variance**0.5,(batchsz,-1)) # maxoutlayer
        g_2 = tf.nn.sigmoid(self.dense2(g_1))
        return g_2


class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dense1 = keras.layers.Dense(128)
        self.dense2 = keras.layers.Dense(1)

    def call(self, x, y, training=None):
        xy = tf.concat([x, y], axis=1)
        d_1 = tf.nn.elu(self.dense1(xy))
        d_logit = self.dense2(d_1)
        d_prob = tf.nn.sigmoid(d_logit)
        return d_prob, d_logit

In [5]:
def sample_z(m, n):
    return tf.random.uniform(maxval=1., minval=-1., shape=[m, n])

In [6]:
(x, y),(x_val,y_val) = keras.datasets.mnist.load_data()

x = tf.convert_to_tensor(x, dtype=tf.float32)/255.
x = tf.reshape(x,(-1, 28*28))
y = tf.convert_to_tensor(y, dtype=tf.int32)
y_onehot = tf.one_hot(y, depth=10) # one_hot成立
train_dataset = tf.data.Dataset.from_tensor_slices((x,y_onehot)).repeat(-1).batch(batchsz)
train_dbiter = iter(train_dataset)

In [7]:
generator = Generator()
discriminator = Discriminator()

g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate)

In [8]:
# 交叉熵损失函数
def celoss_ones(logits):
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,labels=tf.ones_like(logits))
    return tf.reduce_mean(loss)

def celoss_zeros(logits):
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.zeros_like(logits))
    return tf.reduce_mean(loss)

def d_loss_fn(generator, discriminator, batch_z, batch_x, batch_y, is_training):
    fake_data = generator(batch_z, batch_y, is_training)
    d_real,d_real_logits = discriminator(batch_x, batch_y, is_training)
    d_fake,d_fake_logits = discriminator(fake_data, batch_y, is_training)
    d_loss_real = celoss_ones(d_real_logits)
    d_loss_fake = celoss_zeros(d_fake_logits)
    d_loss = d_loss_fake + d_loss_real
    return d_loss

def g_loss_fn(generator, discriminator, batch_z, batch_y, is_training):
    fake_data = generator(batch_z, batch_y, is_training)
    d_fake,d_fake_logits = discriminator(fake_data, batch_y, is_training)
    g_loss = celoss_ones(d_fake_logits)
    return g_loss

In [10]:
def save_image(label, generator, m=sample_count**2, n=z_dim):
    sample_y = tf.one_hot(list(range(10)) * 10,depth=10,dtype=tf.float32)
    z = sample_z(m, n)
    fake_images = generator(z,sample_y)
    fig = plot(fake_images)
    plt.savefig(img_path + "{}.png".format(label),bbox_inches="tight")
    plt.close(fig)

In [11]:
def save_weights_(generator, discriminator, network_path=network_path):
    if not os.path.exists(img_path):
        os.makedirs(img_path)
    generator.save_weights(network_path+"cgan_g")
    discriminator.save_weights(network_path+"cgan_d")
    print("saved total weights.")

In [None]:
for epoch in range(epochs+1):
    batch_z = sample_z(batchsz, z_dim)
    batch_x, batch_y = next(train_dbiter)

    with tf.GradientTape() as tape:
        d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, batch_y, is_training)
    grads = tape.gradient(d_loss, discriminator.trainable_variables)
    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
    
    with tf.GradientTape() as tape:
        g_loss = g_loss_fn(generator, discriminator, batch_z, batch_y, is_training)
    grads = tape.gradient(g_loss, generator.trainable_variables)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
    
#     if learning_rate >= 0.000001:
#         learning_rate -= learning_rate*0.00004
    
    if epoch % 1000 == 0:
        save_image(str(epoch), generator)
        save_weights_(generator, discriminator)
        print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))
        print()

110 (100, 110)
saved total weights.
0 d-loss: 1.5539416074752808 g-loss: 1.0622875690460205

saved total weights.
1000 d-loss: 0.5798658132553101 g-loss: 1.7094054222106934

saved total weights.
2000 d-loss: 0.7365269064903259 g-loss: 1.2456793785095215

saved total weights.
3000 d-loss: 0.42904508113861084 g-loss: 1.7211616039276123

saved total weights.
4000 d-loss: 0.20582297444343567 g-loss: 2.4959986209869385

saved total weights.
5000 d-loss: 0.20742002129554749 g-loss: 2.54160213470459

saved total weights.
6000 d-loss: 0.10852529108524323 g-loss: 3.3190321922302246

saved total weights.
7000 d-loss: 0.13999749720096588 g-loss: 3.1363685131073

saved total weights.
8000 d-loss: 0.22361893951892853 g-loss: 2.5899996757507324

saved total weights.
9000 d-loss: 0.38055890798568726 g-loss: 2.4479100704193115

saved total weights.
10000 d-loss: 0.5953718423843384 g-loss: 1.8801110982894897

saved total weights.
11000 d-loss: 0.5042920112609863 g-loss: 1.9203633069992065

saved total 

In [None]:
# # 在custom layer中的 call() 上添加 @tf.function 可以将前向传播过程中不属于Graph的部分 转化为Graph。
# network_path = "E:\\C_all\\Desktop\\深度之眼\\paper\\tensorflow\\文献实现\\GAN\\模型输出\\"
# generator.save_weights(network_path+"original_gan_g")
# discriminator.save_weights(network_path+"original_gan_d")
# print("saved total weights.")
del generator
del discriminator
generator = Generator()
discriminator = Discriminator()
save_image("载入前",generator)
# generator.load_weights(network_path+"cgan_g")
# discriminator.load_weights(network_path+"cgan_d")
# save_image("载入后",generator)