In [None]:
import sys, os
from jax import config

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
config.update("jax_enable_x64", True)
sys.path.insert(0, os.path.abspath(os.path.join("..")))

import numpy as np
import jax.numpy as jnp
import pandas as pd

import plotly.graph_objects as go

from uot.problems.generators import GaussianMixtureGenerator

from uot.utils.costs import cost_euclid_squared
from uot.problems.inspect_store import plot_1d

from uot.solvers.gradient_ascent import (
    AdamGradientAscentSolver,
    GradientAscentMultiMarginalSGD,
    AMSGradSolver,
)

In [None]:
seed = 55


def get_gaussian_problems(seed):
    gen = GaussianMixtureGenerator(
        name="",
        dim=1,
        num_components=1,
        n_points=64,
        num_datasets=5,
        borders=(-1, 1),
        cost_fn=cost_euclid_squared,
        use_jax=False,
        seed=seed,
    )

    problems = []
    for p in gen.generate():
        problems.append(p)
    return problems

problems_list = get_gaussian_problems(seed=55)

In [None]:
prob1 = problems_list[0]
mu1, nu1 = prob1.get_marginals()

mu1_pts, mu1_w = mu1.to_discrete()
nu1_pts, nu1_w = nu1.to_discrete()

fig1 = plot_1d(mu1_pts, mu1_w, nu1_pts, nu1_w)
fig1

In [None]:
prob1.get_costs()[0].shape

In [None]:
prob1 = problems_list[0]
gradient_solver = AMSGradSolver()
result = gradient_solver.solve(
    marginals=prob1.get_marginals(),
    costs=prob1.get_costs(),
    reg=0.1,
    maxiter=10_000,
    tol=1e-6,
    learning_rate=0.001,
)

result

In [None]:
from uot.experiments.runner import run_pipeline
from uot.experiments.experiment import Experiment

from uot.solvers.sinkhorn import SinkhornTwoMarginalSolver
from uot.problems.iterator import OnlineProblemIterator
from uot.solvers.solver_config import SolverConfig
from uot.utils.instantiate_solver import instantiate_solver

import time

In [None]:
def solve_fn(problem, solver_instance, measures, costs, *args, **kwargs):
    start = time.perf_counter()
    results = solver_instance.solve(
        marginals=measures,
        costs=costs,
        **kwargs,
    )
    stop = time.perf_counter()
    results["runtime"] = stop - start
    return results


exp = Experiment(
    name="Testing Gradient Ascent Variations",
    solve_fn=solve_fn,
)

solvers = [
    SolverConfig(
        name="Adam Solver",
        solver=AdamGradientAscentSolver,
        param_grid=[
            {
                "reg": 0.01,
                "normalize_cost": True,
                "maxiter": 100_000,
                "tol": 1e-6,
                "learning_rate": 8e-3,
                "schedule": "constant",
                "schedule_kwargs": {},
            },
            {
                "reg": 0.01,
                "normalize_cost": False,
                "maxiter": 100_000,
                "tol": 1e-6,
                "learning_rate": 8e-3,
                "schedule": "constant",
                "schedule_kwargs": {},
            },
            # {
            #     "reg": 0.01,
            #     "maxiter": 100_000,
            #     "tol": 1e-6,
            #     "learning_rate": 8e-3,
            #     "schedule": "exponential",
            #     "schedule_kwargs": {
            #         "decay_rate": 0.99,
            #         "decay_steps": 5_000,
            #     },
            # },
            # {
            #     "reg": 0.01,
            #     "maxiter": 100_000,
            #     "tol": 1e-6,
            #     "learning_rate": 8e-3,
            #     "schedule": "cosine",
            #     "schedule_kwargs": {
            #         "total_steps": 100_000,
            #         "final_lr": 1e-4,
            #     },
            # },
        ],
        is_jit=True,
    )
]

seed = 55
iterators = [
    OnlineProblemIterator(
        GaussianMixtureGenerator(
            name="Gaussian (1d, 1c, 64p)",
            dim=1,
            num_components=1,
            n_points=64,
            num_datasets=15,
            borders=(-1, 1),
            cost_fn=cost_euclid_squared,
            use_jax=False,
            seed=seed,
        ),
        num=15,
        cache_gt=False,
    ),
    # OnlineProblemIterator(
    #     GaussianMixtureGenerator(
    #         name="",
    #         dim=1,
    #         num_components=1,
    #         n_points=16,
    #         num_datasets=2,
    #         borders=(-1, 1),
    #         cost_fn=cost_euclid_squared,
    #         use_jax=False,
    #         seed=seed,
    #     ),
    #     num=2,
    #     cache_gt=False,
    # ),
]

results = run_pipeline(
    experiment=exp,
    solvers=solvers,
    iterators=iterators,
    folds=1,
    progress=True,
)

In [None]:
fig = go.Figure()
colors = {'constant': 'blue', 'exponential': 'red', 'cosine': 'green', 'linear_warmup': 'purple'}

for scheduler in results['schedule'].unique():
    scheduler_data = results[results['schedule'] == scheduler]['iterations']
    
    fig.add_trace(go.Box(
        y=scheduler_data,
        name=scheduler,
        boxpoints='all', # show all points
        jitter=0.3,
        pointpos=-1.8,
        marker_color=colors[scheduler],
        showlegend=True
    ))

fig.update_layout(
    title='Iterations Required by Different Learning Rate Schedules',
    yaxis_title='Number of Iterations',
    boxmode='group'
)
fig.show()

In [None]:
schedules = results['schedule'].unique()

fig = go.Figure()

for schedule in schedules:
    schedule_data = results[results['schedule'] == schedule]
    
    fig.add_trace(go.Scatter(
        x=schedule_data['problem_index'],
        y=schedule_data['iterations'],
        mode='lines+markers',
        name=schedule,
        hovertemplate='Problem: %{x}<br>Iterations: %{y}<extra></extra>'
    ))

fig.update_layout(
    title='Iterations Required per Problem Index',
    xaxis_title='Problem Index',
    yaxis_title='Number of Iterations',
    showlegend=True,
    hovermode='x unified'
)

fig.show()

In [None]:
fig = go.Figure()

df = pd.DataFrame(results)
# Convert JAX arrays to regular Python values
df['iterations'] = df['iterations'].apply(lambda x: float(x) if hasattr(x, 'item') else x)
df['learning_rate'] = df['learning_rate'].apply(lambda x: float(x) if hasattr(x, 'item') else x)

# Get unique learning rates and their corresponding schedules
learning_rates = df.groupby(['learning_rate', 'schedule'], as_index=False)['iterations'].median()

# Add a line for each schedule
for schedule in schedules:
    schedule_data = learning_rates[learning_rates['schedule'] == schedule]
    
    fig.add_trace(go.Scatter(
        x=schedule_data['learning_rate'].astype(float),
        y=schedule_data['iterations'].astype(float),
        mode='lines+markers',
        name=schedule,
        hovertemplate='Learning Rate: %{x}<br>Iterations: %{y}<extra></extra>'
    ))

# Update layout
fig.update_layout(
    title='Median Number of Iterations by Learning Rate and Schedule',
    xaxis_title='Learning Rate',
    yaxis_title='Median Number of Iterations',
    showlegend=True,
    hovermode='x unified',
    template='plotly_white'
)

fig.show()

In [None]:
fig = go.Figure()

# Filter data for normalized and non-normalized cases
normalized = results[results['normalize_cost'] == True]['cost'].astype(float)
non_normalized = results[results['normalize_cost'] == False]['cost'].astype(float)

# Add box plots for each case
fig.add_trace(go.Box(
    y=normalized,
    name='Normalized Cost',
    boxpoints='all',
    jitter=0.3,
    pointpos=-1.8
))

fig.add_trace(go.Box(
    y=non_normalized,
    name='Non-normalized Cost',
    boxpoints='all',
    jitter=0.3,
    pointpos=-1.8
))

# Update layout
fig.update_layout(
    title='Distribution of Costs: Normalized vs Non-normalized',
    yaxis_title='Cost Value',
    boxmode='group',
    template='plotly_white'
)

fig.show()

In [None]:
# Calculate relative cost error
normalized_costs = results[results['normalize_cost'] == True]['cost'].astype(float)
non_normalized_costs = results[results['normalize_cost'] == False]['cost'].astype(float)

relative_errors = abs(normalized_costs.values - non_normalized_costs.values) / non_normalized_costs.values

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=list(range(len(relative_errors))),
    y=relative_errors,
    mode='lines+markers',
    name='Relative Cost Error',
    hovertemplate='Problem %{x}<br>Relative Error: %{y:.6f}<extra></extra>'
))

fig.update_layout(
    title='Relative Cost Error: |Normalized - Non-normalized| / Non-normalized',
    xaxis_title='Problem Index',
    yaxis_title='Relative Error',
    template='plotly_white',
    showlegend=True
)

fig.show()

In [None]:
# compute L-infinity (max absolute) errors between normalized and non-normalized runs
norm = results[results['normalize_cost'] == True].set_index('problem_index').sort_index()
nonnorm = results[results['normalize_cost'] == False].set_index('problem_index').sort_index()

common = norm.index.intersection(nonnorm.index)
rows = []

for idx in common:
    r1 = norm.loc[idx]
    r2 = nonnorm.loc[idx]

    cost_err = abs(float(r1['cost']) - float(r2['cost']))

    def linf(a, b):
        try:
            a_arr = np.asarray(a)
            b_arr = np.asarray(b)
            return float(np.max(np.abs(a_arr - b_arr)))
        except Exception:
            return float('nan')

    u_err = linf(r1['u_final'], r2['u_final'])
    v_err = linf(r1['v_final'], r2['v_final'])
    plan_err = linf(r1['transport_plan'], r2['transport_plan'])

    rows.append({
        'problem_index': idx,
        'cost_Linf': cost_err,
        'u_Linf': u_err,
        'v_Linf': v_err,
        'plan_Linf': plan_err,
    })

err_df = pd.DataFrame(rows).set_index('problem_index')
summary = err_df.max().to_dict()

print("Per-problem L-inf errors (normalized vs non-normalized):")
print(err_df)
print("\nMaximum L-inf over problems:")
print(summary)