In [None]:
import numpy as np
import named_arrays as na
import astropy.units as u
import astropy.visualization

astropy.visualization.quantity_support();

In [None]:
inputs = na.SpectralPositionalVectorLinearSpace(
    start=na.SpectralPositionalVectorArray(
        wavelength=100 * u.AA,
        position=na.Cartesian2dVectorArray(
            x=-50 * u.arcsec,
            y=-50 * u.arcsec,
        )
    ),
    stop=na.SpectralPositionalVectorArray(
        wavelength=500 * u.AA,
        position=na.Cartesian2dVectorArray(
            x=50 * u.arcsec,
            y=50 * u.arcsec,
        )
    ),
    num=na.SpectralPositionalVectorArray(
        wavelength=2,
        position=na.Cartesian2dVectorArray(
            x=11,
            y=11,
        ),
    ),
    axis=na.SpectralPositionalVectorArray(
        wavelength='wavelength',
        position=na.Cartesian2dVectorArray(
            x='x',
            y='y',
        ),
    ),
)

inputs

In [None]:
t = na.ScalarLinearSpace(-10 * u.s, 10 * u.s, num=3, axis='time')

a = 1 * (u.mm / u.arcsec)
b = .2 * u.mm / (u.arcsec ** 2)
c = t * (u.mm / (u.arcsec * u.s))
d = .001 *  inputs.wavelength * u.mm / (u.AA * u.arcsec ** 2)

outputs = na.Cartesian2dVectorArray(
    x=1 * u.mm + a * inputs.position.x + b * inputs.position.y ** 2 ,
    y=5 * u.mm + c * inputs.position.y + d * inputs.position.x ** 2,
)

In [None]:
fit = na.PolynomialFitFunctionArray(
    inputs=inputs,
    outputs=outputs,
    degree=2,
    components_polynomial=('position.x', 'position.y'),
    axis_polynomial=('x', 'y'),
)

coefficients = fit.coefficients
print(coefficients.components['position.x'].x-a)
print(coefficients.components['position.y*position.y'].x-b)
print(coefficients.components['position.y'].y-c)
print(coefficients.components['position.x*position.x'].y-d)

In [None]:
best_fit_quad = fit(fit.inputs)
rms_error = np.sqrt(np.square(best_fit_quad.outputs - outputs).sum())
rms_error

In [None]:
fit_linear = na.PolynomialFitFunctionArray(
    inputs=inputs,
    outputs=outputs,
    degree=1,
    components_polynomial=('position.x', 'position.y'),
    axis_polynomial=('x', 'y'),
)
best_fit_linear = fit_linear(fit.inputs)
rms_error = np.sqrt(np.square(best_fit_linear.outputs - outputs).sum())
rms_error

In [None]:
original_output_y = fit.outputs.y.value
quadratic_fit_output_y = best_fit_quad.outputs.y.value
linear_fit_output_y = best_fit_linear.outputs.y.value

fig, ax = na.plt.subplots(
    axis_cols='wavelength',
    ncols=fit.shape['wavelength'],
    axis_rows='time',
    nrows=fit.shape['time'],
    sharex=True,
    sharey=True
)
na.plt.pcolormesh(
    fit.broadcasted.inputs.position,
    C=original_output_y,
    ax=ax,
)
fig.suptitle('Orginal Function Array');

In [None]:
fig, ax = na.plt.subplots(
    axis_cols='wavelength',
    ncols=fit.shape['wavelength'],
    axis_rows='time',
    nrows=fit.shape['time'],
      sharex=True,
    sharey=True
)
na.plt.pcolormesh(
    fit.broadcasted.inputs.position,
    C=quadratic_fit_output_y,
    ax=ax,
)
fig.suptitle('Quadratic Fit');

In [None]:
fig, ax = na.plt.subplots(
    axis_cols='wavelength',
    ncols=fit.shape['wavelength'],
    axis_rows='time',
    nrows=fit.shape['time'],
    sharex=True,
    sharey=True
)
na.plt.pcolormesh(
    fit.broadcasted.inputs.position,
    C=linear_fit_output_y ,
    ax=ax,
)
fig.suptitle('Linear Fit');