In [None]:
#导入相关的库
import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras import Model ,layers


In [None]:
data = pd.read_csv('train.csv')    #读取数据
x = data.loc[:,data.columns!='label'].values/255.0     #获取图片数据并归一化到0-1
x = np.array(x,np.float32)
y = data['label'].values    #获取标签  
data = tf.data.Dataset.from_tensor_slices((x,y))          #构造数据集
data_loader = data.repeat().shuffle(1000).batch(128).prefetch(1)     #分批次加载数据

In [None]:
#定义生生成器
class Generator(Model):
    def __init__(self):
        super(Generator,self).__init__()
        self.fc1 = layers.Dense(7*7*128)
        self.bn1 = layers.BatchNormalization()
        self.conv2tr1 = layers.Conv2DTranspose(64,5,strides=2,padding='SAME')   #反卷积操作
        self.bn2 = layers.BatchNormalization()
        self.conv2tr2 = layers.Conv2DTranspose(1,5,strides=2,padding='SAME')
    def call(self,x,is_training=False):
        x = self.fc1(x)
        x = self.bn1(x,training=is_training)
        x  = tf.nn.leaky_relu(x)
        x = tf.reshape(x,shape=[-1,7,7,128])
        
        x = self.conv2tr1(x)
        x = self.bn2(x,training=is_training)
        x = tf.nn.leaky_relu(x)
        x = self.conv2tr2(x)
        return x
#定义判别器
class Discriminator(Model):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.conv1 = layers.Conv2D(64,5,strides=2,padding='SAME')
        self.bn1 = layers.BatchNormalization()
        self.conv2 = layers.Conv2D(128,5,strides=2,padding='SAME')
        self.bn2 = layers.BatchNormalization()
        self.flatten = layers.Flatten()
        self.fc1 = layers.Dense(1024)
        self.bn3 = layers.BatchNormalization()
        self.fc2 = layers.Dense(2)
    def call(self,x,is_training=False):
        x = tf.reshape(x,[-1,28,28,1])
        x =  x = self.conv1(x)
        x = self.bn1(x, training=is_training)
        x = tf.nn.leaky_relu(x)
        x = self.conv2(x)
        x = self.bn2(x, training=is_training)
        x = tf.nn.leaky_relu(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.bn3(x, training=is_training)
        x = tf.nn.leaky_relu(x)
        return self.fc2(x)

generator = Generator()
discriminator = Discriminator()


In [None]:
#定义生成器的损失函数
def generator_loss(image): 
    gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=image,labels=tf.ones([128],dtype=tf.int32)))
    return gen_loss
#定义判别器的损失函数
def discriminator_loss(fake_image,real_image):
    fake_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=fake_image,labels=tf.ones([128],dtype=tf.int32)))
    real_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=real_image,labels=tf.zeros([128],dtype=tf.int32)))
    return fake_loss + real_loss

#定义优化函数
optimizer_gen = tf.optimizers.Adam(learning_rate=0.0002)
optimizer_disc =  tf.optimizers.Adam(learning_rate=0.0002)

In [None]:
def  run_optimizer(real_image):
    real_images = real_image*2. - 1.
    noise = np.random.normal(-1.,1.,size=[128,100]).astype(np.float32)
    #先训练判别器
    with tf.GradientTape() as g:
        fake_images = generator(noise,is_training=True)     #生成假图片
        
        disc_fake = discriminator(fake_images,is_training=True)    #将假图片给判别器识别
        
        disc_real = discriminator(real_images,is_training=True)     #将真图片给判别器识别
        
        disc_loss = discriminator_loss(disc_fake,disc_real)
    gradient_disc = g.gradient(disc_loss,discriminator.trainable_variables)
    optimizer_disc.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))
    
    #后训练生成器
    noise = np.random.normal(-1.,1.,size=[128,100]).astype(np.float32)
    with tf.GradientTape() as g:
            fake_images = generator(noise,is_training=True)
            disc_fake = discriminator(fake_images,is_training=True)
            gen_loss = generator_loss(disc_fake)
    gradient_gen = g.gradient(gen_loss,generator.trainable_variables)
    optimizer_gen.apply_gradients(zip(gradient_gen,generator.trainable_variables))
    
    return (gen_loss,disc_loss)

In [None]:
#开始训练
for step,(batch_x,_) in enumerate(data_loader.take(1000)):
    (gen_loss,disc_loss) = run_optimizer(batch_x)
    
    if step % 10 == 0:
        print("step: %i, gen_loss: %f, disc_loss: %f" % (step, gen_loss, disc_loss))