In [1]:
import torch

class Mix_Loss(torch.nn.Module):
    def __init__(self,alpha=0.025,c1=0.01,c2=0.03,sigma = (0.5, 1., 2., 4., 8.),
                 dtype = torch.float32,Even_input_allowed=False):
        super(Mix_Loss, self).__init__()
        self.C1 = c1 ** 2
        self.C2 = c2 ** 2
        self.sigma = sigma
        self.alpha = alpha
        
        self.dtype = dtype
        
        self.Even_input_allowed = Even_input_allowed
        
        if self.Even_input_allowed:
            print('UserWarning: If the input has an even height and width, it may not work properly.')
    def forward(self, img0, img1):
        if img0.shape != img1.shape:
            raise Exception("Inputs must have the same dimension.")
            
        if img0.shape[2] != img0.shape[3]:
            raise Exception("Must be the same height and width")

        if not(self.Even_input_allowed):
            if (img0.shape[3]%2) != 1 or (img0.shape[2]%2) != 1:
                raise Exception("The input height and width must be odd.")

        num_scale = len(self.sigma)
        self.width = img0.shape[3]
        self.channels = img0.shape[1]
        self.batch = img0.shape[0]

        self.w = torch.empty((num_scale, self.batch, self.channels, self.width, self.width),dtype=self.dtype)
        self.mux = torch.empty((num_scale, self.batch, self.channels, 1, 1),dtype=self.dtype)
        self.muy = torch.empty((num_scale, self.batch, self.channels, 1, 1),dtype=self.dtype)
        self.sigmax2 = torch.empty((num_scale, self.batch, self.channels, 1, 1),dtype=self.dtype)
        self.sigmay2 = torch.empty((num_scale, self.batch, self.channels, 1, 1),dtype=self.dtype)
        self.sigmaxy = torch.empty((num_scale, self.batch, self.channels, 1, 1),dtype=self.dtype)
        self.l = torch.empty((num_scale, self.batch, self.channels, 1, 1),dtype=self.dtype)
        self.cs = torch.empty((num_scale, self.batch, self.channels, 1, 1),dtype=self.dtype)
        
        for i in range(num_scale):
            gaussian = torch.exp(-1.*torch.arange(-(self.width/2), self.width/2,dtype=self.dtype)**2/(2*self.sigma[i]**2))
            gaussian = torch.einsum('p, qr->pqr',gaussian, gaussian.reshape((self.width, 1)))
            gaussian = gaussian/torch.sum(gaussian)
            gaussian = torch.reshape(gaussian, (1, 1, self.width, self.width))
            gaussian = gaussian.repeat(1,self.batch, self.channels, 1, 1)
            self.w[i,:,:,:,:] = gaussian

        self.bottom0data = img0.repeat(len(self.sigma), 1, 1, 1, 1)
        self.bottom1data = img1.repeat(len(self.sigma), 1, 1, 1, 1)

        self.mux = torch.sum(self.w * self.bottom0data, axis=(3, 4), keepdims=True)
        self.muy = torch.sum(self.w * self.bottom1data, axis=(3, 4), keepdims=True)
        self.sigmax2 = torch.sum(self.w * self.bottom0data ** 2, axis=(3, 4), keepdims=True) - self.mux **2
        self.sigmay2 = torch.sum(self.w * self.bottom1data ** 2, axis=(3, 4), keepdims=True) - self.muy **2
        self.sigmaxy = torch.sum(self.w * self.bottom0data * self.bottom1data, axis=(3, 4), keepdims=True) - self.mux * self.muy
        self.l = (2 * self.mux * self.muy + self.C1)/(self.mux ** 2 + self.muy **2 + self.C1)
        self.cs = (2 * self.sigmaxy + self.C2)/(self.sigmax2 + self.sigmay2 + self.C2)
        self.Pcs = torch.prod(self.cs, axis=0,dtype=self.dtype)

        loss_MSSSIM = 1 - torch.sum(self.l[-1, :, :, :, :] * self.Pcs)/(self.batch * self.channels)
        self.diff = img0 - img1
        loss_L1 = torch.sum(torch.abs(self.diff) * self.w[-1, :, :, :, :]) / (self.batch * self.channels)
        loss = self.alpha * loss_MSSSIM + (1-self.alpha) * loss_L1
        return loss

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mix_loss = Mix_Loss().to(device)
img0 = torch.zeros(1,3,129,129)
img1 = torch.ones(1,3,129,129)
d = mix_loss(img0, img1)
d

tensor(1.0000)

In [3]:
import tensorflow as tf

class Mix_Loss():
    def __init__(self,alpha=0.025,c1=0.01,c2=0.03,sigma = (0.5, 1., 2., 4., 8.),
                 dtype = tf.float32,Even_input_allowed=False):
        self.C1 = c1 ** 2
        self.C2 = c2 ** 2
        self.sigma = sigma
        self.alpha = alpha
        
        self.dtype = dtype
        
        self.Even_input_allowed = Even_input_allowed
        
        if self.Even_input_allowed:
            print('UserWarning: If the input has an even height and width, it may not work properly.')
    def __call__(self, img0, img1):
        if img0.shape != img1.shape:
            raise Exception("Inputs must have the same dimension.")
            
        if img0.shape[1] != img0.shape[2]:
            raise Exception("Must be the same height and width")

        if not(self.Even_input_allowed):
            if (img0.shape[1]%2) != 1 or (img0.shape[2]%2) != 1:
                raise Exception("The input height and width must be odd.")
                
        img0 = tf.cast(img0,dtype=self.dtype)
        img1 = tf.cast(img1,dtype=self.dtype)

        num_scale = len(self.sigma)
        self.width = img0.shape[1]
        self.channels = img0.shape[3]
        self.batch = img0.shape[0]
    
        self.w = []
        self.mux = tf.zeros((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.muy = tf.zeros((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.sigmax2 = tf.zeros((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.sigmay2 = tf.zeros((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.sigmaxy = tf.zeros((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.l = tf.zeros((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.cs = tf.zeros((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        
        for i in range(num_scale):
            gaussian = tf.exp(-1.*tf.range(-(self.width/2), self.width/2,dtype=self.dtype)**2/(2*self.sigma[i]**2))
            gaussian = tf.einsum('p, qr->pqr',gaussian, tf.reshape(gaussian,(self.width, 1)))
            gaussian = gaussian/tf.reduce_sum(gaussian)
            gaussian = tf.reshape(gaussian, (1, self.width, self.width,1))
            gaussian = tf.tile(gaussian, (self.batch, 1, 1, self.channels))
            self.w.append(gaussian[tf.newaxis])
        self.w = tf.concat(self.w,0)

        self.bottom0data = tf.tile(img0[tf.newaxis], (len(self.sigma), 1, 1, 1, 1))
        self.bottom1data = tf.tile(img1[tf.newaxis], (len(self.sigma), 1, 1, 1, 1))
        
        self.mux = tf.reduce_sum(self.w * self.bottom0data, axis=(2, 3), keepdims=True)
        self.muy = tf.reduce_sum(self.w * self.bottom1data, axis=(2, 3), keepdims=True)
        self.sigmax2 = tf.reduce_sum(self.w * self.bottom0data ** 2, axis=(2, 3), keepdims=True) - self.mux **2
        self.sigmay2 = tf.reduce_sum(self.w * self.bottom1data ** 2, axis=(2, 3), keepdims=True) - self.muy **2
        self.sigmaxy = tf.reduce_sum(self.w * self.bottom0data * self.bottom1data, axis=(2, 3), keepdims=True) - self.mux * self.muy
        self.l = (2 * self.mux * self.muy + self.C1)/(self.mux ** 2 + self.muy **2 + self.C1)
        self.cs = (2 * self.sigmaxy + self.C2)/(self.sigmax2 + self.sigmay2 + self.C2)
        self.Pcs = tf.reduce_prod(self.cs, axis=0)
        
        loss_MSSSIM = 1 - tf.reduce_sum(self.l[-1, :, :, :, :] * self.Pcs)/(self.batch * self.channels)
        self.diff = img0 - img1
        loss_L1 = tf.reduce_sum(tf.abs(self.diff) * self.w[-1, :, :, :, :]) / (self.batch * self.channels)
        loss = self.alpha * loss_MSSSIM + (1-self.alpha) * loss_L1
        return loss

In [4]:
mix_loss = Mix_Loss()
img0 = tf.zeros((1,129,129,3))
img1 = tf.ones((1,129,129,3))
d = mix_loss(img0, img1)
d

<tf.Tensor: shape=(), dtype=float32, numpy=0.9999975>

In [5]:
import numpy as np

class Mix_Loss():
    def __init__(self,alpha=0.025,c1=0.01,c2=0.03,sigma = (0.5, 1., 2., 4., 8.),
                 dtype = np.float32,Even_input_allowed=False):
        self.C1 = c1 ** 2
        self.C2 = c2 ** 2
        self.sigma = sigma
        self.alpha = alpha
        
        self.dtype = dtype
        
        self.Even_input_allowed = Even_input_allowed
        
        if self.Even_input_allowed:
            print('UserWarning: If the input has an even height and width, it may not work properly.')
    def __call__(self, img0, img1):
        if img0.shape != img1.shape:
            raise Exception("Inputs must have the same dimension.")
            
        if img0.shape[1] != img0.shape[2]:
            raise Exception("Must be the same height and width")

        if not(self.Even_input_allowed):
            if (img0.shape[1]%2) != 1 or (img0.shape[2]%2) != 1:
                raise Exception("The input height and width must be odd.")
                
        img0 = np.array(img0,dtype=self.dtype)
        img1 = np.array(img1,dtype=self.dtype)

        num_scale = len(self.sigma)
        self.width = img0.shape[2]
        self.channels = img0.shape[3]
        self.batch = img0.shape[0]

        self.w = np.empty((num_scale, self.batch, self.width, self.width, self.channels),dtype=self.dtype)
        self.mux = np.empty((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.muy = np.empty((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.sigmax2 = np.empty((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.sigmay2 = np.empty((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.sigmaxy = np.empty((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.l = np.empty((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)
        self.cs = np.empty((num_scale, self.batch, 1, 1, self.channels),dtype=self.dtype)

        for i in range(num_scale):
            gaussian = np.exp(-1.*np.arange(-(self.width/2), self.width/2,dtype=self.dtype)**2/(2*self.sigma[i]**2))
            gaussian = np.outer(gaussian, gaussian.reshape((self.width, 1)))
            gaussian = gaussian/np.sum(gaussian,dtype=self.dtype)
            gaussian = np.reshape(gaussian, (1, self.width, self.width, 1))
            gaussian = np.tile(gaussian, (self.batch, 1, 1, self.channels))
            self.w[i,:,:,:,:] = gaussian


        self.bottom0data = np.tile(img0, (len(self.sigma), 1, 1, 1, 1))
        self.bottom1data = np.tile(img1, (len(self.sigma), 1, 1, 1, 1))

        self.mux = np.sum(self.w * self.bottom0data, axis=(2, 3), keepdims=True,dtype=self.dtype)
        self.muy = np.sum(self.w * self.bottom1data, axis=(2, 3), keepdims=True,dtype=self.dtype)
        self.sigmax2 = np.sum(self.w * self.bottom0data ** 2, axis=(2, 3), keepdims=True,dtype=self.dtype) - self.mux **2
        self.sigmay2 = np.sum(self.w * self.bottom1data ** 2, axis=(2, 3), keepdims=True,dtype=self.dtype) - self.muy **2
        self.sigmaxy = np.sum(self.w * self.bottom0data * self.bottom1data, axis=(2, 3), keepdims=True,dtype=self.dtype) - self.mux * self.muy
        self.l = (2 * self.mux * self.muy + self.C1)/(self.mux ** 2 + self.muy **2 + self.C1)
        self.cs = (2 * self.sigmaxy + self.C2)/(self.sigmax2 + self.sigmay2 + self.C2)
        self.Pcs = np.prod(self.cs, axis=0,dtype=self.dtype)

        loss_MSSSIM = 1 - np.sum(self.l[-1, :, :, :, :] * self.Pcs,dtype=self.dtype)/(self.batch * self.channels)
        self.diff = img0 - img1
        loss_L1 = np.sum(np.abs(self.diff) * self.w[-1, :, :, :, :],dtype=self.dtype) / (self.batch * self.channels) 

        loss = self.alpha * loss_MSSSIM + (1-self.alpha) * loss_L1
        return loss

In [6]:
mix_loss = Mix_Loss()
img0 = np.zeros((1,129,129,3))
img1 = np.ones((1,129,129,3))
d = mix_loss(img0, img1)
d

0.9999974434671457