Simple Tensorflow Implementation of Spectral Normalization for Generative Adversarial Networks (ICLR 2018)
> python main.py --dataset mnist --sn True
def spectral_norm(w, iteration=1):
w_shape = w.shape.as_list()
w = tf.reshape(w, [-1, w_shape[-1]])
u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
u_hat = u
v_hat = None
for i in range(iteration):
v_ = tf.matmul(u_hat, tf.transpose(w))
v_hat = l2_norm(v_)
u_ = tf.matmul(v_hat, w)
u_hat = l2_norm(u_)
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
w_norm = w / sigma
with tf.control_dependencies([u.assign(u_hat)]):
w_norm = tf.reshape(w_norm, w_shape)
return w_norm
w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels])
b = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
x = tf.nn.conv2d(input=x, filter=spectral_norm(w), strides=[1, stride, stride, 1]) + b
Junho Kim