In [28]:
import torch
import numpy as np
from mrspy.plot import plot

def MathSinc(x):
    # MathSinc implementation (sinc function is sin(pi*x)/(pi*x))
    return torch.sinc(x / np.pi)

def calcSRMatrixApprox(MaxPhase, NumPixels, k, Partitions, b=None, ZeroThreshold=None):
    DefaultZeroThreshold = 10 * torch.finfo(torch.float32).eps

    # Handle the optional parameters
    if ZeroThreshold is None:
        ZeroThreshold = DefaultZeroThreshold

    if b is None:
        b = -(2 * MaxPhase / (Partitions[-1] - Partitions[0])**2) * Partitions[0]  # if b is not provided, calculate it

    aEffective = MaxPhase / (Partitions[-1] - Partitions[0])**2

    # If k is not defined, calculate it
    if k is None:
        k = -2 * aEffective * torch.arange(NumPixels).float()

    # Reshaping Partitions and k
    Partitions = Partitions.reshape(-1, 1)
    IdxPositions = (Partitions[:-1] + Partitions[1:]) / 2
    delta = Partitions[1:] - IdxPositions

    # Ensuring k has the correct shape
    k = k.reshape(-1, 1)
    NumKs = len(k)

    deltaMat = delta.view(1, -1).repeat(NumKs, 1)
    IdxPosMat = IdxPositions.view(1, -1).repeat(NumKs, 1)
    kMat = k.repeat(1, NumPixels)

    LinCoeffMat = (2 * aEffective * IdxPosMat + b + kMat)
    LinCoeff_x_delta_Mat = LinCoeffMat * deltaMat

    SincInput = LinCoeff_x_delta_Mat
    ExpInput = aEffective * IdxPosMat**2 + b * IdxPosMat + kMat * IdxPosMat
    HighOrder2 = 2 * ((LinCoeff_x_delta_Mat**2 - 2) * torch.sin(LinCoeff_x_delta_Mat) + 
                     2 * LinCoeff_x_delta_Mat * torch.cos(LinCoeff_x_delta_Mat)) / (LinCoeffMat**3)

    # Set small values of LinCoeffMat to ZeroThreshold
    ZeroLinCoeffMatIdxs = torch.abs(LinCoeffMat) < ZeroThreshold
    HighOrder2[ZeroLinCoeffMatIdxs] = (2 / 3) * deltaMat[ZeroLinCoeffMatIdxs]**3

    # Derivative calculation
    DerivativeOrder1 = 2 * 1j / LinCoeffMat**2 * (torch.sin(LinCoeff_x_delta_Mat) - LinCoeff_x_delta_Mat * torch.cos(LinCoeff_x_delta_Mat))

    # 2nd order calculation for A
    A = torch.exp(1j * ExpInput) * ((2 * deltaMat) * MathSinc(SincInput) + 1j * aEffective * HighOrder2)

    ADerivative = torch.exp(1j * ExpInput) * DerivativeOrder1

    PartitionsUsed = Partitions

    return A, ADerivative, IdxPositions, PartitionsUsed

def calcInvA(a_rad2cmsqr, LPE, NumPE, ShiftPE, SPENAcquireSign, ky1RelativePos, GaussRelativeWidth):
    MaxPhase = a_rad2cmsqr * LPE**2  # [rad]

    NumPixels = NumPE
    NumPixelsFinal = NumPE

    # Define positions of pixel borders.
    Partitions = SPENAcquireSign * torch.linspace(-LPE/2, LPE/2, NumPixels + 1) + ShiftPE / 10
    PartitionsFinal = SPENAcquireSign * torch.linspace(-LPE/2, LPE/2, NumPixelsFinal + 1) + ShiftPE / 10

    # Define ky sample positions.
    ky = -2 * SPENAcquireSign * a_rad2cmsqr * (torch.arange(NumPE).float()) * LPE / NumPE

    b = - ky[0] + -2 * a_rad2cmsqr * (Partitions[0] + (Partitions[1] - Partitions[0]) * ky1RelativePos)

    # Final value of b and ky.
    b = b.item()  # Convert b to a scalar (no problem here, as it's used in scalar form)
    
    # Ensure ky is kept as a tensor with NumPE elements
    ky = ky

    AFinal = calcSRMatrixApprox(MaxPhase, NumPixelsFinal, ky, PartitionsFinal, b)[0]

    # Generate Gaussian weighted Super-Resolution matrix.
    GaussWeightVar = (GaussRelativeWidth * np.pi * NumPixelsFinal**2 / MaxPhase)**2

    # Define y from given ky.
    yk = -(b + ky) / (2 * a_rad2cmsqr)

    # Define pixel centers
    yPixels = (PartitionsFinal[:-1] + PartitionsFinal[1:]) / 2

    # Calculate distances and translate to final pixel distances.
    DistMat = NumPixelsFinal / LPE * (yk.unsqueeze(1) - yPixels.unsqueeze(0))

    GaussWeight = torch.exp(-DistMat**2 / (2 * GaussWeightVar))

    AGaussWeighted = AFinal * GaussWeight

    # Define inverse of super-resolution matrix A to use.
    InvA = AGaussWeighted.conj().t()

    return InvA, AFinal


# Main execution:

alfa = -47.1239
L = [4, 4]
Finalryxacq = torch.zeros((256, 256))

aSign = -1
InvA, AFinal = calcInvA(alfa, L[1], Finalryxacq.size(0) // 2, 0, -aSign, 0, 0.9)


In [29]:
data = InvA

print(data.mean())
print(data.abs().max())
print(data.abs().min())
print(data.std())
print(data[10][11])

tensor(0.0001-2.3628e-05j)
tensor(0.0312)
tensor(6.9432e-05)
tensor(0.0190)
tensor(-0.0282+0.0135j)
