Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenzo Stella committed Oct 27, 2023
1 parent ebc39db commit 08f3213
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 28 deletions.
23 changes: 9 additions & 14 deletions experiments/dual_svm/runme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ using AdaProx
using Random
using Plots
using LaTeXStrings
using ProximalCore
using ProximalOperators: IndBox, IndZero
using ProximalCore: IndZero
using ProximalOperators: IndBox

pgfplotsx()

Expand All @@ -21,15 +21,10 @@ struct Quadratic{TQ,Tq}
q::Tq
end

function (f::Quadratic)(x)
function AdaProx.eval_with_pullback(f::Quadratic, x)
temp = f.Q * x
return 0.5 * dot(x, temp) + dot(x, f.q)
end

function ProximalCore.gradient!(grad, f::Quadratic, x)
temp = f.Q * x
grad .= temp + f.q
return 0.5 * dot(x, temp) + dot(x, f.q)
quadratic_pullback() = temp + f.q
return 0.5 * dot(x, temp) + dot(x, f.q), quadratic_pullback
end

function run_dsvm(
Expand All @@ -56,7 +51,7 @@ function run_dsvm(
f = Quadratic(Q, q)
g = IndBox(0.0, C)
h = IndZero()
A = y'
A = Matrix(y')

Lf = norm(Q)
x0 = zeros(N)
Expand Down Expand Up @@ -117,7 +112,7 @@ function plot_residual(path)
names_to_plot = []
for name in ["Condat-Vu", "Malitsky-Pock", "AdaPDM"]
matching_names = [k for k in keys(gb) if startswith(k.method, name)]
push!(names_to_plot, find_best(gb, matching_names, :norm_res, 1e-5, :grad_f_evals))
push!(names_to_plot, find_best(gb, matching_names, :norm_res, 1e-5, :f_evals))
end

fig = plot(
Expand All @@ -131,7 +126,7 @@ function plot_residual(path)
continue
end
plot!(
gb[k][!, :grad_f_evals],
gb[k][!, :f_evals],
gb[k][!, :norm_res],
yaxis = :log,
label = k.method,
Expand All @@ -143,7 +138,7 @@ end


function main(;maxit = 10_000)
keys_to_log = [:method, :it, :grad_f_evals, :norm_res]
keys_to_log = [:method, :it, :f_evals, :norm_res]

for C in [0.1, 1]
path = joinpath(@__DIR__, "svmguide3_C_$(C).jsonl")
Expand Down
5 changes: 5 additions & 0 deletions experiments/least_absolute_deviation/runme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ using AdaProx

pgfplotsx()

function AdaProx.eval_with_pullback(f::Zero, x)
zero_pullback() = zero(x)
return f(x), zero_pullback
end

function run_least_absolute_deviation(
filename,
::Type{T} = Float64;
Expand Down
29 changes: 15 additions & 14 deletions experiments/nesterov_worst_case/runme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,29 @@ struct WorstQuadratic{S,T}
L::T
end

function (f::WorstQuadratic)(x)
function AdaProx.eval_with_pullback(f::WorstQuadratic, x)
s = x[1]^2 + x[f.k]^2

for i in 1:(f.k-1)
s += (x[i] - x[i+1])^2
end
return (f.L / 4) * (s / 2 - x[1])
end

function ProximalCore.gradient!(grad, f::WorstQuadratic, x)
grad[1] = (f.L / 4) * (2 * x[1] - x[2] - 1)
for i in 2:(f.k-1)
grad[i] = (f.L / 4) * (2 * x[i] - x[i-1] - x[i+1])
function worst_quadratic_pullback()
grad = zero(x)
grad[1] = (f.L / 4) * (2 * x[1] - x[2] - 1)
for i in 2:(f.k-1)
grad[i] = (f.L / 4) * (2 * x[i] - x[i-1] - x[i+1])
end
grad[f.k] = (f.L / 4) * (2 * x[f.k] - x[f.k-1])
grad[(f.k+1):end] .= 0
return grad
end
grad[f.k] = (f.L / 4) * (2 * x[f.k] - x[f.k-1])
grad[(f.k+1):end] .= 0
# since f is quadratic, f(x) = 1/2 <x, Q x> + <q, x>
# meaning \nabla f(x) = Q x + q
# and f(x) = (1/2) <\nabla f(x), x> + (1/2) <q, x>
# since q = -(L/4) e_1, we obtain the following
return dot(grad, x) / 2 - (f.L / 8) * x[1]

return (f.L / 4) * (s / 2 - x[1]), worst_quadratic_pullback
end

(f::WorstQuadratic)(x) = AdaProx.eval_with_pullback(f, x)[1]

function run_nesterov_worst_case()
k = 100
n = 100
Expand Down
5 changes: 5 additions & 0 deletions experiments/square_root_lasso/runme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ using AdaProx

pgfplotsx()

function AdaProx.eval_with_pullback(f::Zero, x)
zero_pullback() = zero(x)
return f(x), zero_pullback
end

function run_square_root_lasso(
filename,
::Type{T} = Float64;
Expand Down

0 comments on commit 08f3213

Please sign in to comment.