In [59]:
import time
import numpy as np
import autograd.numpy as anp
from autograd import grad
from scipy.optimize import brentq, NonlinearConstraint, Bounds

from algorithms import GD, GDA, Projector


def true_euclidean_projector(x: np.ndarray) -> np.ndarray:
    """
    Computes the true Euclidean projection of vector x onto the set C:
    C = {y in R^n_{++} | product(y_i) >= 1}

    This method uses the KKT conditions and root-finding (Brent's method)
    and is robust even when x contains negative or zero components.

    Args:
        x: A 1D numpy array representing the vector to project.

    Returns:
        A 1D numpy array representing the projected vector P_C(x).
    """
    epsilon = 1e-12  # Use a slightly smaller epsilon for robust log/sqrt

    # --- Step 1: Simple Bounds Check and Early Exit ---
    # First, project onto the R^n_{++} orthant.
    y_bounds = np.maximum(x, epsilon)

    # Check if this bounded vector already satisfies the product constraint (is in C).
    # log(product(y)) >= 0  <=>  sum(log(y)) >= 0
    if np.sum(np.log(y_bounds)) >= 0:
        return y_bounds

    # --- Step 2: Iterative Solution for Active Constraint ---
    # The constraint is active: P_C(x) lies on the boundary product(y_i) = 1.

    def get_y(lam):
        """Calculates y_i(lambda) from the KKT stationarity condition."""
        # y_i = (x_i + sqrt(x_i^2 + 4*lambda)) / 2
        # This solution automatically ensures y_i > 0 for lambda > 0.
        val = (x + np.sqrt(x**2 + 4 * lam)) / 2
        # Use epsilon for the final projection to ensure strict positivity R^n_{++}
        return np.maximum(val, epsilon)

    def objective(lam):
        """The dual objective function to find lambda: sum(log(y_i)) - log(1) = 0."""
        y = get_y(lam)
        # Note: We seek the root where sum(log(y)) = 0
        return np.sum(np.log(y))

    # --- Step 3: Find Bracket for brentq ---
    lam_min = 0.0  # The lower bound of lambda (when constraint is non-active)
    lam_max = 1.0

    # Find a lam_max such that objective(lam_max) > 0 (i.e., product(y) > 1)
    while objective(lam_max) < 0:
        lam_max *= 2
        # Prevent indefinite loop for extremely difficult cases
        if lam_max > 1e12:
            break

    # --- Step 4: Find the Optimal Lambda ---
    try:
        # brentq is highly efficient for finding the root of a single-variable function
        # The objective function is guaranteed to be monotonic, which is ideal for brentq.
        lam_opt = brentq(objective, lam_min, lam_max)
        return get_y(lam_opt)
    except ValueError as e:
        # This usually means the bracket failed. Return the simple bounds projection as a safe fallback.
        print(f"Brent's method failed: {e}. Falling back to bounds projection.")
        return y_bounds


def run_experiment(n_values, gda_multiplier=5.0, with_fixed_step_size=None):
    np.random.seed(42)
    results = []
    for n in n_values:
        e = np.arange(0, n, dtype=np.float64)
        a = np.random.uniform(1e-12, 1.0, size=n)
        beta = 0.741271
        alpha = 3 * (beta**1.5) * np.sqrt(n + 1)
        L = 4 * (beta**1.5) * np.sqrt(n) + 3 * alpha
        step_size = 1.0 / L if with_fixed_step_size is None else with_fixed_step_size

        def objective(x):
            xt_x = anp.dot(x, x)
            at_x = anp.dot(a, x)
            et_x = anp.dot(e, x)
            term3 = (beta / anp.sqrt(1 + beta * xt_x)) * et_x
            return at_x + alpha * xt_x + term3

        x0 = np.random.rand(n)

        gda_solver = GDA(objective, projector=true_euclidean_projector)
        start_time = time.time()
        res_gda = gda_solver.solve(
            x0=x0,
            lambda_0=gda_multiplier * step_size,
            sigma=0.1,
            kappa=0.5,
            max_iter=1000,
            stop_if_stationary=True,
        )

        time_gda = time.time() - start_time
        gd_solver = GD(objective, projector=true_euclidean_projector)
        start_time = time.time()
        res_gd = gd_solver.solve(
            x0=x0,
            step_size=step_size,
            max_iter=1000,
            stop_if_stationary=True,
        )
        time_gd = time.time() - start_time
        results.append((n, res_gda, time_gda, res_gd, time_gd))

    return results

In [60]:
n_values = [10, 20, 50, 100, 200, 500, 1000, 2000, 3000, 10000]

In [61]:

def print_results(results, title="Results"):
    print(f"+{'-' * 100}+")
    print(f"|{title:^100}|")
    print(f"+{'-' * 100}+")
    # Print a table of results with f opt, time, iters
    print(f"|{'n':^7} | {'Algorithm GDA (proposed)':^43} | {'Algorithm GD':^43} |")
    print(f"+{'-' * 8}+{'-' * 45}+{'-' * 45}+")
    print(f"|{'':>7} | {'f_opt':^16} | {'time (s)':^11} | {'iters':^10} | {'f_opt':^16} | {'time (s)':^11} | {'iters':^10} |")
    print(f"+{'-' * 8}+{'-' * 18}+{'-' * 13}+{'-' * 12}+{'-' * 18}+{'-' * 13}+{'-' * 12}+")
    for n, res_gda, time_gda, res_gd, time_gd in results:
        print(f"| {n:6d} | {res_gda.f_opt:16.6f} | {time_gda:11.6f} | {len(res_gda.x_history):10d} | {res_gd.f_opt:16.6f} | {time_gd:11.6f} | {len(res_gd.x_history):10d} |")
    print(f"+{'-' * 8}+{'-' * 18}+{'-' * 13}+{'-' * 12}+{'-' * 18}+{'-' * 13}+{'-' * 12}+")

In [62]:
results_0 = run_experiment(n_values, gda_multiplier=5.0, with_fixed_step_size=None)
print_results(results_0, title="Results with Adaptive Step Size (gda_multiplier=5.0)")

+----------------------------------------------------------------------------------------------------+
|                        Results with Adaptive Step Size (gda_multiplier=5.0)                        |
+----------------------------------------------------------------------------------------------------+
|   n    |          Algorithm GDA (proposed)           |                Algorithm GD                 |
+--------+---------------------------------------------+---------------------------------------------+
|        |      f_opt       |  time (s)   |   iters    |      f_opt       |  time (s)   |   iters    |
+--------+------------------+-------------+------------+------------------+-------------+------------+
|     10 |        80.080567 |    0.007197 |         18 |        80.080567 |    0.004858 |         19 |
|     20 |       219.590296 |    0.005251 |         17 |       219.590296 |    0.003741 |         19 |
|     50 |       852.496543 |    0.002952 |         17 |       852.496543

In [63]:
results_1 = run_experiment(n_values, gda_multiplier=2.0, with_fixed_step_size=None)
print_results(results_1, title="Results with Adaptive Step Size (gda_multiplier=2.0)")

+----------------------------------------------------------------------------------------------------+
|                        Results with Adaptive Step Size (gda_multiplier=2.0)                        |
+----------------------------------------------------------------------------------------------------+
|   n    |          Algorithm GDA (proposed)           |                Algorithm GD                 |
+--------+---------------------------------------------+---------------------------------------------+
|        |      f_opt       |  time (s)   |   iters    |      f_opt       |  time (s)   |   iters    |
+--------+------------------+-------------+------------+------------------+-------------+------------+
|     10 |        80.080567 |    0.005774 |          8 |        80.080567 |    0.005454 |         19 |
|     20 |       219.590296 |    0.002177 |          8 |       219.590296 |    0.005913 |         19 |
|     50 |       852.496543 |    0.001804 |          8 |       852.496543

In [64]:
results_2 = run_experiment(n_values, gda_multiplier=5.0, with_fixed_step_size=0.1)
print_results(results_2, title="Results with Fixed Step Size (gda_multiplier=5.0, step_size=0.1)")

+----------------------------------------------------------------------------------------------------+
|                  Results with Fixed Step Size (gda_multiplier=5.0, step_size=0.1)                  |
+----------------------------------------------------------------------------------------------------+
|   n    |          Algorithm GDA (proposed)           |                Algorithm GD                 |
+--------+---------------------------------------------+---------------------------------------------+
|        |      f_opt       |  time (s)   |   iters    |      f_opt       |  time (s)   |   iters    |
+--------+------------------+-------------+------------+------------------+-------------+------------+
|     10 |        80.080567 |    0.014760 |         46 |        80.080567 |    0.001612 |          8 |
|     20 |       219.590296 |    0.016659 |         58 |       219.590296 |    0.004688 |         13 |
|     50 |       852.496543 |    0.026083 |         90 |       852.496543

In [65]:
results_3 = run_experiment(n_values, gda_multiplier=2.0, with_fixed_step_size=0.1)
print_results(results_3, title="Results with Fixed Step Size (gda_multiplier=2.0, step_size=0.1)")

+----------------------------------------------------------------------------------------------------+
|                  Results with Fixed Step Size (gda_multiplier=2.0, step_size=0.1)                  |
+----------------------------------------------------------------------------------------------------+
|   n    |          Algorithm GDA (proposed)           |                Algorithm GD                 |
+--------+---------------------------------------------+---------------------------------------------+
|        |      f_opt       |  time (s)   |   iters    |      f_opt       |  time (s)   |   iters    |
+--------+------------------+-------------+------------+------------------+-------------+------------+
|     10 |        80.080567 |    0.006069 |         19 |        80.080567 |    0.003014 |          8 |
|     20 |       219.590296 |    0.008320 |         26 |       219.590296 |    0.002887 |         13 |
|     50 |       852.496543 |    0.009598 |         41 |       852.496543

In [68]:
results_4 = run_experiment(n_values, gda_multiplier=5.0, with_fixed_step_size=0.05)
print_results(results_4, title="Results with Fixed Step Size (gda_multiplier=5.0, step_size=0.05)")

+----------------------------------------------------------------------------------------------------+
|                 Results with Fixed Step Size (gda_multiplier=5.0, step_size=0.05)                  |
+----------------------------------------------------------------------------------------------------+
|   n    |          Algorithm GDA (proposed)           |                Algorithm GD                 |
+--------+---------------------------------------------+---------------------------------------------+
|        |      f_opt       |  time (s)   |   iters    |      f_opt       |  time (s)   |   iters    |
+--------+------------------+-------------+------------+------------------+-------------+------------+
|     10 |        80.080567 |    0.008128 |         24 |        80.080567 |    0.001741 |         13 |
|     20 |       219.590296 |    0.006121 |         33 |       219.590296 |    0.000000 |          9 |
|     50 |       852.496543 |    0.009683 |         50 |       852.496543

In [70]:
results_5 = run_experiment(n_values, gda_multiplier=2.0, with_fixed_step_size=0.05)
print_results(results_5, title="Results with Fixed Step Size (gda_multiplier=2.0, step_size=0.05)")

+----------------------------------------------------------------------------------------------------+
|                 Results with Fixed Step Size (gda_multiplier=2.0, step_size=0.05)                  |
+----------------------------------------------------------------------------------------------------+
|   n    |          Algorithm GDA (proposed)           |                Algorithm GD                 |
+--------+---------------------------------------------+---------------------------------------------+
|        |      f_opt       |  time (s)   |   iters    |      f_opt       |  time (s)   |   iters    |
+--------+------------------+-------------+------------+------------------+-------------+------------+
|     10 |        80.080567 |    0.003303 |          8 |        80.080567 |    0.005770 |         13 |
|     20 |       219.590296 |    0.002999 |         13 |       219.590296 |    0.002006 |          9 |
|     50 |       852.496543 |    0.005893 |         21 |       852.496543