<td>
<a href="https://colab.research.google.com/github/raoulg/aiforgis/blob/main/notebooks/pso.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
</td>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from mpl_toolkits.mplot3d import Axes3D
import random


class ParticleSwarmOptimizer:
    def __init__(self, objective_function, bounds, n_particles, w, c1, c2):
        self.objective_function = objective_function
        self.bounds_low = bounds[0]
        self.bounds_high = bounds[1]
        self.n_dim = len(self.bounds_low)
        self.n_particles = n_particles
        self.w = w  # Inertia weight
        self.c1 = c1  # Cognitive coefficient
        self.c2 = c2  # Social coefficient

        # Initialize swarm
        self._initialize_swarm()

    def _initialize_swarm(self):
        """Initializes the particles, velocities, and best positions."""
        self.particles = np.random.uniform(
            low=self.bounds_low,
            high=self.bounds_high,
            size=(self.n_particles, self.n_dim),
        )
        self.velocities = np.zeros((self.n_particles, self.n_dim))

        self.pbest_positions = self.particles.copy()
        self.pbest_costs = np.array(
            [self.objective_function(p) for p in self.pbest_positions]
        )

        self.gbest_idx = np.argmin(self.pbest_costs)
        self.gbest_position = self.pbest_positions[self.gbest_idx].copy()
        self.gbest_cost = self.pbest_costs[self.gbest_idx]

        self.history = []  # For storing states for visualization

    def _update_velocities(self):
        """Updates the velocities of the particles."""
        r1 = np.random.rand(self.n_particles, self.n_dim)
        r2 = np.random.rand(self.n_particles, self.n_dim)

        cognitive_component = self.c1 * r1 * (self.pbest_positions - self.particles)
        social_component = self.c2 * r2 * (self.gbest_position - self.particles)

        self.velocities = (
            self.w * self.velocities + cognitive_component + social_component
        )

    def _update_particle_positions(self):
        """Updates the positions of the particles and enforces bounds."""
        self.particles += self.velocities
        self.particles = np.clip(self.particles, self.bounds_low, self.bounds_high)

    def _evaluate_and_update_bests(self):
        """Evaluates current particle positions and updates personal and global bests."""
        current_costs = np.array([self.objective_function(p) for p in self.particles])

        # Update personal bests
        improved_mask = current_costs < self.pbest_costs
        self.pbest_positions[improved_mask] = self.particles[improved_mask]
        self.pbest_costs[improved_mask] = current_costs[improved_mask]

        # Update global best
        current_min_idx = np.argmin(
            self.pbest_costs
        )  # Check pbest_costs as it contains all bests
        if self.pbest_costs[current_min_idx] < self.gbest_cost:
            self.gbest_idx = current_min_idx
            self.gbest_position = self.pbest_positions[self.gbest_idx].copy()
            self.gbest_cost = self.pbest_costs[self.gbest_idx]

    def optimize(self, max_iterations, verbose=True, store_history=False):
        """Runs the PSO algorithm."""
        if verbose:
            print(
                f"Initial Global Best Cost: {self.gbest_cost:.6f} at {self.gbest_position}"
            )

        for i in range(max_iterations):
            self._update_velocities()
            self._update_particle_positions()
            self._evaluate_and_update_bests()

            if store_history:
                self.history.append(
                    {
                        "iteration": i + 1,
                        "particles": self.particles.copy(),
                        "velocities": self.velocities.copy(),
                        "costs": np.array(
                            [self.objective_function(p) for p in self.particles]
                        ),  # Re-evaluate for accurate current costs
                        "gbest_cost": self.gbest_cost,
                        "gbest_position": self.gbest_position.copy(),
                    }
                )

            if verbose and (i + 1) % 10 == 0:  # Print every 10 iterations
                print(f"Iteration {i + 1}: Best Cost = {self.gbest_cost:.6f}")

        if verbose:
            print("\nOptimization Finished.")
            print("Global Best Position:", self.gbest_position)
            print("Global Best Cost:", self.gbest_cost)

        return self.gbest_position, self.gbest_cost


class Visualizer:
    def __init__(self, optimizer: ParticleSwarmOptimizer):
        self.history = optimizer.history
        self.bounds_low = optimizer.bounds_low
        self.bounds_high = optimizer.bounds_high

    def create_plotly_animation(self, width=800, height=600, duration=300):
        if not self.history:
            print("No history to plot.")
            return None

        frames = []
        slider_steps = []

        # Determine global min/max costs for consistent color scaling
        all_costs = np.concatenate([frame["costs"] for frame in self.history])
        min_cost, max_cost = np.min(all_costs), np.max(all_costs)
        if min_cost == max_cost:  # Avoids division by zero if all costs are the same
            max_cost += 1

        for i, frame_data in enumerate(self.history):
            particles = frame_data["particles"]
            velocities = frame_data["velocities"]  # For quiver lines
            costs = frame_data["costs"]
            gbest_pos = frame_data["gbest_position"]
            gbest_cost_val = frame_data["gbest_cost"]

            # Particle trace
            particle_trace = go.Scatter3d(
                x=particles[:, 0],
                y=particles[:, 1],
                z=particles[:, 2],
                mode="markers",
                marker=dict(
                    size=5,
                    color=costs,
                    colorscale="Viridis",
                    cmin=min_cost,
                    cmax=max_cost,
                    colorbar=dict(title="Cost"),
                    opacity=0.8,
                ),
                name="Particles",
            )

            # Gbest trace
            gbest_trace = go.Scatter3d(
                x=[gbest_pos[0]],
                y=[gbest_pos[1]],
                z=[gbest_pos[2]],
                mode="markers",
                marker=dict(
                    size=10,
                    color="gold",
                    symbol="diamond",
                    line=dict(width=1, color="black"),
                ),
                name=f"GBest ({gbest_cost_val:.2f})",
            )

            # True Optimum trace
            optimum_trace = go.Scatter3d(
                x=[4],
                y=[5],
                z=[-6],
                mode="markers",
                marker=dict(
                    size=6, color="blue", symbol="x", line=dict(width=1, color="black")
                ),
                name="True Optimum",
            )

            # Velocity traces (as lines, Plotly's quiver is not as direct as Matplotlib)
            quiver_traces = []
            vel_scale = 0.1  # Adjust scale for visibility
            for p_idx in range(len(particles)):
                start_point = particles[p_idx]
                end_point = particles[p_idx] + velocities[p_idx] * vel_scale
                quiver_traces.append(
                    go.Scatter3d(
                        x=[start_point[0], end_point[0]],
                        y=[start_point[1], end_point[1]],
                        z=[start_point[2], end_point[2]],
                        mode="lines",
                        line=dict(color="red", width=2),
                        showlegend=(
                            p_idx == 0
                        ),  # Show legend only for the first arrow line group
                        name="Velocities" if p_idx == 0 else None,
                    )
                )

            frame_name = f"Iter {frame_data['iteration']}"
            frames.append(
                go.Frame(
                    data=[particle_trace, gbest_trace, optimum_trace] + quiver_traces,
                    name=frame_name,
                )
            )

            slider_step = dict(
                args=[
                    [frame_name],
                    dict(
                        frame=dict(duration=300, redraw=True),
                        mode="immediate",
                        transition=dict(duration=100),
                    ),
                ],
                label=str(frame_data["iteration"]),
                method="animate",
            )
            slider_steps.append(slider_step)

        # Initial figure state (first frame data)
        initial_frame_data = self.history[0]
        initial_particles = initial_frame_data["particles"]
        initial_velocities = initial_frame_data["velocities"]
        initial_costs = initial_frame_data["costs"]
        initial_gbest_pos = initial_frame_data["gbest_position"]
        initial_gbest_cost_val = initial_frame_data["gbest_cost"]

        fig_data = [
            go.Scatter3d(
                x=initial_particles[:, 0],
                y=initial_particles[:, 1],
                z=initial_particles[:, 2],
                mode="markers",
                marker=dict(
                    size=5,
                    color=initial_costs,
                    colorscale="Viridis",
                    cmin=min_cost,
                    cmax=max_cost,
                    colorbar=dict(title="Cost"),
                    opacity=0.8,
                ),
                name="Particles",
            ),
            go.Scatter3d(
                x=[initial_gbest_pos[0]],
                y=[initial_gbest_pos[1]],
                z=[initial_gbest_pos[2]],
                mode="markers",
                marker=dict(
                    size=10,
                    color="gold",
                    symbol="diamond",
                    line=dict(width=1, color="black"),
                ),
                name=f"GBest ({initial_gbest_cost_val:.2f})",
            ),
            go.Scatter3d(
                x=[4],
                y=[5],
                z=[-6],
                mode="markers",
                marker=dict(
                    size=6, color="blue", symbol="x", line=dict(width=1, color="black")
                ),
                name="True Optimum",
            ),
        ]
        # Add initial quiver traces
        for p_idx in range(len(initial_particles)):
            start_point = initial_particles[p_idx]
            end_point = initial_particles[p_idx] + initial_velocities[p_idx] * vel_scale
            fig_data.append(
                go.Scatter3d(
                    x=[start_point[0], end_point[0]],
                    y=[start_point[1], end_point[1]],
                    z=[start_point[2], end_point[2]],
                    mode="lines",
                    line=dict(color="red", width=2),
                    showlegend=(p_idx == 0),
                    name="Velocities" if p_idx == 0 else None,
                )
            )

        fig = go.Figure(data=fig_data, frames=frames)

        fig.update_layout(
            title="Particle Swarm Optimization (Plotly)",
            width=width,
            height=height,
            scene=dict(
                xaxis=dict(title="X", range=[self.bounds_low[0], self.bounds_high[0]]),
                yaxis=dict(title="Y", range=[self.bounds_low[1], self.bounds_high[1]]),
                zaxis=dict(title="Z", range=[self.bounds_low[2], self.bounds_high[2]]),
                aspectmode="cube",  # Ensures aspect ratio is good for 3D
            ),
            updatemenus=[
                dict(
                    type="buttons",
                    buttons=[
                        dict(
                            label="Play",
                            method="animate",
                            args=[
                                None,
                                dict(
                                    frame=dict(duration=duration, redraw=True),
                                    fromcurrent=True,
                                    mode="immediate",
                                    transition=dict(duration=100),
                                ),
                            ],
                        ),
                        dict(
                            label="Pause",
                            method="animate",
                            args=[
                                [None],
                                dict(
                                    frame=dict(
                                        duration=0, redraw=False
                                    ),  # duration 0 to pause
                                    mode="immediate",
                                    transition=dict(duration=0),
                                ),
                            ],
                        ),
                    ],
                )
            ],
            sliders=[
                dict(
                    steps=slider_steps,
                    active=0,
                    transition=dict(duration=0),
                    x=0.1,  # Slider position
                    len=0.9,  # Slider length
                    currentvalue=dict(
                        font=dict(size=14),
                        prefix="Iteration: ",
                        visible=True,
                        xanchor="right",
                    ),
                )
            ],
            margin=dict(l=50, r=50, b=50, t=100, pad=4),
            legend=dict(
                orientation="h",  # Horizontal orientation
                yanchor="bottom",
                y=1.02,  # Position legend slightly above the plot
                xanchor="center",
                x=0.5,  # Center the legend horizontally
            ),
        )
        return fig


def viz_rastrigin(fun):
    x_opt, y_opt, z_opt = 4, 5, -6

    # Define ranges for plotting around the optimum
    # We want to show a few periods of the cosine function.
    # The shifted variables (e.g., x_shifted) have oscillations with period 1.
    # So, a range of +/- 3 to 4 for shifted variables should show several local optima.
    plot_range_delta = 4

    x_vals_range = np.linspace(x_opt - plot_range_delta, x_opt + plot_range_delta, 200)
    y_vals_range = np.linspace(y_opt - plot_range_delta, y_opt + plot_range_delta, 200)
    z_vals_range = np.linspace(z_opt - plot_range_delta, z_opt + plot_range_delta, 200)

    X1, Y1 = np.meshgrid(x_vals_range, y_vals_range)
    Z_fixed_val = z_opt
    F1_values = np.zeros(X1.shape)

    for i in range(X1.shape[0]):
        for j in range(X1.shape[1]):
            F1_values[i, j] = fun([X1[i, j], Y1[i, j], Z_fixed_val])

    plt.figure(figsize=(10, 8))
    contour1 = plt.contourf(
        X1, Y1, F1_values, levels=50, cmap="viridis"
    )  # Filled contour
    plt.contour(
        X1, Y1, F1_values, levels=contour1.levels, colors="k", linewidths=0.5
    )  # Contour lines
    plt.colorbar(contour1, label="f(x, y, z_opt)")
    plt.plot(
        x_opt,
        y_opt,
        "ro",
        markersize=8,
        label=f"Global Minimum at ({x_opt}, {y_opt}, {z_opt})",
    )
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(f"Slice of Rastrigin Function at z = {Z_fixed_val:.2f}")
    plt.legend()
    plt.grid(True, linestyle=":", alpha=0.6)
    plt.show()

In [None]:
from plotly.offline import init_notebook_mode
import numpy as np

In [None]:
init_notebook_mode(connected=True)  # Call this once

bounds = np.array([[-10, -10, -10], [10, 10, 10]])

In [None]:
def objective(params):
    x, y, z = params[0], params[1], params[2]
    return (x - 4) ** 2 + (y - 5) ** 2 + (z + 6) ** 2

In [None]:
pso_plotly = ParticleSwarmOptimizer(
    objective, bounds, n_particles=20, w=0.5, c1=0.8, c2=0.9
)
gbest_pos, gbest_cost = pso_plotly.optimize(
    max_iterations=100, verbose=True, store_history=True
)

In [None]:
vis = Visualizer(pso_plotly)
vis.create_plotly_animation(width=800, height=600, duration=500)

From [wikipedia](https://en.wikipedia.org/wiki/Rastrigin_function):
> In mathematical optimization, the Rastrigin function is a non-convex function used as a performance test problem for optimization algorithms. It is a typical example of non-linear multimodal function.

In [None]:
def rastrigin_function(params):
    x, y, z = params[0], params[1], params[2]

    # Shift the coordinates so the minimum is at (4, 5, -6)
    x_shifted = x - 4
    y_shifted = y - 5
    z_shifted = z + 6

    A = 10
    n = 3  # Number of dimensions

    term1 = x_shifted**2 - A * np.cos(2 * np.pi * x_shifted)
    term2 = y_shifted**2 - A * np.cos(2 * np.pi * y_shifted)
    term3 = z_shifted**2 - A * np.cos(2 * np.pi * z_shifted)

    return (A * n) + term1 + term2 + term3

In [None]:
viz_rastrigin(rastrigin_function)

In [None]:
pso_plotly = ParticleSwarmOptimizer(
    rastrigin_function, bounds, n_particles=100, w=0.5, c1=1.5, c2=1.5
)
gbest_pos, gbest_cost = pso_plotly.optimize(
    max_iterations=50, verbose=True, store_history=True
)

In [None]:
vis = Visualizer(pso_plotly)
vis.create_plotly_animation(width=800, height=600, duration=500)