### Noah's Code

In [160]:
import tensorflow as tf
import numpy as np
import torch
def torch_autocorr(x):
    if len(x.shape) < 4:
        x = torch.unsqueeze(x, dim=0)

    x = x.detach().cpu().numpy()

    dim = x.shape[1]

    x = 2*x - 1
    x = tf.transpose(x, perm=[0, 3, 1, 2])
    x = tf.cast(x, tf.complex64)

    m = x
    M = np.array(tf.signal.fft2d(m))

    mag = tf.math.abs(M)
    mag = tf.cast(mag, tf.complex64)

    ang = tf.math.atan2(tf.math.imag(M), tf.math.real(M))
    ang = tf.cast(ang, tf.complex64)

    exp1 = tf.math.exp(tf.dtypes.complex(0., -1.)*ang)
    exp2 = tf.math.exp(tf.dtypes.complex(0., 1.)*ang)

    term1 = mag*exp1
    term2 = mag*exp2

    FFtmp = (term1*term2)/(32**2)

    autocorr = tf.signal.ifft2d(FFtmp)
    autocorr = tf.math.real(autocorr)

    autocorr = tf.transpose(autocorr, perm=[0, 2, 3, 1])

    rv = np.int32(np.floor(dim/2))
    #rv = tf.cast(tf.math.floor(dim/2), tf.int32)
    # autocorr = tf.roll(autocorr, rv, 1)
    autocorr = tf.roll(autocorr, rv, 2)
    # # autocorr = tf.signal.ifftshift(autocorr, axes=(1, 2))


    # # Convert the TensorFlow tensor to a NumPy array
    # np_array = autocorr.numpy()
    np_array = autocorr.numpy()
    # Create a PyTorch tensor from the NumPy array
    torch_tensor_autocorr = torch.from_numpy(np_array)

    return torch_tensor_autocorr

### Torch

In [166]:
import torch
import torch.nn as nn

def autocorr(x):
    """
    x is a 3 dimensional torch tensor. 
    """
    if len(x.shape) < 4:
        x = torch.unsqueeze(x, dim=0)

    dim = x.shape[1]

    x = 2*x - 1
    x = torch.permute(x, (0, 3, 1, 2))
    x_complex = x.type(torch.complex64)

    M = torch.fft.fft2(x_complex)
    mag_M_complex = torch.abs(M).type(torch.complex64)

    ang = torch.atan2(M.imag, M.real)
    ang = ang.type(torch.complex64)

    exp1 = torch.exp(torch.complex(torch.tensor([0], dtype=torch.float32), torch.tensor([-1], dtype=torch.float32))*ang)
    exp2 = torch.exp(torch.complex(torch.tensor([0], dtype=torch.float32), torch.tensor([1], dtype=torch.float32))*ang)

    term1 = mag_M_complex*exp1
    term2 = mag_M_complex*exp2

    FFtmp = (term1*term2) / (32**2)

    autocorr = torch.fft.ifft2(FFtmp) # till here the values are the same upto 1e-7 tolerance
    autocorr = autocorr.real
    autocorr = torch.permute(autocorr, (0, 2, 3, 1))

    rv = torch.tensor(dim//2, dtype=torch.int32)
    # #autocorr = torch.roll(autocorr, rv, 1)
    # #autocorr = torch.roll(autocorr, rv, 2)
    autocorr = torch.roll(autocorr, dim//2, 2)
    # return autocorr

    return autocorr    

### testing code

In [170]:
test_input = torch.randn((3, 64, 64))
tolerance = 1e-07
def equal_within_tolerance(arr1, arr2, ):
    return np.isclose(arr1.numpy(), arr2.numpy(), 
                      atol=tolerance
                      )

In [174]:
output1 = torch_autocorr(test_input)
output2 = autocorr(test_input)
print(output1.shape)
print(output2.shape)
print(f"All values equal: {(output1==output2).all()}")
print(equal_within_tolerance(output1, output2))
print(f"All values within the tolerance of {tolerance}: {equal_within_tolerance(output1, output2).all()}")
print(f"{np.sum((equal_within_tolerance(output1, output2))==True)} values within the tolerance value of {tolerance}.")
print(f"{np.sum((equal_within_tolerance(output1, output2))==False)} not values within the tolerance value {tolerance}.")

torch.Size([1, 3, 64, 64])
torch.Size([1, 3, 64, 64])
All values equal: False
[[[[ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]
   ...
   [ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]]

  [[ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]
   ...
   [ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]]

  [[ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]
   ...
   [ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]
   [ True  True  True ...  True  True  True]]]]
All values within the tolerance of 1e-07: False
12285 values within the tolerance va

### Loss Definition

In [175]:
class FFTLoss(nn.Module):
    def __init__(self):
        super(FFTLoss, self).__init__()
        self.mae_loss = nn.L1Loss()
    
    def forward(self, input, target):
        input_autocorr = autocorr(input)
        target_autocorr = autocorr(target)
        diff = self.mae_loss(input_autocorr, target_autocorr)
        return diff

In [178]:
# test the loss
import torch
import torch.nn as nn

input = torch.randn((3, 64, 64), requires_grad=True)
target = torch.randn((3, 64, 64))

mae_loss = nn.L1Loss()
output = mae_loss(input, target)
output.backward()

print('input: ', input)
print('target: ', target)
print('output: ', output)

input:  tensor([[[ 0.7833,  1.7424,  0.8350,  ...,  0.7731,  0.2165,  0.2028],
         [-0.3478, -0.4437, -1.7366,  ..., -1.1579, -0.9653, -0.7782],
         [-0.5331, -0.7700,  0.3380,  ...,  0.3561,  0.0362,  1.5044],
         ...,
         [-0.1079,  1.8757, -0.0354,  ...,  1.7489, -2.1834,  1.2323],
         [ 0.4964, -1.0542,  0.6126,  ...,  1.3857,  2.0068, -0.9700],
         [ 1.5098, -0.2395, -1.3822,  ...,  1.2966,  0.2667, -0.0671]],

        [[ 0.6029, -1.5645,  0.8020,  ...,  0.0533, -0.7314,  0.7086],
         [ 0.2386, -1.9438,  3.0096,  ..., -0.3898, -1.2965, -0.1086],
         [-0.7281,  0.9692,  0.8955,  ..., -0.0182,  1.1141, -0.5534],
         ...,
         [-0.3685, -0.7304, -0.3284,  ..., -1.7926, -1.2779,  0.2217],
         [ 2.0750, -1.1924, -0.2329,  ..., -0.4250, -0.1801,  1.2846],
         [-0.7208,  0.8995,  1.3209,  ..., -0.2966,  1.9810,  1.6138]],

        [[ 0.2652, -0.7919,  0.5994,  ..., -1.2554, -1.4385,  2.0571],
         [ 0.1718, -0.7977,  0.4268, 

In [177]:
# test the loss
import torch
import torch.nn as nn

input = torch.randn((3, 64, 64), requires_grad=True)
target = torch.randn((3, 64, 64))

fft_loss = FFTLoss()
output = fft_loss(input, target)
output.backward()

print('input: ', input)
print('target: ', target)
print('output: ', output)

input:  tensor([[[-0.6055, -0.3839,  0.4894,  ..., -1.5995, -1.0831, -0.6098],
         [ 0.1162, -0.9619, -0.0865,  ..., -0.4587, -0.4353, -0.6610],
         [ 0.8993, -2.2511,  1.5225,  ..., -1.8125,  0.0825, -0.3185],
         ...,
         [-0.2171, -0.6641, -2.8307,  ..., -0.0839,  0.0508, -0.9194],
         [ 1.5035,  2.0641,  1.7570,  ..., -0.0079, -0.2611,  0.2861],
         [-1.6612, -0.4720, -0.2615,  ...,  0.0289, -0.7971, -0.6313]],

        [[ 1.1084, -0.2814, -1.5716,  ...,  0.1583,  0.3641,  1.0905],
         [ 0.0252, -1.0436, -0.3290,  ...,  0.0718, -0.1969,  0.8435],
         [-1.7029,  0.4927, -0.7195,  ...,  0.6377,  2.7637,  1.3378],
         ...,
         [-0.4161,  0.9680, -0.3302,  ..., -1.0823, -0.4520,  0.9102],
         [-0.6099,  1.2596, -1.4773,  ...,  0.5926, -0.5569, -0.6325],
         [ 0.6208,  0.6184,  0.4602,  ...,  0.6756, -0.4513,  0.4019]],

        [[ 0.3724,  0.5834,  0.6066,  ...,  0.3487,  0.1162, -0.1470],
         [ 1.3664, -0.6699, -0.5641, 