# Test spectral and histogram loss for 3D images
Jan 6, 2021

In [1]:
import os
import random
import logging
import sys

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

import argparse
import time
from datetime import datetime
import glob
import pickle
import yaml
import collections

In [3]:
## numpy code
def f_radial_profile(data, center=(None,None)):
    ''' Module to compute radial profile of a 2D image '''
    y, x = np.indices((data.shape)) # Get a grid of x and y values
    
    if center[0]==None and center[1]==None:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0]) # compute centers
        
    # get radial values of every pair of points
    r = np.sqrt((x - center[0])**2 + (y - center[1])**2)
    r = r.astype(np.int)
    
    # Compute histogram of r values
    tbin = np.bincount(r.ravel(), data.ravel())
    nr = np.bincount(r.ravel()) 
    radialprofile = tbin / nr
    
    return radialprofile

def f_compute_spectrum(arr):
#     GLOBAL_MEAN=1.0
#     arr=((arr - GLOBAL_MEAN)/GLOBAL_MEAN)
    y1=np.fft.fft2(arr)
    y1=np.fft.fftshift(y1)
    y2=abs(y1)**2
    z1=f_radial_profile(y2)
    return(z1)
   
def f_compute_batch_spectrum(arr):
    batch_pk=np.array([f_compute_spectrum(i) for i in arr])
    return batch_pk


### Code ###
def f_image_spectrum(x):
    '''
    Data has to be in the form (batch,channel,x,y)
    '''
    print(x.shape)
    mean=[[] for i in range(num_channels)]    
    sdev=[[] for i in range(num_channels)]    

    for i in range(num_channels):
        arr=x[:,i,:,:]
#         print(i,arr.shape)
        batch_pk=f_compute_batch_spectrum(arr)
#         print(batch_pk)
        mean[i]=np.mean(batch_pk,axis=0)
        sdev[i]=np.std(batch_pk,axis=0)
    mean=np.array(mean)
    sdev=np.array(sdev)
    return mean,sdev


In [4]:
####################
### Pytorch code ###
####################

def f_torch_radial_profile(img, center=(None,None)):
    ''' Module to compute radial profile of a 2D image 
    Bincount causes issues with backprop, so not using this code
    '''
    
    y,x=torch.meshgrid(torch.arange(0,img.shape[0]),torch.arange(0,img.shape[1])) # Get a grid of x and y values
    if center[0]==None and center[1]==None:
        center = torch.Tensor([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0]) # compute centers

    # get radial values of every pair of points
    r = torch.sqrt((x - center[0])**2 + (y - center[1])**2)
    r= r.int()
    
#     print(r.shape,img.shape)
    # Compute histogram of r values
    tbin=torch.bincount(torch.reshape(r,(-1,)),weights=torch.reshape(img,(-1,)).type(torch.DoubleTensor))
    nr = torch.bincount(torch.reshape(r,(-1,)))
    radialprofile = tbin / nr
    
    return radialprofile[1:-1]


def f_torch_get_azimuthalAverage_with_batch(image, center=None): ### Not used in this code.
    """
    Calculate the azimuthally averaged radial profile. Only use if you need to combine batches

    image - The 2D image
    center - The [x,y] pixel coordinates used as the center. The default is 
             None, which then uses the center of the image (including 
             fracitonal pixels).
    source: https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    
    batch, channel, height, width = image.shape
    # Create a grid of points with x and y coordinates
    y, x = np.indices([height,width])

    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0])

    # Get the radial coordinate for every grid point. Array has the shape of image
    r = torch.tensor(np.hypot(x - center[0], y - center[1]))

    # Get sorted radii
    ind = torch.argsort(torch.reshape(r, (batch, channel,-1)))
    r_sorted = torch.gather(torch.reshape(r, (batch, channel, -1,)),2, ind)
    i_sorted = torch.gather(torch.reshape(image, (batch, channel, -1,)),2, ind)

    # Get the integer part of the radii (bin size = 1)
    r_int=r_sorted.to(torch.int32)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[:,:,1:] - r_int[:,:,:-1]  # Assumes all radii represented
    rind = torch.reshape(torch.where(deltar)[2], (batch, -1))    # location of changes in radius
    rind=torch.unsqueeze(rind,1)
    nr = (rind[:,:,1:] - rind[:,:,:-1]).type(torch.float)       # number of radius bin

    # Cumulative sum to figure out sums for each radius bin

    csum = torch.cumsum(i_sorted, axis=-1)
#     print(csum.shape,rind.shape,nr.shape)

    tbin = torch.gather(csum, 2, rind[:,:,1:]) - torch.gather(csum, 2, rind[:,:,:-1])
    radial_prof = tbin / nr

    return radial_prof


def f_get_rad(img):
    ''' Get the radial tensor for use in f_torch_get_azimuthalAverage '''
    
    height,width=img.shape[-2:]
    # Create a grid of points with x and y coordinates
    y, x = np.indices([height,width])
    
    center=[]
    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0])

    # Get the radial coordinate for every grid point. Array has the shape of image
    r = torch.tensor(np.hypot(x - center[0], y - center[1]))
    
    # Get sorted radii
    ind = torch.argsort(torch.reshape(r, (-1,)))
    
    return r.detach(),ind.detach()


def f_torch_get_azimuthalAverage(image,r,ind):
    """
    Calculate the azimuthally averaged radial profile.

    image - The 2D image
    center - The [x,y] pixel coordinates used as the center. The default is 
             None, which then uses the center of the image (including 
             fracitonal pixels).
    source: https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    
#     height, width = image.shape
#     # Create a grid of points with x and y coordinates
#     y, x = np.indices([height,width])

#     if not center:
#         center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0])

#     # Get the radial coordinate for every grid point. Array has the shape of image
#     r = torch.tensor(np.hypot(x - center[0], y - center[1]))

#     # Get sorted radii
#     ind = torch.argsort(torch.reshape(r, (-1,)))

    r_sorted = torch.gather(torch.reshape(r, ( -1,)),0, ind)
    i_sorted = torch.gather(torch.reshape(image, ( -1,)),0, ind)
    
    # Get the integer part of the radii (bin size = 1)
    r_int=r_sorted.to(torch.int32)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[1:] - r_int[:-1]  # Assumes all radii represented
    rind = torch.reshape(torch.where(deltar)[0], (-1,))    # location of changes in radius
    nr = (rind[1:] - rind[:-1]).type(torch.float)       # number of radius bin

    # Cumulative sum to figure out sums for each radius bin
    
    csum = torch.cumsum(i_sorted, axis=-1)
    tbin = torch.gather(csum, 0, rind[1:]) - torch.gather(csum, 0, rind[:-1])
    radial_prof = tbin / nr

    return radial_prof

def f_torch_fftshift(real, imag):
    for dim in range(0, len(real.size())):
        real = torch.roll(real, dims=dim, shifts=real.size(dim)//2)
        imag = torch.roll(imag, dims=dim, shifts=imag.size(dim)//2)
    return real, imag

def f_torch_compute_spectrum(arr,r,ind):
    
    GLOBAL_MEAN=1.0
    arr=(arr-GLOBAL_MEAN)/(GLOBAL_MEAN)
    y1=torch.rfft(arr,signal_ndim=2,onesided=False)
    real,imag=f_torch_fftshift(y1[:,:,0],y1[:,:,1])    ## last index is real/imag part
    y2=real**2+imag**2     ## Absolute value of each complex number
    
#     print(y2.shape)
    z1=f_torch_get_azimuthalAverage(y2,r,ind)     ## Compute radial profile
    
    return z1

def f_torch_compute_batch_spectrum(arr,r,ind):
    
    batch_pk=torch.stack([f_torch_compute_spectrum(i,r,ind) for i in arr])
    
    return batch_pk

def f_torch_image_spectrum(x,num_channels,r,ind):
    '''
    Data has to be in the form (batch,channel,x,y)
    '''
    mean=[[] for i in range(num_channels)]    
    sdev=[[] for i in range(num_channels)]    

    for i in range(num_channels):
        arr=x[:,i,:,:]
        batch_pk=f_torch_compute_batch_spectrum(arr,r,ind)
        mean[i]=torch.mean(batch_pk,axis=0)
#         sdev[i]=torch.std(batch_pk,axis=0)/np.sqrt(batch_pk.shape[0])
#         sdev[i]=torch.std(batch_pk,axis=0)
        sdev[i]=torch.var(batch_pk,axis=0)
    
    mean=torch.stack(mean)
    sdev=torch.stack(sdev)
        
    return mean,sdev

def f_compute_hist(data,bins):
    
    try: 
        hist_data=torch.histc(data,bins=bins)
        ## A kind of normalization of histograms: divide by total sum
        hist_data=(hist_data*bins)/torch.sum(hist_data)
    except Exception as e:
        print(e)
        hist_data=torch.zeros(bins)

    return hist_data

### Losses 
def loss_spectrum(spec_mean,spec_mean_ref,spec_std,spec_std_ref,image_size,lambda1):
    ''' Loss function for the spectrum : mean + variance 
    Log(sum( batch value - expect value) ^ 2 )) '''
    
    idx=int(image_size/2) ### For the spectrum, use only N/2 indices for loss calc.
    ### Warning: the first index is the channel number.For multiple channels, you are averaging over them, which is fine.
        
    spec_mean=torch.log(torch.mean(torch.pow(spec_mean[:,:idx]-spec_mean_ref[:,:idx],2)))
    spec_sdev=torch.log(torch.mean(torch.pow(spec_std[:,:idx]-spec_std_ref[:,:idx],2)))
    
    lambda1=lambda1;
    lambda2=lambda1;
    ans=lambda1*spec_mean+lambda2*spec_sdev
    
    if torch.isnan(spec_sdev).any():    print("spec loss with nan",ans)
    
    return ans
    
def loss_hist(hist_sample,hist_ref):
    
    lambda1=1.0
    return lambda1*torch.log(torch.mean(torch.pow(hist_sample-hist_ref,2)))



In [5]:
## Read input
ip_fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/3d_data/full_1.npy'
img=np.load(ip_fname,mmap_mode='r')[:20].transpose(0,1,2,3)
img=np.expand_dims(img,axis=1).astype(float)
t_img=torch.from_numpy(img)

In [6]:
t_img.shape

torch.Size([20, 1, 64, 64, 64])

In [7]:
f_compute_hist(t_img,bins=50)

tensor([4.9749e+01, 1.4340e-01, 4.4575e-02, 2.1524e-02, 1.2398e-02, 7.7057e-03,
        5.0068e-03, 3.5572e-03, 2.8324e-03, 2.2221e-03, 1.6308e-03, 1.2302e-03,
        8.7738e-04, 8.6784e-04, 5.4359e-04, 4.3869e-04, 3.7193e-04, 3.4332e-04,
        2.1935e-04, 2.1935e-04, 1.6212e-04, 1.4305e-04, 1.0490e-04, 7.6294e-05,
        7.6294e-05, 8.5831e-05, 8.5831e-05, 6.6757e-05, 6.6757e-05, 0.0000e+00,
        3.8147e-05, 9.5367e-06, 2.8610e-05, 1.9073e-05, 1.9073e-05, 3.8147e-05,
        2.8610e-05, 3.8147e-05, 1.9073e-05, 9.5367e-06, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 9.5367e-06, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 9.5367e-06], dtype=torch.float64)

In [8]:
def f_torch_get_azimuthalAverage_with_batch(image, center=None): ### Not used in this code.
    """
    Calculate the azimuthally averaged radial profile. Only use if you need to combine batches

    image - The 2D image
    center - The [x,y] pixel coordinates used as the center. The default is 
             None, which then uses the center of the image (including 
             fracitonal pixels).
    source: https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    
    batch, channel, height, width,depth = image.shape
    # Create a grid of points with x and y coordinates
    x,y,z = np.indices([height,width,depth])

    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0,(z.max()-z.min())/2.0])

    # Get the radial coordinate for every grid point. Array has the shape of image
    r = torch.tensor(np.hypot(x - center[0], y - center[1],z-center[2]))
    
    print(r)
    # Get sorted radii
    print(batch,channel)
    ind = torch.argsort(torch.reshape(r, (batch, channel,-1)))
    r_sorted = torch.gather(torch.reshape(r, (batch, channel, -1,)),2, ind)
    i_sorted = torch.gather(torch.reshape(image, (batch, channel, -1,)),2, ind)

    # Get the integer part of the radii (bin size = 1)
    r_int=r_sorted.to(torch.int32)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[:,:,1:] - r_int[:,:,:-1]  # Assumes all radii represented
    rind = torch.reshape(torch.where(deltar)[2], (batch, -1))    # location of changes in radius
    rind=torch.unsqueeze(rind,1)
    nr = (rind[:,:,1:] - rind[:,:,:-1]).type(torch.float)       # number of radius bin

    # Cumulative sum to figure out sums for each radius bin

    csum = torch.cumsum(i_sorted, axis=-1)
#     print(csum.shape,rind.shape,nr.shape)

    tbin = torch.gather(csum, 2, rind[:,:,1:]) - torch.gather(csum, 2, rind[:,:,:-1])
    radial_prof = tbin / nr

    return radial_prof


f_torch_get_azimuthalAverage_with_batch(t_img[:,:,:4,:4,:4])

tensor([[[2.1213, 2.1213, 2.1213, 2.1213],
         [1.5811, 1.5811, 1.5811, 1.5811],
         [1.5811, 1.5811, 1.5811, 1.5811],
         [2.1213, 2.1213, 2.1213, 2.1213]],

        [[1.5811, 1.5811, 1.5811, 1.5811],
         [0.7071, 0.7071, 0.7071, 0.7071],
         [0.7071, 0.7071, 0.7071, 0.7071],
         [1.5811, 1.5811, 1.5811, 1.5811]],

        [[1.5811, 1.5811, 1.5811, 1.5811],
         [0.7071, 0.7071, 0.7071, 0.7071],
         [0.7071, 0.7071, 0.7071, 0.7071],
         [1.5811, 1.5811, 1.5811, 1.5811]],

        [[2.1213, 2.1213, 2.1213, 2.1213],
         [1.5811, 1.5811, 1.5811, 1.5811],
         [1.5811, 1.5811, 1.5811, 1.5811],
         [2.1213, 2.1213, 2.1213, 2.1213]]], dtype=torch.float64)
20 1


RuntimeError: shape '[20, 1, -1]' is invalid for input of size 64

In [9]:
img=t_img[0,0,:,:,:]
print(img.shape)
r,ind=f_get_rad(t_img)

torch.Size([64, 64, 64])


In [17]:
# y1=torch.rfft(t_img[0,0,:,:,:],signal_ndim=3,onesided=False)
# y1.shape


In [11]:
r,ind

(tensor([[44.5477, 43.8463, 43.1567,  ..., 43.1567, 43.8463, 44.5477],
         [43.8463, 43.1335, 42.4323,  ..., 42.4323, 43.1335, 43.8463],
         [43.1567, 42.4323, 41.7193,  ..., 41.7193, 42.4323, 43.1567],
         ...,
         [43.1567, 42.4323, 41.7193,  ..., 41.7193, 42.4323, 43.1567],
         [43.8463, 43.1335, 42.4323,  ..., 42.4323, 43.1335, 43.8463],
         [44.5477, 43.8463, 43.1567,  ..., 43.1567, 43.8463, 44.5477]],
        dtype=torch.float64),
 tensor([2016, 2079, 2015,  ..., 4032,   63, 4095]))

In [14]:
# f_torch_image_spectrum(t_img,1,r,ind)

In [28]:
def f_torch_get_azimuthalAverage(image,r,ind):
    """
    Calculate the azimuthally averaged radial profile.

    image - The 2D image
    center - The [x,y] pixel coordinates used as the center. The default is 
             None, which then uses the center of the image (including 
             fracitonal pixels).
    source: https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    
#     height, width = image.shape
#     # Create a grid of points with x and y coordinates
#     y, x = np.indices([height,width])

#     if not center:
#         center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0])

#     # Get the radial coordinate for every grid point. Array has the shape of image
#     r = torch.tensor(np.hypot(x - center[0], y - center[1]))

#     # Get sorted radii
#     ind = torch.argsort(torch.reshape(r, (-1,)))

    r_sorted = torch.gather(torch.reshape(r, ( -1,)),0, ind)
    i_sorted = torch.gather(torch.reshape(image, ( -1,)),0, ind)
    
    # Get the integer part of the radii (bin size = 1)
    r_int=r_sorted.to(torch.int32)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[1:] - r_int[:-1]  # Assumes all radii represented
    rind = torch.reshape(torch.where(deltar)[0], (-1,))    # location of changes in radius
    nr = (rind[1:] - rind[:-1]).type(torch.float)       # number of radius bin

    # Cumulative sum to figure out sums for each radius bin
    
    csum = torch.cumsum(i_sorted, axis=-1)
    tbin = torch.gather(csum, 0, rind[1:]) - torch.gather(csum, 0, rind[:-1])
    radial_prof = tbin / nr

    return radial_prof

def f_torch_fftshift(real, imag):
    for dim in range(0, len(real.size())):
        real = torch.roll(real, dims=dim, shifts=real.size(dim)//2)
        imag = torch.roll(imag, dims=dim, shifts=imag.size(dim)//2)
    return real, imag

def f_torch_compute_spectrum(arr,r,ind):
    
    GLOBAL_MEAN=1.0
    arr=(arr-GLOBAL_MEAN)/(GLOBAL_MEAN)
    y1=torch.rfft(arr,signal_ndim=3,onesided=False)
    print(y1.shape)
    real,imag=f_torch_fftshift(y1[:,:,:,0],y1[:,:,:,1])    ## last index is real/imag part
    print(real.shape,imag.shape)
    y2=real**2+imag**2     ## Absolute value of each complex number
    print(y2.shape)
#     print(y2.shape)
    z1=f_torch_get_azimuthalAverage(y2,r,ind)     ## Compute radial profile
    print(z1.shape)
    return z1

In [29]:
f_torch_compute_spectrum(t_img[0,0,:,:,:],r,ind)

torch.Size([64, 64, 64, 2])
torch.Size([64, 64, 64]) torch.Size([64, 64, 64])
torch.Size([64, 64, 64])
torch.Size([43])


tensor([ 2508273.9225,  7070027.4958,  6802440.1899, 10405998.4226,
        10110390.1638,  8101516.8540,  5885862.6172,  5022353.3353,
         4669633.9063,  5994386.0244,  5980273.9324,  5358895.7123,
         5097610.9158,  4571598.5829,  3101416.3273,  3909796.4399,
         3754593.0063,  3764176.7559,  3314507.3260,  3643855.6688,
         3397347.0271,  2734598.3531,  2899227.8991,  2708118.4934,
         2183583.3825,  2235347.4285,  1958966.0722,  2510536.4087,
         2259801.8579,  2138486.3777,  1844594.8609,  1890788.9211,
         1733705.2404,  1583946.6509,  1464912.0189,  1459921.3987,
         1524975.7922,  1520010.0394,   923301.4526,   780980.9213,
          945995.0148,   696427.8134,  1537155.3631], dtype=torch.float64)

In [38]:
def f_get_rad(img):
    ''' Get the radial tensor for use in f_torch_get_azimuthalAverage '''
    
    height,width,depth=img.shape[-3:]
    # Create a grid of points with x and y coordinates
    x, y, z = np.indices([height,width,depth])
    
    center=[]
    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0, (z.max()-z.min())/2.0])

    # Get the radial coordinate for every grid point. Array has the shape of image
    r = torch.tensor(np.hypot(x - center[0], y - center[1], z-center[2]))
    
    # Get sorted radii
    ind = torch.argsort(torch.reshape(r, (-1,)))
    
    return r.detach(),ind.detach()

r,ind=f_get_rad(t_img)
print(r.shape,ind.shape)

torch.Size([64, 64, 64]) torch.Size([262144])


In [41]:
torch.m

tensor(44.5477, dtype=torch.float64)

In [45]:
np.linalg.norm(np.array([63,63,63])-np.array([0,0,0]))
np.sqrt(63**3)

500.0469977912076

In [36]:
? np.hypot


[0;31mCall signature:[0m   [0mnp[0m[0;34m.[0m[0mhypot[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m            ufunc
[0;31mString form:[0m     <ufunc 'hypot'>
[0;31mFile:[0m            ~/.conda/envs/v3/lib/python3.8/site-packages/numpy/__init__.py
[0;31mDocstring:[0m      
hypot(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])

Given the "legs" of a right triangle, return its hypotenuse.

Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise.  If `x1` or
`x2` is scalar_like (i.e., unambiguously cast-able to a scalar type),
it is broadcast for use with each element of the other argument.
(See Examples)

Parameters
----------
x1, x2 : array_like
    Leg of the triangle(s).
    If ``x1.shape != x2.shape``, they must be broadcastable to a common
    shape (which becomes the shape of the output).
out : ndarray, None, or tuple of 