In [1]:
import torch
import numpy as np
from typing import List

# Traditional method: using explicit grids
def traditional_method(shape: List[int], domain: List[float]):
    """Compute dt and frequencies using explicit grids."""
    grids = [torch.linspace(0, domain[i], shape[i]) for i in range(len(shape))]
    dt_list = [(grid[1] - grid[0]).item() for grid in grids]
    frequencies = [torch.fft.fftfreq(shape[i], d=dt) for i, dt in enumerate(dt_list)]
    return dt_list, frequencies

# New method: using SpectralConvND's _compute_dt
def compute_dt_new(shape: List[int], domain: List[float]):
    """Compute dt and frequencies using the new method."""
    def _compute_dt(shape: List[int], domain: List[float]):
        return [domain[i] / shape[i] for i in range(len(shape))]

    dt_list = _compute_dt(shape, domain)
    frequencies = [torch.fft.fftfreq(shape[i], d=dt) for i, dt in enumerate(dt_list)]
    return dt_list, frequencies

# Testing script
def test_methods():
    """Compare dt and frequencies between traditional and new methods."""
    test_cases = [
        {"shape": [10, 20, 30], "domain": [1.0, 1.0, 1.0]},
        {"shape": [50, 100], "domain": [2.0, 3.0]},
        {"shape": [64, 64, 64], "domain": [1.5, 1.5, 1.5]},
    ]

    for i, case in enumerate(test_cases):
        print(f"\nTest Case {i+1}")
        print(f"Shape: {case['shape']}, Domain: {case['domain']}")

        # Traditional method
        dt_traditional, freqs_traditional = traditional_method(case['shape'], case['domain'])

        # New method
        dt_new, freqs_new = compute_dt_new(case['shape'], case['domain'])

        # Compare dt
        print("\nComparing dt values:")
        for dim, (dt_t, dt_n) in enumerate(zip(dt_traditional, dt_new)):
            print(f"  Dimension {dim+1}: Traditional dt = {dt_t}, New dt = {dt_n}, Difference = {abs(dt_t - dt_n):.5e}")

        # Compare frequencies
        print("\nComparing frequency values:")
        for dim, (freq_t, freq_n) in enumerate(zip(freqs_traditional, freqs_new)):
            diff = (freq_t - freq_n).abs().max().item()
            print(f"  Dimension {dim+1}: Max Frequency Difference = {diff:.5e}")

if __name__ == "__main__":
    test_methods()



Test Case 1
Shape: [10, 20, 30], Domain: [1.0, 1.0, 1.0]

Comparing dt values:
  Dimension 1: Traditional dt = 0.1111111119389534, New dt = 0.1, Difference = 1.11111e-02
  Dimension 2: Traditional dt = 0.05263157933950424, New dt = 0.05, Difference = 2.63158e-03
  Dimension 3: Traditional dt = 0.03448275849223137, New dt = 0.03333333333333333, Difference = 1.14943e-03

Comparing frequency values:
  Dimension 1: Max Frequency Difference = 5.00000e-01
  Dimension 2: Max Frequency Difference = 5.00000e-01
  Dimension 3: Max Frequency Difference = 5.00000e-01

Test Case 2
Shape: [50, 100], Domain: [2.0, 3.0]

Comparing dt values:
  Dimension 1: Traditional dt = 0.040816325694322586, New dt = 0.04, Difference = 8.16326e-04
  Dimension 2: Traditional dt = 0.03030303120613098, New dt = 0.03, Difference = 3.03031e-04

Comparing frequency values:
  Dimension 1: Max Frequency Difference = 2.50000e-01
  Dimension 2: Max Frequency Difference = 1.66668e-01

Test Case 3
Shape: [64, 64, 64], Domain: