In [None]:
import tensorflow as tf
from tensorflow import keras
import time
import numpy as np

# build CycleGAN
class CycleGAN(keras.Model):
    def __init__(self, lambda_, img_shape, model_type,use_identity=False):
        super().__init__()
        self.lambda_ = lambda_
        self.img_shape = img_shape
        self.use_identity = use_identity

        self.g12 = self._get_generator("g12",model_type)
        self.g21 = self._get_generator("g21",model_type)
        self.d12 = self._get_discriminator("d12")
        self.d21 = self._get_discriminator("d21")

        self.opt_G = keras.optimizers.Adam(0.0002, beta_1=0.5)
        self.opt_D = keras.optimizers.Adam(0.0002, beta_1=0.5)
        self.loss_bool = keras.losses.BinaryCrossentropy(from_logits=True)
        self.loss_img = keras.losses.MeanAbsoluteError()  # a better result when using mse

        #summary
        # self.g12.summary
        # self.d12.summary
        
    def d_loss_wasserstein(self,real_logits,fake_logits):
        
        d_loss=tf.reduce_mean(fake_logits)-tf.reduce_mean(real_logits)

        return d_loss

    def g_loss_wasserstein(self,fake_logits):

        g_loss=-tf.reduce_mean(fake_logits)

        return g_loss

    def wasserstein_gradient_penalty(self,x,x_fake,y,discriminator):

        # temp_shape = [x.shape[0]]+[1 for _ in  range(len(x.shape)-1)]

        epsilon = tf.random.uniform([], 0.0, 1.0)
        x_hat = epsilon*x + (1 - epsilon) * x_fake
        
        with tf.GradientTape() as t:
          t.watch(x_hat)
          d_hat = discriminator([x_hat,y],training=False)
        gradients = t.gradient(d_hat, x_hat)
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients)))
        gradient_penalty = 1 * tf.reduce_mean((slopes - 1.0) ** 2)

        return gradient_penalty

    def _get_generator(self, name,model_type):
        if model_type=='sem':
          model = condition_mnist_uni_img2img_sem((28, 28, 1),name=name)
          print('Load ResNet with SEM module...')
        else:

          model = condition_mnist_uni_img2img((28, 28, 1),name=name)
          print('Load ResNet module...')

        return model

    def _get_discriminator(self, name):
        model = condition_mnist_uni_disc_cnn((28,28,1),name=name)
        return model

    def cycle_loss(self, real_img1,real_y1, real_img2,real_y2):
        fake2, fake1 = self.g12([real_img1,real_y2]), self.g21([real_img2,real_y1])

        loss1 = self.loss_img(real_img1, self.g21([fake2,real_y1]))
        loss2 = self.loss_img(real_img2, self.g12([fake1,real_y2]))

        cycle_loss = loss1 + loss2

        return cycle_loss, fake2, fake1

    def train_g(self, img1,y1, img2,y2,loss_type):
        with tf.GradientTape() as tape:
            cycle_loss, fake2, fake1 = self.cycle_loss(img1,y1, img2,y2)
            # img1
            pred2 = self.d12([fake2,y2])
            # img2
            pred1 = self.d21([fake1,y1])

            if loss_type=='wd' or loss_type=='wd-sem':
              # print('Load Wasserstein distance as loss...')
              #wd
              d_loss12=self.g_loss_wasserstein(pred2)
              # wd
              d_loss21=self.g_loss_wasserstein(pred1)

            else:
              # print('Load sliced Wasserstein distance as loss...')
              #swd
              r_y2 = self.d12([img2,y2])
              d_loss12 = sw_loss(r_y2, pred2,num_projections=32,batch_size=32)

              #swd
              r_y1 = self.d21([img1,y1])
              d_loss21=sw_loss(r_y1, pred1,num_projections=32,batch_size=32)

            #kl
            # d_loss12 = self.loss_bool(tf.ones_like(pred2), pred2)
            #kl
            # d_loss21 = self.loss_bool(tf.ones_like(pred1), pred1)

            loss12 = d_loss12
            loss21 = d_loss21           

            loss = loss12 + loss21+ self.lambda_ * cycle_loss
        var = self.g12.trainable_variables + self.g21.trainable_variables
        grads = tape.gradient(loss, var)
        self.opt_G.apply_gradients(zip(grads, var))

        return d_loss12 + d_loss21, cycle_loss

    def train_d(self, img1,y1_, img2,y2_):
        length = len(img1)  # length of img1=length of img2

        with tf.GradientTape() as d_tape:
            fake2, fake1 = self.g12([img1,y2_]), self.g21([img2,y1_])

            y_real = tf.ones((length, 1), tf.float32)
            y_fake = tf.zeros((length, 1), tf.float32)

            # adversarial_1
            y2 = self.d12([img2,y2_])
            pred2 = self.d12([fake2,y2_])
            #loss2_real = self.loss_bool(y_real, y2)
            #loss2_fake = self.loss_bool(y_fake, pred2)

            #loss_12 = loss2_real + loss2_fake

            loss_12=self.d_loss_wasserstein(y2,pred2)
            loss_12+=self.wasserstein_gradient_penalty(x=img2,x_fake=fake2,y=y2_,
                                                  discriminator=self.d12)
            # adversarial_2
            y1 = self.d21([img1,y1_])
            pred1 = self.d21([fake1,y1_])

            #loss1_real = self.loss_bool(y_real, y1)
            #loss1_fake = self.loss_bool(y_fake, pred1)

            #loss_21 = loss1_real + loss1_fake
            loss_21=self.d_loss_wasserstein(y1,pred1)
            loss_21+=self.wasserstein_gradient_penalty(x=img1,x_fake=fake1,y=y1_,
                                                  discriminator=self.d21)

            # total adversarial loss
            dis_loss = loss_12 + loss_21

        var = self.d12.trainable_variables + self.d21.trainable_variables
        dis_grads = d_tape.gradient(dis_loss, var)
        self.opt_D.apply_gradients(zip(dis_grads, var))
        return dis_loss

    def train_on_step(self, img1,y1, img2,y2,loss_type):
        # for _ in range(5):
        d_loss = self.train_d(img1,y1,img2,y2)
        g_loss, cyc_loss = self.train_g(img1,y1, img2,y2,loss_type)

        return g_loss, cyc_loss, d_loss

def train(seed,loss_type,gan, x0, y0, ds_x, ds_y, test6, testy_6, test9, testy_9, step, batch_size):
    loss_G=[]
    loss_D=[]
    loss_CYC=[]
    s_value=[]
    r_value=[]

    general='drive/Shared drives/Ziqiang/Mnist_condi/For_github/'+loss_type+'/'+seed
    dir_=general+'/visual/'
    dir_loss=general+'/loss/'
    dir_model=general+'/models/'
    dir_utils=general+'/others/'

    os.makedirs(dir_,exist_ok=True)
    os.makedirs(dir_loss,exist_ok=True)
    os.makedirs(dir_utils,exist_ok=True)
    t0 = time.time()

    rate=0
    for t in range(step):
        idx6 = np.random.randint(0, len(x0), batch_size)
        img6 = tf.gather(x0, idx6)
        y6 = tf.gather(y0, idx6)

        idx9 = np.random.randint(0, len(ds_x), batch_size)
        img9 = tf.gather(ds_x, idx9)
        y9 = tf.gather(ds_y, idx9)

        g_loss, d_loss, cyc_loss = gan.train_on_step(img6, y6, img9, y9,loss_type)

        current_score,x_eval,y_eval,img,imgs=save_gan(gan,t,img6=test6,img9=test9,
                                             y6=testy_6,y9=testy_9,)
        
        loss_G.append(g_loss.numpy())
        loss_D.append(d_loss.numpy())
        loss_CYC.append(cyc_loss.numpy())

        if t==0:
          score0=0.9
          print('initial:',score0)

          s_value.append([score0,t])
          r_value.append([0,t])

        
        else:
          score=current_score
          s_value.append([score,t])
          current_rate=1-score/score0

          if t%200==0:
            r_value.append([current_rate,t])
            print('Rate:',current_rate,'at step:',t,'Cost: gen',g_loss.numpy(), 'dis',d_loss.numpy(),'cyc', cyc_loss.numpy())
            rate=current_rate
            best_step=t

            #visual
            path=dir_+"{}.png".format(best_step)
            t_sne(x_eval,y_eval,n_class=18,savename=path)

            #models
            os.makedirs(dir_model+'model_'+str(best_step),exist_ok=True)
            gan.save_weights(dir_model+'model_'+str(t)+'/model.ckpt')

            #MNIST imgae
            _save_img2img_gan( best_step, img, imgs,seed,loss_type)

    
    #visual
    rate_last=1-current_score/score0
    r_value.append([rate_last,t])
    print('last rate:',rate_last)
    path=dir_+"/{}.png".format(t)
    t_sne(x_eval,y_eval,n_class=18,savename=path)

    #model
    gan.save_weights(dir_model+'model_'+str(t)+'/model.ckpt')

    #MNIST imgae
    _save_img2img_gan( ep, img, imgs,seed,loss_type)


    t1 = time.time()
    print('running time:',t1-t0)
    #loss
    np.savetxt(dir_loss+'loss_g.txt',np.array(loss_G))
    np.savetxt(dir_loss+'loss_d.txt',np.array(loss_D))
    np.savetxt(dir_loss+'loss_cyc.txt',np.array(loss_CYC))

    #others
    np.savetxt(dir_utils+'score.txt',np.array(s_value))
    np.savetxt(dir_utils+'rate.txt',np.array(r_value))

    print('running time:', t1 - t0)
    print('*'*20)
    print('Average rate is:',np.array(r_value)[1:,0].mean())
    print('*'*20)