Skip to content

Commit

Permalink
fix for learning rate decay on validation performance
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Feb 6, 2015
1 parent b8d8d19 commit dbd4d92
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
2 changes: 2 additions & 0 deletions src/coffee/validation-performance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ export register
type ValidationPerformance <: Coffee
validation_net :: Net

ValidationPerformance(net::Net) = new(net, Function[])

# listeners will be notified each time we compute
# performance on the validation set
listeners :: Vector{Function}
Expand Down
45 changes: 24 additions & 21 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export load_snapshot
############################################################
abstract LearningRatePolicy
module LRPolicy
using ..Mocha.LearningRatePolicy
using ..Mocha
type Fixed <: LearningRatePolicy
base_lr :: FloatingPoint
end
Expand All @@ -37,6 +37,25 @@ end

# curr_lr *= gamma whenever performance
# drops on the validation set
function decay_on_validation_listener(policy, key::String, coffee_lounge::CoffeeLounge, net::Net, state::SolverState)
stats = get_statistics(coffee_lounge, key)
index = sort(collect(keys(stats)))
if length(index) > 1
if stats[index[end]] < stats[index[end-1]]
# performance drop
info("lr decay %e -> %e", policy.curr_lr, policy.curr_lr*policy.gamma)
policy.curr_lr *= policy.gamma

# revert to a previously saved "good" snapshot
if isa(policy.solver, Solver)
info("reverting to previous saved snapshot")
solver_state = load_snapshot(net, solver.params.load_from, state)
info("snapshot at iteration %d loaded", solver_state.iter)
copy_solver_state!(state, solver_state)
end
end
end
end
type DecayOnValidation <: LearningRatePolicy
base_lr :: FloatingPoint
gamma :: FloatingPoint
Expand All @@ -50,23 +69,7 @@ type DecayOnValidation <: LearningRatePolicy
policy = new(base_lr, gamma, key, base_lr)
policy.solver = nothing
policy.listener = (coffee_lounge,net,state) -> begin
stats = get_statistics(coffee_lounge, key)
index = sort(keys(stats))
if length(index) > 1
if stats[index[end]] < stats[index[end-1]]
# performance drop
@info("lr decay %e -> %e", policy.curr_lr, policy.curr_lr*policy.gamma)
policy.curr_lr *= policy.gamma

# revert to a previously saved "good" snapshot
if isa(policy.solver, Solver)
@info("reverting to previous saved snapshot")
solver_state = load_snapshot(net, solver.params.load_from, state)
@info("snapshot at iteration %d loaded", solver_state.iter)
copy_solver_state!(state, solver_state)
end
end
end
decay_on_validation_listener(policy, key, coffee_lounge, net, state)
end

policy
Expand Down Expand Up @@ -101,8 +104,8 @@ get_learning_rate(policy::LRPolicy.Inv, state::SolverState) =
policy.base_lr * (1 + policy.gamma * state.iter) ^ (-policy.power)


function setup(policy::DecayOnValidation, validation::ValidationPerformance, solver::Solver)
validation.add_listener(policy.listener)
function setup(policy::LRPolicy.DecayOnValidation, validation::ValidationPerformance, solver::Solver)
register(validation, policy.listener)
policy.solver = solver
end

Expand Down Expand Up @@ -181,7 +184,7 @@ function get_momentum(policy::MomPolicy.Staged, state::SolverState)
maxiter = policy.stages[policy.curr_stage][1]
while state.iter >= maxiter && policy.curr_stage < length(policy.stages)
policy.curr_stage += 1
@info("Staged learning rate policy: switching to stage $(policy.curr_stage)")
@info("Staged momentum policy: switching to stage $(policy.curr_stage)")
maxiter = policy.stages[policy.curr_stage][1]
end
end
Expand Down

0 comments on commit dbd4d92

Please sign in to comment.