In [None]:
base_gen_optimizer = Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999)
base_dis_optimizer = Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999)


def training(HR_image, LR_image):
    with tf.GradientTape(persistent=True) as tape:
        # Forward pass
        gen_output = generator(LR_image, training=True)
        
        tf.debugging.check_numerics(gen_output, "Generator output contains invalid values")
        
        dis_real = discriminator([HR_image, LR_image], training=True)
        dis_fake = discriminator([gen_output, LR_image], training=True)
        
        # Calculate losses
        gen_loss = gen_losses.total_loss(discriminator, LR_image, gen_output, HR_image)
        dis_loss = discriminator_loss(dis_real, dis_fake)
        
    # Calculate gradients
    gen_grads = tape.gradient(gen_loss, generator.trainable_variables)
    dis_grads = tape.gradient(dis_loss, discriminator.trainable_variables)
    
    # Apply gradient clipping
    gen_grads = [tf.clip_by_norm(g, 1.0) for g in gen_grads]
    dis_grads = [tf.clip_by_norm(g, 1.0) for g in dis_grads]
    
    # Apply gradients
    base_gen_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
    base_dis_optimizer.apply_gradients(zip(dis_grads, discriminator.trainable_variables))
    
    # Calculate metrics
    psnr_value, ssim_value = metrics(HR_image, gen_output)
    
    
    return gen_loss, dis_loss, psnr_value, ssim_value

tf.debugging.enable_check_numerics()



In [None]:
def fit(training_dataset, epochs):
    gen_loss = []
    dis_loss = []
    psnr_list = []
    ssim_list = []
    
    for epoch in range(epochs):
        print(f"Epoch: {epoch + 1}/{epochs}")
        
        for n, (HR_image, LR_image) in enumerate(training_dataset):
            generator_loss, discriminator_loss, psnr_value, ssim_value = training(HR_image, LR_image)
            
            gen_loss.append(generator_loss.numpy())
            dis_loss.append(discriminator_loss.numpy())
            psnr_list.append(psnr_value.numpy())
            ssim_list.append(ssim_value.numpy())
            
            if n % 100 == 0:
                print(f"Step: {n}")
                print(f"Generator Total Loss: {generator_loss.numpy():.4f}")
                print(f"Discriminator Total Loss: {discriminator_loss.numpy():.4f}")
                print(f"PSNR Value: {psnr_value.numpy():.4f}")
                print(f"SSIM: {ssim_value.numpy():.4f}")
    
    return {
        'gen_loss': gen_loss,
        'dis_loss': dis_loss,
        'psnr': psnr_list,
        'ssim': ssim_list
    }