Skip to content

Commit

Permalink
Merge pull request #6 from stokasto/accumulate_statistics
Browse files Browse the repository at this point in the history
Add coffee break to allow for accumulation of statistics and reproduce ~100 errors on MNIST
  • Loading branch information
pluskid committed Nov 29, 2014
2 parents 3392bb6 + b7a6053 commit 77274da
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 4 deletions.
102 changes: 102 additions & 0 deletions examples/mnist/mnist_dropout_fc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
hdf5_fns = ["data/train.hdf5", "data/test.hdf5"]
source_fns = ["data/train.txt", "data/test.txt"]
for i = 1:length(hdf5_fns)
if !isfile(hdf5_fns[i])
println("Data not found, use get-mnist.sh to generate HDF5 data")
exit(1)
else
open(source_fns[i], "w") do s
println(s, hdf5_fns[i])
end
end
end

#ENV["MOCHA_USE_NATIVE_EXT"] = "true"
#ENV["OMP_NUM_THREADS"] = 1
#blas_set_num_threads(1)
ENV["MOCHA_USE_CUDA"] = "true"

using Mocha

############################################################
# This is an example script for training a fully connected
# network with dropout on mnist.
#
# The network size is 784-1200-1200-10 with ReLU units
# in the hidden layers and a softmax output layer.
# The parameters for training the network were chosen
# to reproduce the results from the original dropout paper:
# http://arxiv.org/abs/1207.0580
# and the corresponding newer JMLR paper:
# http://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf
#
# Our parameters slightly differ. This is mainly due to the
# fact that in the original dropout paper the weights are scaled
# by 0.5 after training whereas we scale them by 2 during training.
#
# The settings in this script should currently produce a model that
# gets 100 errors (or 99 % accuracy) on the test set
# if you run it for the whole 2000 epochs (=600*2000 steps).
# This is slightly better than, but well within the error
# bars of the JMLR paper.
############################################################


# fix the random seed to make results reproducable
srand(12345678)

data_layer = HDF5DataLayer(name="train-data", source=source_fns[1], batch_size=100)
fc1_layer = InnerProductLayer(name="fc1", output_dim=1200, neuron=Neurons.ReLU(), weight_init = GaussianInitializer(std=0.01), bottoms=[:data], tops=[:fc1])
fc2_layer = InnerProductLayer(name="fc2", output_dim=1200, neuron=Neurons.ReLU(), weight_init = GaussianInitializer(std=0.01), bottoms=[:fc1], tops=[:fc2])
fc3_layer = InnerProductLayer(name="out", output_dim=10, bottoms=[:fc2], weight_init = ConstantInitializer(0), tops=[:out])
loss_layer = SoftmaxLossLayer(name="loss", bottoms=[:out,:label])

# setup dropout for the different layers
drop_input = DropoutLayer(name="drop_in", bottoms=[:data], ratio=0.2)
drop_fc1 = DropoutLayer(name="drop_fc1", bottoms=[:fc1], ratio=0.5)
drop_fc2 = DropoutLayer(name="drop_fc2", bottoms=[:fc2], ratio=0.5)

sys = System(CuDNNBackend())
#sys = System(CPUBackend())
init(sys)

common_layers = [fc1_layer, fc2_layer, fc3_layer]
drop_layers = [drop_input, drop_fc1, drop_fc2]
# put training net together, note that the correct ordering will automatically be established by the constructor
net = Net("MNIST-train", sys, [data_layer, common_layers..., drop_layers..., loss_layer])

# we let the learning rate decrease by 0.998 in each epoch (=600 batches of size 100)
# and let the momentum increase linearly from 0.5 to 0.9 over 500 epochs
# which is equivalent to an increase step of 0.0008
# training is done for 2000 epochs
params = SolverParameters(max_iter=600*2000, regu_coef=0.0,
mom_policy=MomPolicy.Linear(0.5, 0.0008, 600, 0.9),
lr_policy=LRPolicy.Step(0.1, 0.998, 600))
solver = SGD(params)

base_dir = "snapshots_dropout_fc"
# save snapshots every 5000 iterations
add_coffee_break(solver,
Snapshot(base_dir, auto_load=true),
every_n_iter=5000)

# show performance on test data every 600 iterations (one epoch)
# also log everything using the AccumulateStatistics module
data_layer_test = HDF5DataLayer(name="test-data", source=source_fns[2], batch_size=100)
acc_layer = AccuracyLayer(name="test-accuracy", bottoms=[:out, :label], report_error=true)
test_net = Net("MNIST-test", sys, [data_layer_test, common_layers..., acc_layer])
stats = AccumulateStatistics([ValidationPerformance(test_net), TrainingSummary()],
try_load = true, save = true, fname = "$(base_dir)/statistics.h5")
add_coffee_break(solver, stats, every_n_iter=600)

solve(solver, net)

#Profile.init(int(1e8), 0.001)
#@profile solve(solver, net)
#open("profile.txt", "w") do out
# Profile.print(out)
#end

destroy(net)
destroy(test_net)
shutdown(sys)
6 changes: 5 additions & 1 deletion src/coffee-break.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ type Morning <: CoffeeBreakTimeType end
type Evening <: CoffeeBreakTimeType end
end # module CoffeeBreakTime

# statistics returned by every coffee break
typealias StatsDict Dict{String, Real}

abstract Coffee
function init(::Coffee, ::Net) end
function enjoy(::Coffee, ::CoffeeBreakTimeType, ::Net, ::SolverState) end
function enjoy(::Coffee, ::CoffeeBreakTimeType, ::Net, ::SolverState) return StatsDict() end
function destroy(::Coffee, ::Net) end

type CoffeeBreak
Expand All @@ -34,4 +37,5 @@ end

include("coffee/training-summary.jl")
include("coffee/validation-performance.jl")
include("coffee/accumulator.jl")
include("coffee/snapshot.jl")
51 changes: 51 additions & 0 deletions src/coffee/accumulator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using HDF5, JLD

export AccumulateStatistics

type AccumulateStatistics <: Coffee
modules :: Array{Coffee, 1}
stats :: Dict{Integer,StatsDict}
try_load :: Bool
save :: Bool
fname :: String

AccumulateStatistics(modules :: Array{Coffee, 1}; stats = Dict(StatsDict()), try_load = false, save = false, fname = "" ) = new(modules, stats, try_load, save, fname)
end

function init(coffee::AccumulateStatistics, net::Net)
for m in coffee.modules
init(m, net)
end
if coffee.try_load && isfile(coffee.fname)
@warn("Statistics file already exists, trying to merge!")
stats = jldopen(coffee.fname, "r") do file
read(file, "statistics")
end
merge!(coffee.stats, stats)
end
end

function enjoy(coffee::AccumulateStatistics, time::CoffeeBreakTime.Morning, net::Net, state::SolverState)
step = state.iter
if ! haskey(coffee.stats, step)
coffee.stats[step] = StatsDict()
end
for m in coffee.modules
merge!(coffee.stats[step], enjoy(m, time, net, state))
end
end

function enjoy(coffee::AccumulateStatistics, time::CoffeeBreakTime.Evening, net::Net, state::SolverState)
step = state.iter
if ! haskey(coffee.stats, step)
coffee.stats[step] = StatsDict()
end
for m in coffee.modules
merge!(coffee.stats[step], enjoy(m, time, net, state))
end
if coffee.save
jldopen(coffee.fname, "w") do file
write(file, "statistics", coffee.stats)
end
end
end
1 change: 1 addition & 0 deletions src/coffee/training-summary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ end
function enjoy(::TrainingSummary, ::CoffeeBreakTime.Evening, ::Net, state::SolverState)
summary = @sprintf("%06d :: TRAIN obj-val = %.8f", state.iter, state.obj_val)
@info(summary)
return StatsDict(["obj-val" => state.obj_val])
end
3 changes: 2 additions & 1 deletion src/coffee/validation-performance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ function enjoy(coffee::ValidationPerformance, ::CoffeeBreakTime.Morning, ::Net,
end
end

show_statistics(coffee.validation_net, title="Performance on Validation Set")
result = show_statistics(coffee.validation_net, title="Performance on Validation Set")
reset_statistics(coffee.validation_net)
return result
end
function destroy(coffee::ValidationPerformance, ::Net)
# We don't destroy here as we didn't construct the network
Expand Down
6 changes: 6 additions & 0 deletions src/layers/accuracy.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@defstruct AccuracyLayer StatLayer (
name :: String = "accuracy",
report_error :: Bool = false,
(bottoms :: Vector{Symbol} = Symbol[], length(bottoms) == 2),
)

Expand Down Expand Up @@ -29,6 +30,11 @@ end
function show_statistics(state::AccuracyLayerState)
accuracy = @sprintf("%.4f%%", state.accuracy*100)
@info(" Accuracy (avg over $(state.n_accum)) = $accuracy")
res = Dict(["$(state.layer.name)-accuracy" => state.accuracy])
if state.layer.report_error
res["$(state.layer.name)-error"] = 1 - state.accuracy
end
return res
end

function forward(sys::System{CPUBackend}, state::AccuracyLayerState, inputs::Vector{Blob})
Expand Down
4 changes: 3 additions & 1 deletion src/net.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,18 @@ function destroy(net::Net)
end

function show_statistics(net::Net; title="Network Statistics")
res = StatsDict()
@info("")
@info("## $title")
@info("---------------------------------------------------------")
for i = 1:length(net.layers)
if isa(net.layers[i], StatLayer)
show_statistics(net.states[i])
merge!(res, show_statistics(net.states[i]))
end
end
@info("---------------------------------------------------------")
@info("")
return res
end
function reset_statistics(net::Net)
for i = 1:length(net.layers)
Expand Down
4 changes: 3 additions & 1 deletion src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ end
# General utilities that could be used by all solvers
############################################################
function update_solver_state(state::SolverState, obj_val :: Float64)
state.iter += 1
state.obj_val = obj_val
end
function update_solver_time(state::SolverState)
state.iter += 1
end
function stop_condition_satisfied(solver::Solver, state::SolverState, net::Net)
if state.iter >= solver.params.max_iter
return true
Expand Down
1 change: 1 addition & 0 deletions src/solvers/sgd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ function solve(sgd::SGD, net::Net)

update_solver_state(solver_state, obj_val)
check_coffee_breaks(CoffeeBreakTime.Evening(), sgd, solver_state, net)
update_solver_time(solver_state)

if stop_condition_satisfied(sgd, solver_state, net)
break
Expand Down
114 changes: 114 additions & 0 deletions tools/plot_statistics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using HDF5, JLD
using ArgParse

function read_stats(fname)
stats = jldopen(fname, "r") do file
read(file, "statistics")
end
return stats
end

function number_stats(fnames, names)
res = Dict()
n = 1
for (i, fname) in enumerate(fnames)
for (j, name) in enumerate(names[i])
res[n] = (i, fname, name)
n += 1
end
end
return res
end

function list_stats(numbered_names)
println("Listing available statistics")
for k in sort(collect(keys(numbered_names)))
(_, fname, name) = numbered_names[k]
println(" $k : $fname/$name")
end
println("Select statistics to plot using -i and specify the numbers 1-$(length(numbered_names)) seperated with ,")
end

function create_safe_files(fnames, to_tmp)
# copy to temporary file if requested
if to_tmp
stats_files = [tempname() for fname in fnames]
for (tmpfile,fname) in zip(stats_files, fnames)
cp(fname, tmpfile)
end
return stats_files
else
return fnames
end
end

get_unique_names(stats) = unique(vcat(map(collect, map(keys, values(stats)))...))

s = ArgParseSettings()
@add_arg_table s begin
"--idx", "-i"
help = "a list of indices seperated by , denoting the statistics that should be plotted"
arg_type = String
default = ""
"--list", "-l"
help = "list all available statistics for plotting"
action = :store_true
"--tmp", "-t"
help = "copy the statistics file to a temporary location before plotting (useful when plotting during training)"
action = :store_true
"statistics_filenames"
nargs = '*'
help = "the filenames of the statistics hdf5 files"
required = true
end

# first parse arguments and read statistics files
parsed_args = parse_args(ARGS, s)
filenames = unique(parsed_args["statistics_filenames"])
stats_files = create_safe_files(filenames, parsed_args["tmp"])
all_stats = map(read_stats, stats_files)
# get all unique statistic names that were logged in each files
names = map(get_unique_names, all_stats)
# and assign a number to each
numbered_names = number_stats(filenames, names)

# process according to arguments
using PyPlot
if parsed_args["list"] || parsed_args["idx"] == ""
list_stats(numbered_names)
end

if parsed_args["idx"] != ""
selected_ind = map(int, split(parsed_args["idx"], ","))
if any([x < 0 || x > length(numbered_names) for x in selected_ind])
list_stats(numbered_names)
error("Invalid index in your list : $selected_ind make sure the indices are between 1 and $(length(numbered_names))")
end

figure()
for ind in selected_ind
(stats_num, fname, key) = numbered_names[ind]
stats = all_stats[stats_num]

N = length(stats)
x = zeros(N)
y = zeros(N)
for (i, iter) in enumerate(sort(collect(keys(stats))))
x[i] = iter
y[i] = stats[iter][key]
end
plot(x, y, label="$(fname)/$(key)")
end
legend()

print("Hit <enter> to continue")
readline()
close()
end

# delete temporary file if it was created
if parsed_args["tmp"]
for f in stats_files
rm(f)
end
end

0 comments on commit 77274da

Please sign in to comment.