# Part 1: Nelder Mead

In [None]:
import numpy as np

np.random.seed(42) 
    
class Simplex:
    """ For generality """
    def __init__(self, vertices):        
        self.vertices = np.array(vertices)
        self.n_dim = self.vertices[0].shape[0]  # Number of dimensions
        
        assert self.vertices.shape[0] == self.n_dim + 1, "Simplex must have n+1 vertices in n dimensions"
        
    def __repr__(self):
        return f"Simplex(vertices={self.vertices})"
    
    def sort_simplex(self, func):
        vals = [func(vertex) for vertex in self.vertices]
        sorted_vertices = np.array([self.vertices[i] for i in np.argsort(vals)])
        self.vertices = sorted_vertices
            
    def centroid(self):
        centroid = np.mean(self.vertices[:-1], axis=0)  # Exclude the worst point      
        return centroid
    
    def converged(self, tol=1e-6):
        """ Check if the simplex has converged based on the tolerance. """
        return np.max(np.abs(self.vertices - np.mean(self.vertices, axis=0))) < tol

    def reflect(self, c, alpha):
        """Reflect worst point across the centriod - also used for extension and contraction"""
        x_w = self.vertices[-1]
        return c + alpha * (c - x_w)
    
    def contract(self, c, x_r, beta, mode="outside"):
        if mode == 'outside':
            return c + beta * (x_r - c)
        else:  # inside contraction
            return c - beta * (c - self.vertices[-1])
    
    def diameter(self):
        """Compute the maximum distance between any two vertices."""
        return np.max(np.linalg.norm(self.vertices - self.vertices[:, np.newaxis], axis=-1))
    
    def shrink(self, delta):
        self.vertices[1:] = self.vertices[0] + delta * (self.vertices[1:] - self.vertices[0])


        # best = self.vertices[0]
        # for i in range(1, len(self.vertices)):
        #     for dim in range(self.n_dim):
        #         self.vertices[i, dim] = best[dim] + delta * (self.vertices[i, dim] - best[dim])
 
    
alpha = 1.0  # Reflection
beta = 0.5 # Contraction
gamma = 2.0  # Expansion
delta = 0.5 # Shrinkage    
    
def nelder_mead(func, simplex: Simplex, max_iter=1000, tol=1e-6):
    """
    Perform the Nelder-Mead optimization algorithm in 2D.

    Parameters:
    - func: The objective function to minimize.
    - simplex: A list of points defining the initial simplex.
    - max_iter: Maximum number of iterations.
    - tol: Tolerance for convergence.

    Returns:
    - The point that minimizes the function.
    """

    for k in range(max_iter):
        simplex.sort_simplex(func)

        if simplex.diameter() < tol:
            print(f"Converged at iteration {k}.")
            break
        
        best, worst = simplex.vertices[0], simplex.vertices[-1]
        print(f"iter {k:3d}, best={best}, f={func(best):.6f}")

        centroid = simplex.centroid()
        
        # Reflection & Extension
        reflected = simplex.reflect(centroid, alpha)
        reflected_value = func(reflected)   # Dont recompute (can precompute for all vertices each iteration)
        
        if reflected_value < func(simplex.vertices[-2]):
            
            if reflected_value < func(simplex.vertices[0]): # Expand reflection if it was better than the second worst
                extended = simplex.reflect(centroid, gamma)  

                if func(extended) < reflected_value:
                    simplex.vertices[-1] = extended
                    
                else:
                    simplex.vertices[-1] = reflected
            else:
                simplex.vertices[-1] = reflected
            
            continue

        # Contraction
        if reflected_value >= func(simplex.vertices[-2]):
            # Outside contraction
            contracted = simplex.contract(centroid, reflected, beta, mode='outside')
            contracted_value = func(contracted)
            
            if contracted_value < reflected_value:
                simplex.vertices[-1] = contracted
            else:
                # Shrink
                simplex.shrink(delta)
            
        
        else:
            # Inside contraction
            contracted = simplex.contract(centroid, reflected, beta, mode='inside')
            contracted_value = func(contracted)
            
            if contracted_value < func(simplex.vertices[-1]):
                simplex.vertices[-1] = contracted
            else:
                # Shrink
                simplex.shrink(delta)
            
            
    return simplex.vertices[0]

### Testing GD functions

In [185]:
def f1(x):
    x, y, z = x
    return (x - z)**2 + (2*y + z)**2 + (4*x - 2*y + z)**2 + x + y

start_f1 = np.array([0,0,0])

In [171]:
def f2(x):
    x,y,z = x
    return (x - 1)**2 + (y - 1)**2 + 100*(y-x**2)**2 + 100*(z-y**2)**2

start_f2 = np.array([1.2, 1.2, 1.2])

In [172]:
def f3(x):
    x,y = x
    return (1.5 - x + x*y)**2 + (2.25 - x + x*y**2)**2 + (2.625 - x + x*y**3)**2

start_f3 = np.array([1,1])

In [173]:
# x, y, (z) + f(x,y,(z))
# f1 minimum: [-0.15151515 -0.21212121  0.15151515 -0.19651056]
# f2 minimum: [1. 1. 1. 0.]
# f3 minimum: [3.00000000e+00 5.01501502e-01 5.21242614e-05]

#### F1

In [186]:
simplex = Simplex([start_f1,
                   start_f1 + np.array([1, 0, 0]),
                   start_f1 + np.array([0, 1, 0]),
                   start_f1 + np.array([0, 0, 1])])

min_pt = nelder_mead(f1, simplex, max_iter=1000, tol=1e-7)
min_pt, f1(min_pt), simplex

iter   0, best=[0 0 0], f=0.000000
iter   1, best=[0 0 0], f=0.000000
iter   2, best=[0 0 0], f=0.000000


(array([0, 0, 0]),
 np.int64(0),
 Simplex(vertices=[[0 0 0]
  [0 0 0]
  [0 0 0]
  [0 0 0]]))

#### F2

In [None]:
simplex = Simplex([start_f2,
                   start_f2 + np.array([1, 0, 0]),
                   start_f2 + np.array([0, 1, 0]),
                   start_f2 + np.array([0, 0, 1])])

min_pt = nelder_mead(f2, simplex, max_iter=1000, tol=1e-7)
min_pt, f2(min_pt)

(array([0.99999998, 0.99999997, 0.99999992]), np.float64(6.49875251054266e-15))

#### F3

In [176]:
simplex = Simplex([start_f3,
                   start_f3 + np.array([1, 0]),
                   start_f3 + np.array([0, 1])])

min_pt = nelder_mead(f3, simplex, max_iter=1000, tol=1e-7)
min_pt, f3(min_pt)

(array([2, 0]), np.float64(0.703125))