Skip to content

Commit

Permalink
Merge pull request #4 from wizenink/dualdupdate
Browse files Browse the repository at this point in the history
Dual Update pass Generator
  • Loading branch information
wizenink committed Feb 13, 2020
2 parents eef9e60 + 9081d59 commit 1776c5a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion models.py
Expand Up @@ -105,4 +105,4 @@ def Discriminator():
#generator = Generator()
#discriminator = Discriminator()
generator = pix2pix.unet_generator(2,norm_type='instancenorm')
discriminator = pix2pix.discriminator(norm_type='instancenorm',target=True)
discriminator = pix2pix.discriminator(norm_type='instancenorm',target=True)
14 changes: 10 additions & 4 deletions train.py
Expand Up @@ -87,17 +87,23 @@ def train_step(noise,input_image,target,generator_optimizer,discriminator_optimi

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

gen_output = generator(input_image,training=True)
gen_output = generator(input_image,training=False)

disc_real = discriminator([input_image,target],training=True)
disc_gen = discriminator([input_image,gen_output],training=True)
g_loss = gen_loss(disc_gen,gen_output,target)

#g_loss = gen_loss(disc_gen,gen_output,target)
d_loss_fake,d_loss_real = disc_loss(disc_real,disc_gen)
discriminator_gradients = disc_tape.gradient(d_loss_fake+d_loss_real,discriminator.trainable_variables)
discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.trainable_variables))

disc_gen = discriminator([input_image,gen_output],training=False)
gen_output = generator(input_image,training=True)
g_loss = gen_loss(disc_gen,gen_output,target)
#d_loss_fake,d_loss_real = disc_loss(disc_real,disc_gen)

generator_gradients = gen_tape.gradient(g_loss,generator.trainable_variables)
discriminator_gradients = disc_tape.gradient(d_loss_fake+d_loss_real,discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients,generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.trainable_variables))
tf.summary.histogram("generator gradients",generator_gradients[0],step=discriminator_optimizer.iterations)
tf.summary.histogram("discriminator_gradients",discriminator_gradients[0],step=discriminator_optimizer.iterations)
return g_loss,d_loss_fake,d_loss_real
Expand Down

0 comments on commit 1776c5a

Please sign in to comment.