In [None]:
import numpy as np
import torch 
import torch

def inverse_2d_dft(X):
    """
    Compute the 2D Inverse Discrete Fourier Transform (IDFT) using sine and cosine with real coefficients.
    
    Parameters:
    X : torch.Tensor
        2D tensor of complex numbers representing the frequency domain.
        
    Returns:
    x : torch.Tensor
        2D tensor of real numbers representing the spatial domain.
    """
    M, N = X.shape

    # Create the coordinate grids
    m = torch.arange(M, dtype=torch.float32).view(M, 1, 1, 1)
    n = torch.arange(N, dtype=torch.float32).view(1, N, 1, 1)
    k = torch.arange(M, dtype=torch.float32).view(1, 1, M, 1)
    l = torch.arange(N, dtype=torch.float32).view(1, 1, 1, N)

    # Compute the angles for the cos and sin functions
    angles = 2 * torch.pi * ((m * k / M) + (n * l / N)) 
    # angles = 2 * torch.pi * (torch.matmul(k, m.T) / M + torch.matmul(l, n) / N)

    # Real and imaginary parts of X
    A = X.real
    B = X.imag

    # Compute the cosine and sine terms
    cos_term = torch.cos(angles)
    sin_term = torch.sin(angles)
    
    # Reshape matrices for proper broadcasting
    # cos_term = cos_term  # Shape: (M, N, M, N)
    # sin_term = sin_term  # Shape: (M, N, M, N)
    x = (A * cos_term - B * sin_term).reshape(M, N, -1).mean(dim=-1)
    # Compute the real part of the inverse DFT
    # sum_real = torch.einsum('ij,imn,jmn->mn', A, cos_term, torch.ones((M, N, N))) - torch.einsum('ij,imn,jmn->mn', B, sin_term, torch.ones((M, N, N)))
    
    # Normalize the result
    # x = sum_real / (M * N)
    
    return x




In [None]:

# Create a sample 2D tensor of complex numbers (frequency domain)
X = torch.tensor([[complex(10, 0), complex(2, -2)], [complex(3, 1), complex(4, -4)]], dtype=torch.complex64)

# Compute the 2D inverse DFT
x_custom = inverse_2d_dft(X)

# Compute the 2D inverse FFT using PyTorch
x_torch = torch.fft.ifft2(X).real

print("Custom Inverse 2D DFT result:")
print(x_custom)

print("PyTorch ifft2 result:")
print(x_torch)