# Part 1: Nelder Mead

In [None]:
import numpy as np

np.random.seed(42) 
    
class Simplex:
    """ For generality """
    def __init__(self, vertices):        
        self.vertices = np.array(vertices)
        self.n_dim = self.vertices[0].shape[0]  # Number of dimensions
        
        assert self.vertices.shape[0] == self.n_dim + 1, "Simplex must have n+1 vertices in n dimensions"
        
    def __repr__(self):
        return f"Simplex(vertices={self.vertices})"
    
    def sort_simplex(self, func):
        vals = [func(vertex) for vertex in self.vertices]
        sorted_vertices = np.array([self.vertices[i] for i in np.argsort(vals)])
        self.vertices = sorted_vertices
            
    def centroid(self):
        centroid = np.mean(self.vertices[:-1], axis=0)  # Exclude the worst point      
        return centroid
    
    def converged(self, tol=1e-6):
        """ Check if the simplex has converged based on the tolerance. """
        return np.max(np.abs(self.vertices - np.mean(self.vertices, axis=0))) < tol

    def reflect(self, c, alpha):
        """Reflect worst point across the centriod - also used for extension and contraction"""
        x_w = self.vertices[-1]
        return c + alpha * (c - x_w)
    
    def outside_contraction(self, c, x_r, beta):
        return c + beta * (x_r - c)
    
    def inside_contraction(self, c, beta):
        x_w = self.vertices[-1]
        return c + beta * (x_w - c)
    
    def shrink(self, delta):
        best = self.vertices[0]
        self.vertices[1:] = best + delta*(self.vertices[1:] - best)

        # best = self.vertices[0]
        # for i in range(1, len(self.vertices)):
        #     for dim in range(self.n_dim):
        #         self.vertices[i, dim] = best[dim] + delta * (self.vertices[i, dim] - best[dim])
 
    
alpha = 1.0  # Reflection
beta = 0.5 # Contraction
gamma = 2.0  # Expansion
delta = 0.5 # Shrinkage    
    
def nelder_mead(func, simplex: Simplex, max_iter=1000, tol=1e-6):
    """
    Perform the Nelder-Mead optimization algorithm in 2D.

    Parameters:
    - func: The objective function to minimize.
    - simplex: A list of points defining the initial simplex.
    - max_iter: Maximum number of iterations.
    - tol: Tolerance for convergence.

    Returns:
    - The point that minimizes the function.
    """

    for _ in range(max_iter):
        simplex.sort_simplex(func)

        if simplex.converged(tol):
            break

        centroid = simplex.centroid()
        
        # Reflection & Extension
        reflected = simplex.reflect(centroid, alpha)
        reflected_value = func(reflected)   # Dont recompute (can precompute for all vertices each iteration)
        
        if reflected_value < func(simplex.vertices[-2]):
            
            if reflected_value < func(simplex.vertices[-3]): # Expand reflection if it was better than the second worst
                extended = simplex.reflect(centroid, gamma)  

                if func(extended) < reflected_value:
                    simplex.vertices[-1] = extended
                else:
                    simplex.vertices[-1] = reflected
            else:
                simplex.vertices[-1] = reflected
                
            continue

        # Contraction
        if reflected_value >= func(simplex.vertices[-2]):
            # Outside contraction
            contracted = simplex.outside_contraction(centroid, reflected, beta)
            contracted_value = func(contracted)
            
            if contracted_value < reflected_value:
                simplex.vertices[-1] = contracted
            else:
                # Shrink
                simplex.shrink(delta)
        
        else:
            # Inside contraction
            contracted = simplex.inside_contraction(centroid, beta)
            contracted_value = func(contracted)
            
            if contracted_value < func(simplex.vertices[-1]):
                simplex.vertices[-1] = contracted
            else:
                # Shrink
                simplex.shrink(delta)
            
    return simplex.vertices[0]

In [114]:
def f1(x):
    x, y, z = x
    return (x - z)**2 + (2*y + z)**2 + (4*x - 2*y + z)**2 + x + y

start = np.array([1.2, 1.2, 1.2])

In [117]:
simplex = Simplex([start, start + np.array([1, 0, 0]), start + np.array([0, 1, 0]), start + np.array([0, 0, 1])])

min_pt = nelder_mead(f1, simplex, max_iter=1000, tol=1e-7)
min_pt, f1(min_pt)

(array([-0.16666671, -0.22916671,  0.1666667 ]),
 np.float64(-0.19791666666665605))

In [118]:
simplex

Simplex(vertices=[[-0.16666671 -0.22916671  0.1666667 ]
 [-0.16666661 -0.22916666  0.16666656]
 [-0.1666667  -0.22916678  0.16666667]
 [-0.16666657 -0.22916664  0.16666663]])