In [1]:
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Union, Sequence
from typing import List, Optional, Tuple, Union

Number = Union[int, float]


In [2]:

def _compute_dt(shape, start_points=None, end_points=None):
    """
    Compute uniform spacing (dt) for each dimension based on domain lengths, step sizes,
    start points, and end points. Defaults to a unit domain if not specified.

    Parameters:
    shape (Sequence[int]): The shape of the input excluding batch and channel, i.e. (d_1, d_2, ..., d_n).
    step_sizes (Sequence[float], optional): Step sizes for each dimension. Defaults to shape-based uniform spacing.
    start_points (Sequence[float], optional): Start points for each dimension. Defaults to 0.0 for all dimensions.
    end_points (Sequence[float], optional): End points for each dimension. Defaults to 1.0 for all dimensions.

    Returns:
    dt_list (Sequence[float]): A list of spacings, one per dimension.
    grid (List[torch.Tensor]): A list of grid points for each dimension based on the spacing and domain.
    """
    dim = len(shape)

    # Set default start and end points if not provided
    if start_points is None:
        start_points = torch.zeros(dim).tolist()
    if end_points is None:
        end_points = torch.ones(dim).tolist()

    # Validate that start_points and end_points match the number of dimensions
    if len(start_points) != dim or len(end_points) != dim:
        raise ValueError("Start points and end points must match the number of input dimensions ({dim}).")

    # Compute domain lengths from start and end points
    domain_lengths = [end_points[i] - start_points[i] for i in range(dim)]

    # Generate grid points for each dimension using torch.linspace
    grid = [torch.linspace(start_points[i], end_points[i], steps=shape[i]) for i in range(dim)]

    # Compute dt directly from the grid
    dt_list = [(grid[i][1] - grid[i][0]).item() for i in range(dim)]

    return dt_list, grid


In [3]:
class SpectralConvLaplace1D(nn.Module):
    def __init__(
        self, 
        in_channels, 
        out_channels, 
        n_modes,
        complex_data=False,
        max_n_modes=None,
        bias=True,
        separable=False,
        resolution_scaling_factor: Optional[Union[Number, List[Number]]] = None,
        xno_block_precision="full",
        rank=0.5,
        factorization=None,
        implementation="reconstructed",
        fixed_rank_modes=False,
        decomposition_kwargs: Optional[dict] = None,
        init_std="auto",
        fft_norm="forward",
        device=None, 
        linspace_steps=None, 
        linspace_startpoints=None, 
        linspace_endpoints=None, 
        
        ):
        super(SpectralConvLaplace1D, self).__init__()
        
        
        self.linspace_steps = linspace_steps
        self.linspace_startpoints = linspace_startpoints
        self.linspace_endpoints = linspace_endpoints
        
        modes = list(n_modes)
        self.modes1 = modes[0]
        self.scale = (1 / (in_channels*out_channels))
        self.weights_pole = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))
        self.weights_residue = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))
       
    def output_PR(self, lambda1,alpha, weights_pole, weights_residue):   
        Hw=torch.zeros(weights_residue.shape[0],weights_residue.shape[0],weights_residue.shape[2],lambda1.shape[0], device=alpha.device, dtype=torch.cfloat)
        term1=torch.div(1,torch.sub(lambda1,weights_pole))
        Hw=weights_residue*term1
        output_residue1=torch.einsum("bix,xiok->box", alpha, Hw) 
        output_residue2=torch.einsum("bix,xiok->bok", alpha, -Hw) 
        return output_residue1,output_residue2    

    def forward(self, x):
        
        # t=grid_x_train
        # #Compute input poles and resudes by FFT
        # dt=(t[1]-t[0]).item()
        
        if self.linspace_steps is None:
            self.linspace_steps = x.shape[2:]
            
        dt_list, shape = _compute_dt(shape=self.linspace_steps, 
                                     start_points=self.linspace_startpoints, 
                                     end_points=self.linspace_endpoints)
        t = shape[0]
        dt = dt_list[0]        
        
        alpha = torch.fft.fft(x)
        lambda0=torch.fft.fftfreq(t.shape[0], dt)*2*np.pi*1j
        lambda1=lambda0.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        lambda1=lambda1
    
        # Obtain output poles and residues for transient part and steady-state part
        output_residue1,output_residue2= self.output_PR(lambda1, alpha, self.weights_pole, self.weights_residue)
    
        # Obtain time histories of transient response and steady-state response
        x1 = torch.fft.ifft(output_residue1, n=x.size(-1))
        x1 = torch.real(x1)
        x2=torch.zeros(output_residue2.shape[0],output_residue2.shape[1],t.shape[0], device=alpha.device, dtype=torch.cfloat)    
        term1=torch.einsum("bix,kz->bixz", self.weights_pole, t.type(torch.complex64).reshape(1,-1))
        term2=torch.exp(term1) 
        x2=torch.einsum("bix,ioxz->boz", output_residue2,term2)
        x2=torch.real(x2)
        x2=x2/x.size(-1)
        return x1+x2


In [4]:
n_dimensions = (20, 4, 2048)
input = torch.rand(*n_dimensions)

In [5]:
laplace = SpectralConvLaplace1D(
    in_channels=4, 
    out_channels=4, 
    n_modes=(1,)
)

In [6]:
input.shape

torch.Size([20, 4, 2048])

In [7]:
norm = nn.InstanceNorm2d(48)

In [8]:
norm(input).shape



torch.Size([20, 4, 2048])

In [9]:
norm(laplace(norm(input))).shape

torch.Size([20, 4, 2048])

In [21]:
max_modes1 = 16
max_modes2 = 16
in_channels = 4
out_channels = 4
n_modes= (2, 2)

In [22]:
scale = 1 / (in_channels * out_channels)

In [23]:
total_modes = max_modes1 + max_modes2 + (max_modes1 * max_modes2)
weight = nn.Parameter(
    scale * torch.rand(in_channels, out_channels, total_modes, dtype=torch.cfloat)
)

In [24]:
weight.shape

torch.Size([4, 4, 288])

In [25]:
modes1, modes2 = n_modes
start_pole1 = 0
end_pole1 = modes1
start_pole2 = end_pole1
end_pole2 = start_pole2 + modes2
start_residue = end_pole2
end_residue = start_residue + (modes1 * modes2)

In [26]:
weights_pole1 = weight[:, :, start_pole1:end_pole1].view(weight.size(0), weight.size(1), modes1)
weights_pole2 = weight[:, :, start_pole2:end_pole2].view(weight.size(0), weight.size(1), modes2)
weights_residue = weight[:, :, start_residue:end_residue].view(weight.size(0), weight.size(1), modes1, modes2)

weights_pole1.shape, weights_pole2.shape, weights_residue.shape

(torch.Size([4, 4, 2]), torch.Size([4, 4, 2]), torch.Size([4, 4, 2, 2]))

In [27]:
weights_pole1 = nn.Parameter(scale * torch.rand(in_channels, out_channels, modes1,  dtype=torch.cfloat))
weights_pole2 = nn.Parameter(scale * torch.rand(in_channels, out_channels, modes2, dtype=torch.cfloat))
weights_residue = nn.Parameter(scale * torch.rand(in_channels, out_channels, modes1,  modes2, dtype=torch.cfloat))

weights_pole1.shape, weights_pole2.shape, weights_residue.shape

(torch.Size([4, 4, 2]), torch.Size([4, 4, 2]), torch.Size([4, 4, 2, 2]))