In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax.random as jr
import optax
import jaxopt
from jax import numpy as jnp, jit, vmap, grad
from functools import reduce, partial
from tqdm import tqdm
from wishart_loss import *

import warnings
warnings.filterwarnings('ignore')

In [None]:
def batch_optimize_with_adam(loss_fn, mu_schedule, x_init, num_iterations, learning_rate):
    opt = optax.adam(learning_rate)
    
    @jit
    def step(mu, params, opt_state):
        grads = grad(partial(loss_fn, mu))(params)
        updates, opt_state = opt.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state

    def optimize_single(params):
        opt_state = opt.init(params)
        for i in range(num_iterations):
            mu = mu_schedule(i / num_iterations)
            params, opt_state = step(mu, params, opt_state)
        return params

    final_params = vmap(optimize_single)(x_init)
    return final_params


def estimated_order(noisy_loss, num_parameters, num_samples=100, mu=0.25):
    rng = np.random.default_rng(43)
    x = 2*np.pi*rng.uniform(size=(num_samples, num_parameters))
    
    mu_vals = vmap(partial(noisy_loss, mu))(x)
    vals = vmap(partial(noisy_loss, 0.))(x)
        
    k_est = np.log(mu_vals.std()/vals.std()) / np.log(1-mu)
    
    return k_est


def tanh_schedule(strength: float, max_mu: float, t_stop: float):
    def f(t):
        return max_mu * max(0, np.tanh(strength * (1 - t / t_stop))) / np.tanh(strength)
    return f

In [None]:
def run_experiments(dim, ranks, schedules, x,
                    num_iterations=4000,
                    learning_rate=0.05,
                    num_matrices=100,
                    seed=42):
    total_losses = {}
    for rank in ranks:
        print(f'Rank: {rank}')
        wmatrix_batch = gen_wishart_batch(2**dim, rank, num_matrices, seed)
    
        final_losses_classic = []
        final_losses_fourier = [[] for _ in range(len(schedules))]
    
        for i in tqdm(range(num_matrices)):
            wmatrix = wmatrix_batch[i]
    
            # Classic Loss
            classic_loss_fn = jit(gen_loss(dim, wmatrix))
            opt = optax.adam(learning_rate)
            solver = jaxopt.OptaxSolver(fun=classic_loss_fn, opt=opt, maxiter=num_iterations, jit=True)
            results = vmap(lambda params: solver.run(params).state.value)(x)
            final_losses_classic.append(results)
    
            # Noisy Loss
            noisy_loss_fn = lambda mu, params: gen_noisy_loss(dim, wmatrix)(params, mu)
            k_est = estimated_order(noisy_loss_fn, num_parameters=dim, mu=0.5)
            nloss = jit(lambda mu, params: noisy_loss_fn(mu, params) / ((1-mu)**k_est + 1e-6))
            for j, mu_schedule in enumerate(schedules):
                final_params_fourier = batch_optimize_with_adam(nloss,
                                                                mu_schedule,
                                                                x,
                                                                num_iterations,
                                                                learning_rate)
                final_losses_fourier[j].append(vmap(partial(noisy_loss_fn, 0))(final_params_fourier))
        
        d = {'classic': np.array(final_losses_classic)}
        for j in range(len(schedules)):
            d[f'fourier_{j}'] = np.array(final_losses_fourier[j])
        total_losses[rank] = d
    return total_losses


def compute_metrics(total_losses):
    total_results = {}
    reshape_parameter = 10
    for rank in total_losses:
        res = {'classic': np.zeros((reshape_parameter, 5))}
        for j in range(len(schedules)):
            res[f'fourier_{j}'] = np.zeros((reshape_parameter, 5))
        for i in range(reshape_parameter):
            losses_classic = np.round(total_losses[rank]['classic'].reshape(reshape_parameter,-1)[i], 2)
            losses_fourier = [np.round(total_losses[rank][f'fourier_{j}'].reshape(reshape_parameter,-1)[i], 2) for j in range(len(schedules))]
            for j in range(len(schedules)):
                bin_edges = np.histogram_bin_edges([losses_classic, losses_fourier[j]], bins=50)
                bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
                
                classic_heights, _ = np.histogram(losses_classic, bins=bin_edges)
                fourier_heights, _ = np.histogram(losses_fourier[j], bins=bin_edges)
                
                # 1. Weighted mean
                classic_weighted_mean = np.sum(bin_centers * classic_heights) / np.sum(classic_heights)
                fourier_weighted_mean = np.sum(bin_centers * fourier_heights) / np.sum(fourier_heights)
        
                best_sol = np.percentile(losses_classic, 0.1)
        
                # 2. Top 5%
                threshold = np.percentile(losses_classic, 5)
                classic_left_mass_ratio_5 = np.sum(classic_heights[bin_centers <= threshold]) / np.sum(classic_heights)
                fourier_left_mass_ratio_5 = np.sum(fourier_heights[bin_centers <= threshold]) / np.sum(fourier_heights)
                
                # 3. Top 1
                classic_prob_best_sol = np.sum(losses_classic <= best_sol) / losses_classic.shape[0]
                fourier_prob_best_sol = np.sum(losses_fourier[j] <= best_sol) / losses_fourier[j].shape[0]
                
                res[f'fourier_{j}'][i] += np.array([fourier_prob_best_sol,
                                                    fourier_weighted_mean,
                                                    fourier_left_mass_ratio_5])
                
            res['classic'][i] += np.array([classic_prob_best_sol,
                                           classic_weighted_mean,
                                           classic_left_mass_ratio_5])
        total_results[rank] = res
    return total_results

In [None]:
dim = 8
ranks = [10, 30, 50, 70, 90, 110, 130, 150, 170, 190, 210, 230, 250]

num_samples = 2000
rng = np.random.default_rng(11)
x0 = 2 * np.pi * rng.uniform(size=(num_samples, dim))

mu_schedule = tanh_schedule(strength=5, max_mu=0.4, t_stop=0.75)
schedules = [mu_schedule]

total_losses = run_experiments(dim=dim,
                               ranks=ranks,
                               schedules=schedules,
                               x=x0)
total_results = compute_metrics(total_losses)