In [None]:
class CombinedLoss:
    def __init__(self, discriminator, lambda_l1=1.0, lambda_content=1.0, lambda_adv=1e-3, lambda_adv_feat=1e-3):
        
        self.lambda_l1 = lambda_l1
        self.lambda_content = lambda_content
        self.lambda_adv = lambda_adv
        self.lambda_adv_feat = lambda_adv_feat
        
        # Initialize VGG19 
        self.layer_names = ['block1_conv2','block2_conv2', 'block3_conv3','block4_conv3','block5_conv3']
        self.layer_weights = [1.0, 1/2, 1/4, 1/8, 1/16]
        
        
        vgg = VGG19(include_top=False, weights='imagenet')
        vgg.trainable = False
        
        
        self.feature_extractors = [Model(inputs=vgg.input, outputs=vgg.get_layer(name).output) for name in self.layer_names]

        if discriminator is not None:
            layer_name = ['conv_0', 'conv_1', 'conv_3', 'conv_4']
            outputs = [discriminator.get_layer(name).output for name in layer_name]
            self.feature_model = Model(inputs=discriminator.inputs, outputs=outputs)  # Fixed: removed list brackets
            
    def l1_loss(self, hr, sr):
        if hr.shape != sr.shape:
            raise ValueError(f"Shape mismatch: hr {hr.shape} vs sr {sr.shape}")
        return tf.reduce_mean(tf.abs(hr - sr))
    
    def content_loss(self, hr, sr):
        
        
        hr_norm = (hr+1.0)/2.0
        sr_norm = (sr+1.0)/2.0
        
        hr_prep = tf.keras.applications.vgg16.preprocess_input(hr_norm * 255.0)
        sr_prep = tf.keras.applications.vgg16.preprocess_input(sr_norm * 255.0)
        
        loss = 0
        for model, weight in zip(self.feature_extractors, self.layer_weights):
            hr_features = model(hr_prep)
            sr_features = model(sr_prep)
            loss += weight * tf.reduce_mean(tf.square(hr_features - sr_features))
        return loss

    def adversarial_loss(self, discriminator, lr, sr, hr):
        d_fake = discriminator([sr, lr])
        return -tf.reduce_mean(tf.math.log(d_fake + 1e-10))

    def adversarial_feature_loss(self, lr, sr, hr):
        loss = 0
        real_features = self.feature_model([hr, lr])
        fake_features = self.feature_model([sr, lr])

        for real_feature, fake_feature in zip(real_features, fake_features):  
            loss += tf.reduce_mean(tf.square(real_feature - fake_feature))  
        return loss

    def __call__(self, discriminator, lr, sr, hr):
        l1_loss = self.lambda_l1 * self.l1_loss(hr, sr)
        content_loss = self.lambda_content * self.content_loss(hr, sr)
        adversarial = self.lambda_adv * self.adversarial_loss(discriminator, lr, sr, hr)  
        adv_feature_loss = self.lambda_adv_feat * self.adversarial_feature_loss(lr, sr, hr)

        total = l1_loss + content_loss + adversarial + adv_feature_loss

        return total


In [None]:
losses = CombinedLoss(dis)

In [None]:
# Binary loss function for the discriminator
def discriminator_loss(real_output,fake_output):
    bce=BinaryCrossentropy(from_logits=False)

    real_loss=bce(tf.ones_like(real_output),real_output)
    fake_loss=bce(tf.zeros_like(fake_output),fake_output)

    return real_loss+fake_loss