In [116]:
import numpy as np
import time
from scipy import linalg

In [117]:
ar = 1.0; br = 2.0; cr = 1.0  # Ground truth
N = 100  # Number of data points
w_sigma = 1.0  # Sigma of noise
inv_sigma = 1.0/w_sigma
rng = np.random.default_rng()

x_data = []; y_data = []
for i in range(N):
    x = i/100.0
    x_data.append(x)
    y_data.append(np.exp(ar*x*x + br*x + cr) + rng.normal(0, w_sigma*w_sigma))

In [118]:
def runit():
    ae = 2.0; be = -1.0; ce = 4.0  # initial guess
    # Start Gauss-Newton iterations
    iterations = 1000
    cost = 0; lastCost = 0

    for iters in range(iterations):
        H = np.zeros((3, 3))  # Hessian = J^T W^-1 J in Gauss-Newton
        b = np.zeros((3, 1))
        cost = 0

        for i in range(N):
            xi = x_data[i]; yi = y_data[i]
            error = yi - np.exp(ae*xi*xi + be*xi + ce)
            J = np.zeros((3, 1))
            J[0] = -xi*xi*np.exp(ae*xi*xi + be*xi + ce)  # de/da
            J[1] = -xi*np.exp(ae*xi*xi + be*xi + ce)  # de/db
            J[2] = -np.exp(ae*xi*xi + be*xi + ce)  # de/dc

            H += inv_sigma*inv_sigma*J*J.transpose()
            b += -inv_sigma*inv_sigma*error*J

            cost += error*error
        dx = linalg.solve(H, b, assume_a='sym')
        if dx[0] is np.NAN:
            print("Result is nan!")
            break

        if iters > 0 and abs(cost) > abs(lastCost):
            print(f"Cost: {cost} last cost: {lastCost}, iter: {iters}, break.")
            break

        ae += dx[0]
        be += dx[1]
        ce += dx[2]

        lastCost = cost
        print(f"total cost: {cost}\nupdate: {dx.transpose()}\n"
            f"estimated params: {ae}, {be}, {ce}")
    print(f"estimated abc = {ae}, {be}, {ce}")

startTime = time.time()
runit()
endTime = time.time()
print(f"Solve time: {endTime - startTime}")

total cost: 299557.9577453571
update: [[ 0.1256465   0.19445549 -0.9518299 ]]
estimated params: [2.1256465], [-0.80554451], [3.0481701]
total cost: [27847.55143801]
update: [[-0.00637237  0.58755987 -0.88637528]]
estimated params: [2.11927413], [-0.21798464], [2.16179482]
total cost: [1723.6047699]
update: [[-0.5108      1.1694136  -0.72197793]]
estimated params: [1.60847413], [0.95142896], [1.4398169]
total cost: [160.11727935]
update: [[-0.51710977  0.85897309 -0.35457349]]
estimated params: [1.09136436], [1.81040205], [1.08524341]
total cost: [99.55828856]
update: [[-0.09426262  0.14817762 -0.05619095]]
estimated params: [0.99710174], [1.95857967], [1.02905246]
total cost: [98.67859854]
update: [[-0.0040229   0.00625691 -0.00232876]]
estimated params: [0.99307883], [1.96483658], [1.02672369]
total cost: [98.67719108]
update: [[-1.08018104e-04  1.66990199e-04 -6.14056532e-05]]
estimated params: [0.99297081], [1.96500357], [1.02666229]
total cost: [98.6771901]
update: [[-2.77895672e-0