In [259]:
import torch
import numpy as np
import cv2

In [224]:
from skimage.measure import compare_ssim, compare_psnr

TODO:
- how to compute R dynamically?

In [4]:
torch.randn(10,10).shape

torch.Size([10, 10])

In [238]:
def psnr(y_hat, y):
    '''
    Both inputs should be of shape:
    batch_size x n_channels x height x width
    
    formula from:
    https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio    
    '''
    
    assert y_hat.size() == y.size(), 'Input dimensions must match.'

    
    R = max(y.max(), y_hat.max()).float()
    batch_mse = (y_hat-y).pow(2).float().mean(dim=(1,2,3))
    
    batch_psnr = 20 * torch.log10(R) - 10 * torch.log10(batch_mse)
    
    return batch_psnr


def interpolation_error(y_hat, y):
    '''
    Computes pixel-wise RMSE over batches.
    Both inputs should be of shape:
    batch_size x n_channels x height x width
    '''
    
    assert y_hat.size() == y.size(), 'Input dimensions must match.'
    
    mse = (y_hat-y).pow(2).mean(dim=(1,2,3))
    
    return mse.sqrt()


def ssim(im1, im2, c1=1., c2=1., c3=1.):
    '''
    Computes the SSIM for color images.
    Both inputs should be of shape:
    batch_size x n_channels x height x width
    '''
    
    assert im1.size() == im2.size(), 'Input dimensions must match.'
    
    im1 = im1.float()
    im2 = im2.float()
    
    # compute image statistics
    mu1 = im1.mean(dim=(1,2,3), keepdims=True)
    mu2 = im2.mean(dim=(1,2,3), keepdims=True)
    sig1 = im1.std(dim=(1,2,3), keepdims=True)
    sig2 = im2.std(dim=(1,2,3), keepdims=True)
    
    
    # calculate covariance
    e1 = (im1-mu1).view(im1.size(0), -1)
    e2 = (im2-mu2).view(im2.size(0), -1)
    cov = (e1 * e2).sum(dim=1) / (e1.size(1)-1)
    
    # remove dimensions
    mu1 = mu1.squeeze()
    mu2 = mu2.squeeze()
    sig1 = sig1.squeeze()
    sig2 = sig2.squeeze()
    
    print(mu1.shape, mu2.shape, e1.shape, e2.shape)
    print(sig1.shape, sig2.shape, cov.shape)
    
    # compute ssim
    L = (2 * mu1 * mu2 + c1) / (mu1**2 + mu2**2 + c1)
    C = (2 * sig1 * sig2 + c2) / (sig1**2 + sig2**2 + c2)
    S = (cov + c3) / (sig1 * sig2 + c3)
    
    return L * C * S
    
    
    
    



    # b_psnr = psnr(y_hat, y)

In [190]:
%%time
s = ssim(y_hat, y)

torch.Size([16]) torch.Size([16]) torch.Size([16, 6220800]) torch.Size([16, 6220800])
torch.Size([16]) torch.Size([16]) torch.Size([16])
Wall time: 1.07 s


In [241]:
im1 = y_hat.float()
im2 = y.float()

In [248]:
(y_hat-y

RuntimeError: Can only calculate the mean of floating types. Got Long instead.

In [191]:
im1 = im1.cuda()
im2 = im2.cuda()


In [218]:
%%time
s = ssim(im1, im2)

torch.Size([16]) torch.Size([16]) torch.Size([16, 6220800]) torch.Size([16, 6220800])
torch.Size([16]) torch.Size([16]) torch.Size([16])
Wall time: 6.98 ms


In [61]:
def create_pred_label_random(dims = (16,3,1080,1920)):
    y_hat = torch.randint(low=0, high=256, size=dims)
    y     = torch.randint(low=0, high=256, size=dims)
    return y_hat, y

In [274]:
y_hat, y = create_pred_label_random()

In [278]:
y_hat /= 255
y /= 255

In [226]:
## check psnr

In [284]:
y_hat[0].shape

torch.Size([3, 1080, 1920])

In [285]:
ssims = []
for i in range(y_hat.size(0)):
    psnr_score = compare_ssim(y_hat[i].numpy(), y[i].numpy(), multichannel=True)
    psnrs.append(psnr_score)

ValueError: win_size exceeds image extent.  If the input is a multichannel (color) image, set multichannel=True.

In [280]:
torch.from_numpy(np.array(psnrs))

tensor([69.1972, 69.2323, 69.2179, 69.2017, 69.2298, 69.2364, 69.1906, 69.2208,
        69.2425, 69.2087, 69.2028, 69.2155, 69.2075, 69.2181, 69.2208, 69.2033],
       dtype=torch.float64)

In [281]:
psnr(y_hat, y)

tensor([21.0664, 21.1015, 21.0871, 21.0709, 21.0990, 21.1056, 21.0598, 21.0900,
        21.1117, 21.0779, 21.0720, 21.0847, 21.0767, 21.0873, 21.0900, 21.0725])

tensor(1)

In [265]:
y_hat /= 255
y /= 255

In [258]:
(y_hat-y).pow(2).float().mean(dim=(1,2,3))

tensor([10922.3701, 10929.9268, 10923.8184, 10924.7900, 10928.0312, 10928.7568,
        10925.1611, 10926.7646, 10930.7734, 10922.0234, 10927.4365, 10916.8906,
        10922.7256, 10924.7266, 10927.9043, 10934.1270])