In [1]:
import random
from math import exp
import numpy as np
from scipy import linalg

In [4]:
# ground truth coefficients
ar = 1.0
br = 2.0
cr = 1.0

# initial estimates
ae = 2.0
be = -1.0
ce = 5.0

# number of data points
N = 100

w_sigma = 1.0 # noise sigma
inv_sigma = 1.0 # w_sigma
sig_sq = inv_sigma ** 2

# create data
x_data, y_data = [], []
for i in range(N):
    x = i / 100.0
    x_data.append(x)
    y_data.append(exp(ar * x ** 2 + br * x + cr) + random.gauss(0, w_sigma ** 2))

In [9]:
# Start Gauss-Newton iterations
iterations = 100
cost = 0
lastCost = 0

for iters in range(iterations):
    H = np.zeros((3, 3))  # Hessian = J^T W^-1 J in Gauss-Newton
    b = np.zeros((3, 1))
    cost = 0

    for i in range(N):
        xi = x_data[i]
        yi = y_data[i]

        fun = exp((ae * xi ** 2) + (be * xi) + ce) # function, repeated b/c of chain rule, DRY
        error = yi - fun
        deda = -xi * xi * fun
        dedb = -xi * fun
        dedc = -fun
        J = np.array([deda, dedb, dedc], dtype = np.float32).reshape((3,1))

        H += sig_sq * J * J.T
        b += -sig_sq * error * J

        cost += error ** 2

    # Solve

    dx = linalg.solve(H, b, assume_a='sym')

    if dx[0] is np.NAN:
        print("Result is NaN!")
        break

    if iters > 0 and cost >= lastCost:
        print(f"Cost: {cost} last cost: {lastCost}, iter: {iters}, break.")
        break


    ae += dx[0]
    be += dx[1]
    ce += dx[2]

    lastCost = cost

    print(f"total cost: {cost}\nupdate: {dx.transpose()}\n"
            f"estimated params: {ae}, {be}, {ce}")
            
print(f"estimated abc = {ae}, {be}, {ce}")

total cost: 35629.477431077146
update: [[-0.0591923   0.59830553 -0.89830097]]
estimated params: [2.04991737], [-0.10373918], [2.15836817]
total cost: 2150.4571530314383
update: [[-0.48549784  1.13308559 -0.73492349]]
estimated params: [1.56441953], [1.02934641], [1.42344468]
total cost: 146.03432357271265
update: [[-0.48447287  0.82982363 -0.35780673]]
estimated params: [1.07994665], [1.85917004], [1.06563794]
total cost: 79.88425064973343
update: [[-0.06990631  0.11323037 -0.04475417]]
estimated params: [1.01004034], [1.97240041], [1.02088377]
total cost: 79.3299276553055
update: [[ 0.00058146 -0.0007521   0.00019878]]
estimated params: [1.0106218], [1.97164832], [1.02108255]
total cost: 79.32990202569955
update: [[-7.47818214e-06  1.15071763e-05 -4.20535872e-06]]
estimated params: [1.01061432], [1.97165983], [1.02107835]
total cost: 79.32990202117965
update: [[-2.30998616e-07  3.16659562e-07 -9.69559630e-08]]
estimated params: [1.01061409], [1.97166014], [1.02107825]
Cost: 79.329902