In [1]:
# libraries

using LinearAlgebra
using Random
using RandomMatrices
using StatsBase
using Statistics
using ProximalOperators
using Plots

In [2]:
# sample generation

function matrix_gen(
    m  :: Int,
    n  :: Int,
    p :: Real, # noise?
    κ :: Real  # condition number
    )
    # generate b_i = sign(<a_i, u^\star>) where cond(A) = κ

    u_opt = randn(n)

    A = Matrix(undef, m, n)
    
    b1 = zeros(m); b2 = zeros(m)
    
    wv = aweights([p, 1-p])
    
    for i = 1:m
        
        Ai = randn(n)
        
        A[i,:] = Ai
        
        bi = sign(Ai'u_opt)
        
        b1[i] = bi
        
        b2[i] = sample([-1, 1], wv) * bi
        
    end
    
    return A, b1, b2, u_opt
    
end

matrix_gen (generic function with 1 method)

In [3]:
# Generate data, define the objective function, and set parameters

Random.seed!(304)

# without/with noise & condition number 1
X, y, y2, x = matrix_gen(1000, 40, 0.01, 15)

m, n = size(X)

# the objective function
function F(θ)
    sum(log.(1 .+ exp.(-y .* (X*θ)))) / m
end

function F2(θ)
    sum(log.(1 .+ exp.(-y2 .* (X*θ)))) / m
end

# optimal value
Fstar = F(x)
Fstar2 = F2(x)

# parameter setting
T = 100 # # of experiments
K = 1e3 # total iterations
γinits = exp.(collect(range(log(10^(-2)), log(10^5), length = 15)))
ϵ = 0.05 # tolerance
ϵ_bisection = 1e-03
β = 0.6 # exponent of step size

0.6

# Figure 5 (a)

$m = 1000, n = 40, \kappa(A) = 15, p = 0$.

In [4]:
F(0.4x) - Fstar

0.11283644369415241

In [5]:
F2(0.3x) - Fstar2

0.11076506472118836

In [None]:
# experiments for SGM (Figure 5 (a))

# initialization
iterations_SGM = zeros(T, 15)

# main iteration
for i = 1:15
    for j = 1:T
        γinit = γinits[i]
        θ = 0.4 * x#zeros(n)
        for k = 1:K
            idx = sample(1:m) # sampling
            γ = γinit * k^(-β) # update stepsize
            
            # SGM update
            yXθ = exp(-y[idx] * (X[idx,:]'θ))
            θ = θ + (γ * (y[idx] * yXθ / (1 + yXθ))) * X[idx,:]
            
            if (abs(F(θ) - Fstar) <= ϵ || k >= K)
                iterations_SGM[j, i] = k
                break
            end                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
        end
    end
end

In [None]:
sgm_iters = [[quantile(iterations_SGM[:,i], 0.05) for i = 1:15] median(iterations_SGM, dims = 1)' [quantile(iterations_SGM[:,i], 0.95) for i = 1:15]]

In [None]:
# experiments for the truncated model (Figure 5 (a))

# initialization
iterations_truncated = zeros(T, 15)

# main iteration
for i = 1:15
    for j = 1:T
        γinit = γinits[i]
        θ = 0.4 * x
        for k = 1:K
            idx = sample(1:m) # sampling
            γ = γinit * k^(-β) # update stepsize
            
            # Update for the truncated model
            yXθ = exp(-y[idx] * (X[idx,:]'θ))
            ∇f = (-y[idx] * yXθ / (1 + yXθ)) * X[idx,:]
            θ = θ -  min(γ, log(1 + yXθ)/norm(∇f)^2)*∇f
            
            if (abs(F(θ) - Fstar) <= ϵ || k >= K)
                iterations_truncated[j, i] = k
                break
            end
        end
    end
end

In [None]:
trunc_iters = [[quantile(iterations_truncated[:,i], 0.05) for i = 1:15] median(iterations_truncated, dims = 1)' [quantile(iterations_truncated[:,i], 0.95) for i = 1:15]]

In [None]:
# experiments for the proximal model (Figure 5 (a))

function f(x, xk, a, b, α)
    return log(1 + exp(-b * (a'x))) + norm(x - xk)^2 / 2 / α
end

# initialization
iterations_proximal = zeros(T, 15)

# main iteration
for i = 1:15
    for j = 1:T
        #println("start i = ", i, " and j = ", j)
        γinit = γinits[i]
        θ = 0.4 * x
        for k = 1:K
            idx = sample(1:m) # sampling
            γ = γinit * k^(-β) # update stepsize
            
            Xidx = X[idx,:]; yidx = y[idx]
            
            # Update for the proximal model
            yXθ = exp(-yidx * (Xidx'θ))
            Δx = (γ * (yidx * yXθ / (1 + yXθ))) * Xidx
            
            searchs = [f(θ + l/1000 * Δx, θ, Xidx, yidx, γ) for l = 1:1000]
            
            θ = θ + (findmin(searchs)[2])/1000 * Δx
            
            if (abs(F(θ) - Fstar) <= ϵ || k >= K)
                iterations_proximal[j, i] = k
                break
            end
        end
    end
end

In [None]:
prox_iters = [[quantile(iterations_proximal[:,i], 0.05) for i = 1:15] median(iterations_proximal, dims = 1)' [quantile(iterations_proximal[:,i], 0.95) for i = 1:15]]

In [None]:
# experiments for the bundle model (Figure 5 (a))

# initialization
iterations_bundle = zeros(T, 15)

# main iteration
for i = 1:15
    for j = 1:T
        γinit = γinits[i]
        θ = 0.4 * x
        for k = 1:K
            idx = sample(1:m) # sampling
            γ = γinit * k^(-β) # update stepsize
            
            # Update for the bundle model
            yXθ = exp(-y[idx] * (X[idx,:]'θ))
            ∇fθ = (-y[idx] * yXθ / (1 + yXθ)) * X[idx,:]
            fθ = log(1 + yXθ)
            θy = θ - γ*∇fθ
            
            yXθy = exp(-y[idx] * (X[idx,:]'θy))
            ∇fθy = (-y[idx] * yXθy / (1 + yXθy)) * X[idx,:]
            fθy = log(1 + yXθy)
            θ̄ = θ - γ*∇fθy
            if fθ + ∇fθ'*(θ̄-θ) <= fθy + ∇fθy'*(θ̄-θy)
                θ = θ̄
            else
                λ = (fθy-fθ+γ*norm(∇fθ)^2)/(γ*norm(∇fθy-∇fθ)^2)
                θ -= γ*( (1-λ)*∇fθ + λ*∇fθy  )
            end
            
            if (abs(F(θ) - Fstar) <= ϵ || k >= K)
                iterations_bundle[j, i] = k
                break
            end
        end
    end
end

In [None]:
bund_iters = [[quantile(iterations_bundle[:,i], 0.05) for i = 1:15] median(iterations_bundle, dims = 1)' [quantile(iterations_bundle[:,i], 0.95) for i = 1:15]]

In [None]:
plt = plot(γinits, [sgm_iters[:,2] trunc_iters[:, 2] prox_iters[:, 2] bund_iters[:, 2]], 
    legend = :left,
    xscale = :log10,
    xlab = "Initial Step Size",
    ylab = "Time to accuracy, eps=0.05",
    title = "Figure 5. (a)",
    label = ["SGM" "Truncated" "Proximal" "bundle"],
    color = [:red :green :blue :black],
    markershape = [:dtriangle :circle :square :diamond])
a = Shape([γinits; γinits],
    [sgm_iters[:,1]; sgm_iters[:,3]])
plot!(a, fillalpha = 0.2, linecolor = nothing, fillcolor = :red, label = "")
b = Shape([γinits; γinits[end:-1:2]],
    [trunc_iters[:,1]; trunc_iters[end:-1:2,3]])
plot!(b, fillalpha = 0.2, linecolor = nothing, fillcolor = :green, label = "")
c = Shape([γinits; γinits[end:-1:2]],
    [prox_iters[:,1]; prox_iters[end:-1:2,3]])
plot!(c, fillalpha = 0.2, linecolor = nothing, fillcolor = :blue, label = "")
d = Shape([γinits; γinits[end-1:-1:2]],
    [bund_iters[:,1]; bund_iters[end-1:-1:2,3]])
plot!(d, fillalpha = 0.2, linecolor = nothing, fillcolor = :black, label = "")

# Figure 5 (b)

$m = 1000, n = 40, \kappa(A) = 15, p = 0.01$.

In [None]:
# experiments for SGM (Figure 5 (b))

# initialization
iterations_SGM = zeros(T, 15)

# main iteration
for i = 1:15
    for j = 1:T
        γinit = γinits[i]
        θ = 0.3 * x#zeros(n)
        for k = 1:K
            idx = sample(1:m) # sampling
            γ = γinit * k^(-β) # update stepsize
            
            # SGM update
            yXθ = exp(-y2[idx] * (X[idx,:]'θ))
            θ = θ + (γ * (y2[idx] * yXθ / (1 + yXθ))) * X[idx,:]
            
            if (abs(F2(θ) - Fstar2) <= ϵ || k >= K)
                #println(F2(θ))
                iterations_SGM[j, i] = k
                break
            end                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
        end
    end
end

In [None]:
sgm_iters = [[quantile(iterations_SGM[:,i], 0.05) for i = 1:15] median(iterations_SGM, dims = 1)' [quantile(iterations_SGM[:,i], 0.95) for i = 1:15]]

In [None]:
# experiments for the truncated model (Figure 5 (b))

# initialization
iterations_truncated = zeros(T, 15)

# main iteration
for i = 1:15
    for j = 1:T
        γinit = γinits[i]
        θ = 0.3 * x
        for k = 1:K
            idx = sample(1:m) # sampling
            γ = γinit * k^(-β) # update stepsize
            
            # Update for the truncated model
            yXθ = exp(-y2[idx] * (X[idx,:]'θ))
            ∇f = (-y2[idx] * yXθ / (1 + yXθ)) * X[idx,:]
            θ = θ -  min(γ, log(1 + yXθ)/norm(∇f)^2)*∇f
            
            if (abs(F2(θ) - Fstar2) <= ϵ || k >= K)
                #println(F2(θ))
                iterations_truncated[j, i] = k
                break
            end
        end
    end
end

In [None]:
trunc_iters = [[quantile(iterations_truncated[:,i], 0.05) for i = 1:15] median(iterations_truncated, dims = 1)' [quantile(iterations_truncated[:,i], 0.95) for i = 1:15]]

In [None]:
# experiments for the proximal model (Figure 5 (b))

function f(x, xk, a, b, α)
    return log(1 + exp(-b * (a'x))) + norm(x - xk)^2 / 2 / α
end

# initialization
iterations_proximal = zeros(T, 15)

# main iteration
for i = 1:15
    for j = 1:T
        #println("start i = ", i, " and j = ", j)
        γinit = γinits[i]
        θ = 0.3 * x
        for k = 1:K
            idx = sample(1:m) # sampling
            γ = γinit * k^(-β) # update stepsize
            
            Xidx = X[idx,:]; yidx = y2[idx]
            
            # Update for the proximal model
            yXθ = exp(-yidx * (Xidx'θ))
            Δx = (γ * (yidx * yXθ / (1 + yXθ))) * Xidx
            
            searchs = [f(θ + l/1000 * Δx, θ, Xidx, yidx, γ) for l = 1:1000]
            
            θ = θ + (findmin(searchs)[2])/1000 * Δx
            
            if (abs(F2(θ) - Fstar2) <= ϵ || k >= K)
                #println(abs(F2(θ) - Fstar2))
                iterations_proximal[j, i] = k
                break
            end
        end
    end
end

In [None]:
prox_iters = [[quantile(iterations_proximal[:,i], 0.05) for i = 1:15] median(iterations_proximal, dims = 1)' [quantile(iterations_proximal[:,i], 0.95) for i = 1:15]]

In [None]:
# experiments for the bundle model (Figure 5 (b))

# initialization
iterations_bundle = zeros(T, 15)

# main iteration
for i = 1:15
    for j = 1:T
        γinit = γinits[i]
        θ = 0.3 * x
        for k = 1:K
            idx = sample(1:m) # sampling
            γ = γinit * k^(-β) # update stepsize
            
            # Update for the bundle model
            yXθ = exp(-y2[idx] * (X[idx,:]'θ))
            ∇fθ = (-y2[idx] * yXθ / (1 + yXθ)) * X[idx,:]
            fθ = log(1 + yXθ)
            θy = θ - γ*∇fθ
            
            yXθy = exp(-y2[idx] * (X[idx,:]'θy))
            ∇fθy = (-y2[idx] * yXθy / (1 + yXθy)) * X[idx,:]
            fθy = log(1 + yXθy)
            θ̄ = θ - γ*∇fθy
            if fθ + ∇fθ'*(θ̄-θ) <= fθy + ∇fθy'*(θ̄-θy)
                θ = θ̄
            else
                λ = (fθy-fθ+γ*norm(∇fθ)^2)/(γ*norm(∇fθy-∇fθ)^2)
                θ -= γ*( (1-λ)*∇fθ + λ*∇fθy  )
            end
            
            if (abs(F2(θ) - Fstar2) <= ϵ || k >= K)
                #println(F2(θ))
                iterations_bundle[j, i] = k
                break
            end
        end
    end
end

In [None]:
bund_iters = [[quantile(iterations_bundle[:,i], 0.05) for i = 1:15] median(iterations_bundle, dims = 1)' [quantile(iterations_bundle[:,i], 0.95) for i = 1:15]]

In [None]:
plt = plot(γinits, [sgm_iters[:,2] trunc_iters[:, 2] prox_iters[:, 2] bund_iters[:, 2]], 
    legend = :left,
    xscale = :log10,
    xlab = "Initial Step Size",
    ylab = "Time to accuracy, eps=0.05",
    title = "Figure 5. (b)",
    label = ["SGM" "Truncated" "Proximal" "bundle"],
    color = [:red :green :blue :black],
    markershape = [:dtriangle :circle :square :diamond])
a = Shape([γinits; γinits],
    [sgm_iters[:,1]; sgm_iters[:,3]])
plot!(a, fillalpha = 0.2, linecolor = nothing, fillcolor = :red, label = "")
b = Shape([γinits; γinits[end:-1:2]],
    [trunc_iters[:,1]; trunc_iters[end:-1:2,3]])
plot!(b, fillalpha = 0.2, linecolor = nothing, fillcolor = :green, label = "")
c = Shape([γinits; γinits[end:-1:2]],
    [prox_iters[:,1]; prox_iters[end:-1:2,3]])
plot!(c, fillalpha = 0.2, linecolor = nothing, fillcolor = :blue, label = "")
d = Shape([γinits; γinits[end-1:-1:2]],
    [bund_iters[:,1]; bund_iters[end-1:-1:2,3]])
plot!(d, fillalpha = 0.2, linecolor = nothing, fillcolor = :black, label = "")