In [1]:
import torch
import torch.special as special

def calculate_chebyshev_vectorized(s, s0, coeffs):
    """Calculate Chebyshev polynomial contribution using torch.special"""
    # Map s to the range [-1, 1] for Chebyshev polynomials
    z = (2 * (s - s0) / s0) - 1
    
    # Convert coefficients to tensor if not already
    if not isinstance(coeffs, torch.Tensor):
        coeffs = torch.tensor(coeffs, dtype=s.dtype, device=s.device)
        
    # Reshape for broadcasting
    z_flat = z.reshape(-1)
    batch_size = z_flat.shape[0]
    n = len(coeffs)
    
    # Calculate Chebyshev polynomials for all orders and all z values at once
    # Create indices tensor for each polynomial order
    indices = torch.arange(n, device=s.device).reshape(n, 1).expand(n, batch_size)
    
    # Calculate all Chebyshev polynomials at once
    cheby_values = special.chebyshev_polynomial_t(indices, z_flat.expand(n, batch_size))
    
    # Multiply by coefficients and sum
    coeffs_expanded = coeffs.reshape(-1, 1).expand(-1, batch_size)
    result = torch.sum(coeffs_expanded * cheby_values, dim=0)
    
    # Reshape back to original shape
    return result.reshape(s.shape)

def calculate_k_matrix_optimized(s, channels_data, k_parameters, resmasses):
    """Calculate K-matrix using optimized batch operations"""
    num_channels = len(channels_data["masses"])
    batch_size = s.numel()
    device = s.device
    dtype = torch.complex128
    
    # Reshape s for batch operations
    s_flat = s.reshape(-1)
    
    # Initialize K-matrix with zeros
    k_mat = torch.zeros((batch_size, num_channels, num_channels), dtype=dtype, device=device)
    
    # Add pole terms - vectorized across all s values
    for i, mass in enumerate(resmasses):
        # Vectorized calculation for 1/(mass^2 - s)
        mass_term = 1.0 / (mass**2 - s_flat).reshape(batch_size, 1, 1)
        
        # Convert k_parameter to tensor if not already
        if not isinstance(k_parameters[i], torch.Tensor):
            k_param = torch.tensor(k_parameters[i], dtype=dtype, device=device)
        else:
            k_param = k_parameters[i]
            
        # Expand k_param for broadcasting
        k_param_expanded = k_param.unsqueeze(0).expand(batch_size, -1, -1)
        
        # Add contribution to k_mat
        k_mat += k_param_expanded * mass_term
    
    # Add Chebyshev polynomial terms for each channel
    for i in range(num_channels):
        cheby_val = calculate_chebyshev_vectorized(
            s_flat, 
            channels_data["s0"][i], 
            channels_data["cheby_coeffs"][i]
        )
        # Add to diagonal elements
        k_mat[:, i, i] += cheby_val
    
    # Add Adler zero term if needed
    if "sL" in channels_data and channels_data["sL"] > 0:
        adler_factor = (s_flat - channels_data["sL"]) / s_flat
        adler_factor = adler_factor.reshape(batch_size, 1, 1)
        k_mat *= adler_factor
    
    return k_mat

def calculate_true_momentum_optimized(masses, s):
    """Calculate true momentum with optimized batch operations"""
    # Ensure s is flattened for batch operations
    s_flat = s.reshape(-1)
    batch_size = s_flat.shape[0]
    
    # Extract masses
    m1, m2 = masses[0], masses[1]
    
    # Vectorized calculation
    sqrt_term = torch.sqrt((s_flat - (m1 + m2)**2) * (s_flat - (m1 - m2)**2))
    momentum = 0.5 * sqrt_term / s_flat
    
    # Reshape to original shape
    return momentum.reshape(s.shape)

def calculate_phase_space_matrix_optimized(s, channels_data, J):
    """Calculate phase space matrix with optimized batch operations"""
    num_channels = len(channels_data["masses"])
    batch_size = s.numel()
    device = s.device
    dtype = torch.complex128
    
    # Reshape s for batch operations
    s_flat = s.reshape(-1)
    
    # Initialize rho matrix
    rho_mat = torch.zeros((batch_size, num_channels, num_channels), dtype=dtype, device=device)
    
    # Calculate rho for each channel
    for i in range(num_channels):
        # Calculate threshold
        masses = channels_data["masses"][i]
        threshold = sum(masses)**2
        
        # Calculate momentum
        momentum = calculate_true_momentum_optimized(masses, s_flat)
        
        # Calculate rho
        rho = 2 * momentum / torch.sqrt(s_flat)
        
        # Handle sheet selection if needed
        if "sheet" in channels_data and channels_data["sheet"][i] == 1:
            mask = s_flat.real < threshold
            rho = torch.where(mask, -1j * rho, rho)
        
        # Set diagonal elements
        rho_mat[:, i, i] = rho
    
    return rho_mat

def calculate_amplitude_optimized(s, channels_data, k_parameters, resmasses, J):
    """Fully optimized amplitude calculation using batch operations"""
    num_channels = len(channels_data["masses"])
    batch_size = s.numel()
    device = s.device
    dtype = torch.complex128
    
    # Reshape s for batch operations
    s_flat = s.reshape(-1)
    
    # Calculate K-matrix
    k_mat = calculate_k_matrix_optimized(s_flat, channels_data, k_parameters, resmasses)
    
    # Calculate phase space matrix
    rho_mat = calculate_phase_space_matrix_optimized(s_flat, channels_data, J)
    
    # Calculate denominator: I - i*K*rho using batch matrix multiplication
    identity_batched = torch.eye(num_channels, dtype=dtype, device=device).unsqueeze(0).expand(batch_size, -1, -1)
    k_rho_batched = torch.bmm(k_mat, rho_mat)
    denominator_batched = identity_batched - 1j * k_rho_batched
    
    # Batch matrix inversion
    denominator_inv_batched = torch.linalg.inv(denominator_batched)
    
    # Calculate output phase space factors
    phsp_batched = torch.zeros((batch_size, num_channels, num_channels), dtype=dtype, device=device)
    for i in range(num_channels):
        momentum = calculate_true_momentum_optimized(channels_data["masses"][i], s_flat)
        phsp_batched[:, i, i] = momentum**(J + 0.5) / s_flat**(0.25)
    
    # Final amplitude calculation using batch matrix multiplication
    temp = torch.bmm(k_mat, denominator_inv_batched)
    amplitude_batched = torch.bmm(temp, phsp_batched)
    
    # Reshape back to original format
    amplitude = amplitude_batched.reshape(*s.shape, num_channels, num_channels)
    amplitude = amplitude.permute(*range(len(s.shape), len(s.shape) + 2), *range(len(s.shape)))
    
    return amplitude

In [2]:
def example_optimized_usage():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Define channels in a batch-friendly structure
    channels_data = {
        "masses": [
            torch.tensor([0.139, 0.139], device=device),  # Channel 1: pion masses
            torch.tensor([0.494, 0.494], device=device)   # Channel 2: kaon masses
        ],
        "couplings": [
            torch.tensor([0.5, 0.3], device=device),      # Channel 1 couplings
            torch.tensor([0.4, 0.2], device=device)       # Channel 2 couplings
        ],
        "cheby_coeffs": [
            torch.tensor([0.1, 0.05, -0.02], device=device),  # Channel 1 Chebyshev coeffs
            torch.tensor([0.08, 0.03], device=device)         # Channel 2 Chebyshev coeffs
        ],
        "pole_type": [1, 1],
        "s0": [1.0, 1.0],
        "sL": 0.1,  # Adler zero position
        "sheet": [0, 0]  # Physical sheet for both channels
    }
    
    # Define K-matrix parameters
    num_channels = len(channels_data["masses"])
    
    # Example: one resonance
    k_param1 = torch.zeros((num_channels, num_channels), dtype=torch.complex128, device=device)
    k_param1[0, 0] = 0.5
    k_param1[0, 1] = 0.2
    k_param1[1, 0] = 0.2
    k_param1[1, 1] = 0.3
    k_parameters = [k_param1]
    
    # Resonance masses
    resmasses = torch.tensor([0.770], device=device)  # Example: rho meson mass
    
    # Spin
    J = 1
    
    # Energy values to evaluate (vectorized calculation for multiple s values)
    # Create a grid of s values
    s_real = torch.linspace(0.5, 1.0, 10, device=device)
    s_imag = torch.linspace(0.0, 0.1, 5, device=device)
    s_grid_real, s_grid_imag = torch.meshgrid(s_real, s_imag, indexing='ij')
    s_values = torch.complex(s_grid_real, s_grid_imag)
    
    # Calculate amplitude for all s values at once
    amp = calculate_amplitude_optimized(s_values, channels_data, k_parameters, resmasses, J)
    
    # Print a few results
    print(f"Amplitude shape: {amp.shape}")
    print(f"Amplitude at s = {s_values[0, 0].item()}: {amp[0, 0]}")
    print(f"Amplitude at s = {s_values[-1, -1].item()}: {amp[-1, -1]}")
    
    return amp

# Run the example
amp_result = example_optimized_usage()

RuntimeError: "chebyshev_polynomial_t_cpu" not implemented for 'ComplexFloat'