In [None]:
class CombinedLoss:
    def __init__(self, lambda_l1=1.0, lambda_content=1.0, lambda_adv=1e-3, lambda_adv_feat=1e-3):
        self.lambda_l1 = lambda_l1
        self.lambda_content = lambda_content
        self.lambda_adv = lambda_adv
        self.lambda_adv_feat = lambda_adv_feat
        
        # Load VGG16 model for perceptual loss
        vgg = VGG16(include_top=False, weights='imagenet')
        self.vgg_layers = [
            Model(inputs=vgg.input, outputs=vgg.get_layer('block1_conv2').output),
            Model(inputs=vgg.input, outputs=vgg.get_layer('block2_conv2').output),
            Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output),
            Model(inputs=vgg.input, outputs=vgg.get_layer('block4_conv3').output),
            Model(inputs=vgg.input, outputs=vgg.get_layer('block5_conv3').output)
        ]
        self.vgg_weights = [1, 1/2, 1/4, 1/8, 1/16]  # More weight on shallow layers

    def l1_loss(self, hr, sr):
        return tf.reduce_mean(tf.abs(hr - sr))
    
    def content_loss(self, hr, sr):
        loss = 0
        for model, weight in zip(self.vgg_layers, self.vgg_weights):
            hr_features = model(hr)
            sr_features = model(sr)
            loss += weight * tf.reduce_mean(tf.square(hr_features - sr_features))
        return loss
    
    def adversarial_loss(self, discriminator, lr, sr, hr):
        d_real = discriminator([hr, lr])
        d_fake = discriminator([sr, lr])
        return -tf.reduce_mean(tf.math.log(1 - d_real) - tf.math.log(d_fake))
    
    def adversarial_feature_loss(self, discriminator, lr, sr, hr):
        weights = [1/2, 1/4, 1/8, 1/16, 1/16]
        loss = 0
        feature_extractor = [Model(inputs=discriminator.inputs, outputs=discriminator.get_layer(f'conv_{i}').output) for i in range(5)]
        
        real_features = [fe([hr,lr]) for fe in feature_extractor]
        fake_features = [fe([sr,lr]) for fe in feature_extractor]

        # calculate the feature loss
        
        for i,(real,fake) in enumerate(zip(real_features,fake_features)):
            loss += weights[i] * tf.reduce_mean(tf.square(real-fake))
        
        return loss
    
    def total_loss(self, discriminator, lr, sr, hr):
        return (self.lambda_l1 * self.l1_loss(hr, sr) +
                self.lambda_content * self.content_loss(hr, sr) +
                self.lambda_adv * self.adversarial_loss(discriminator, lr, sr, hr) +
                self.lambda_adv_feat * self.adversarial_feature_loss(discriminator, lr, sr, hr))


In [None]:
# Binary loss function for the discriminator
def discriminator_loss(real_output,fake_output):
    bce=BinaryCrossentropy(from_logits=False)

    real_loss=bce(tf.ones_like(real_output),real_output)
    fake_loss=bce(tf.zeros_like(fake_output),fake_output)

    return real_loss+fake_loss