-
Notifications
You must be signed in to change notification settings - Fork 254
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from stokasto/accumulate_statistics
Add coffee break to allow for accumulation of statistics and reproduce ~100 errors on MNIST
- Loading branch information
Showing
10 changed files
with
288 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
hdf5_fns = ["data/train.hdf5", "data/test.hdf5"] | ||
source_fns = ["data/train.txt", "data/test.txt"] | ||
for i = 1:length(hdf5_fns) | ||
if !isfile(hdf5_fns[i]) | ||
println("Data not found, use get-mnist.sh to generate HDF5 data") | ||
exit(1) | ||
else | ||
open(source_fns[i], "w") do s | ||
println(s, hdf5_fns[i]) | ||
end | ||
end | ||
end | ||
|
||
#ENV["MOCHA_USE_NATIVE_EXT"] = "true" | ||
#ENV["OMP_NUM_THREADS"] = 1 | ||
#blas_set_num_threads(1) | ||
ENV["MOCHA_USE_CUDA"] = "true" | ||
|
||
using Mocha | ||
|
||
############################################################ | ||
# This is an example script for training a fully connected | ||
# network with dropout on mnist. | ||
# | ||
# The network size is 784-1200-1200-10 with ReLU units | ||
# in the hidden layers and a softmax output layer. | ||
# The parameters for training the network were chosen | ||
# to reproduce the results from the original dropout paper: | ||
# http://arxiv.org/abs/1207.0580 | ||
# and the corresponding newer JMLR paper: | ||
# http://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf | ||
# | ||
# Our parameters slightly differ. This is mainly due to the | ||
# fact that in the original dropout paper the weights are scaled | ||
# by 0.5 after training whereas we scale them by 2 during training. | ||
# | ||
# The settings in this script should currently produce a model that | ||
# gets 100 errors (or 99 % accuracy) on the test set | ||
# if you run it for the whole 2000 epochs (=600*2000 steps). | ||
# This is slightly better than, but well within the error | ||
# bars of the JMLR paper. | ||
############################################################ | ||
|
||
|
||
# fix the random seed to make results reproducable | ||
srand(12345678) | ||
|
||
data_layer = HDF5DataLayer(name="train-data", source=source_fns[1], batch_size=100) | ||
fc1_layer = InnerProductLayer(name="fc1", output_dim=1200, neuron=Neurons.ReLU(), weight_init = GaussianInitializer(std=0.01), bottoms=[:data], tops=[:fc1]) | ||
fc2_layer = InnerProductLayer(name="fc2", output_dim=1200, neuron=Neurons.ReLU(), weight_init = GaussianInitializer(std=0.01), bottoms=[:fc1], tops=[:fc2]) | ||
fc3_layer = InnerProductLayer(name="out", output_dim=10, bottoms=[:fc2], weight_init = ConstantInitializer(0), tops=[:out]) | ||
loss_layer = SoftmaxLossLayer(name="loss", bottoms=[:out,:label]) | ||
|
||
# setup dropout for the different layers | ||
drop_input = DropoutLayer(name="drop_in", bottoms=[:data], ratio=0.2) | ||
drop_fc1 = DropoutLayer(name="drop_fc1", bottoms=[:fc1], ratio=0.5) | ||
drop_fc2 = DropoutLayer(name="drop_fc2", bottoms=[:fc2], ratio=0.5) | ||
|
||
sys = System(CuDNNBackend()) | ||
#sys = System(CPUBackend()) | ||
init(sys) | ||
|
||
common_layers = [fc1_layer, fc2_layer, fc3_layer] | ||
drop_layers = [drop_input, drop_fc1, drop_fc2] | ||
# put training net together, note that the correct ordering will automatically be established by the constructor | ||
net = Net("MNIST-train", sys, [data_layer, common_layers..., drop_layers..., loss_layer]) | ||
|
||
# we let the learning rate decrease by 0.998 in each epoch (=600 batches of size 100) | ||
# and let the momentum increase linearly from 0.5 to 0.9 over 500 epochs | ||
# which is equivalent to an increase step of 0.0008 | ||
# training is done for 2000 epochs | ||
params = SolverParameters(max_iter=600*2000, regu_coef=0.0, | ||
mom_policy=MomPolicy.Linear(0.5, 0.0008, 600, 0.9), | ||
lr_policy=LRPolicy.Step(0.1, 0.998, 600)) | ||
solver = SGD(params) | ||
|
||
base_dir = "snapshots_dropout_fc" | ||
# save snapshots every 5000 iterations | ||
add_coffee_break(solver, | ||
Snapshot(base_dir, auto_load=true), | ||
every_n_iter=5000) | ||
|
||
# show performance on test data every 600 iterations (one epoch) | ||
# also log everything using the AccumulateStatistics module | ||
data_layer_test = HDF5DataLayer(name="test-data", source=source_fns[2], batch_size=100) | ||
acc_layer = AccuracyLayer(name="test-accuracy", bottoms=[:out, :label], report_error=true) | ||
test_net = Net("MNIST-test", sys, [data_layer_test, common_layers..., acc_layer]) | ||
stats = AccumulateStatistics([ValidationPerformance(test_net), TrainingSummary()], | ||
try_load = true, save = true, fname = "$(base_dir)/statistics.h5") | ||
add_coffee_break(solver, stats, every_n_iter=600) | ||
|
||
solve(solver, net) | ||
|
||
#Profile.init(int(1e8), 0.001) | ||
#@profile solve(solver, net) | ||
#open("profile.txt", "w") do out | ||
# Profile.print(out) | ||
#end | ||
|
||
destroy(net) | ||
destroy(test_net) | ||
shutdown(sys) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
using HDF5, JLD | ||
|
||
export AccumulateStatistics | ||
|
||
type AccumulateStatistics <: Coffee | ||
modules :: Array{Coffee, 1} | ||
stats :: Dict{Integer,StatsDict} | ||
try_load :: Bool | ||
save :: Bool | ||
fname :: String | ||
|
||
AccumulateStatistics(modules :: Array{Coffee, 1}; stats = Dict(StatsDict()), try_load = false, save = false, fname = "" ) = new(modules, stats, try_load, save, fname) | ||
end | ||
|
||
function init(coffee::AccumulateStatistics, net::Net) | ||
for m in coffee.modules | ||
init(m, net) | ||
end | ||
if coffee.try_load && isfile(coffee.fname) | ||
@warn("Statistics file already exists, trying to merge!") | ||
stats = jldopen(coffee.fname, "r") do file | ||
read(file, "statistics") | ||
end | ||
merge!(coffee.stats, stats) | ||
end | ||
end | ||
|
||
function enjoy(coffee::AccumulateStatistics, time::CoffeeBreakTime.Morning, net::Net, state::SolverState) | ||
step = state.iter | ||
if ! haskey(coffee.stats, step) | ||
coffee.stats[step] = StatsDict() | ||
end | ||
for m in coffee.modules | ||
merge!(coffee.stats[step], enjoy(m, time, net, state)) | ||
end | ||
end | ||
|
||
function enjoy(coffee::AccumulateStatistics, time::CoffeeBreakTime.Evening, net::Net, state::SolverState) | ||
step = state.iter | ||
if ! haskey(coffee.stats, step) | ||
coffee.stats[step] = StatsDict() | ||
end | ||
for m in coffee.modules | ||
merge!(coffee.stats[step], enjoy(m, time, net, state)) | ||
end | ||
if coffee.save | ||
jldopen(coffee.fname, "w") do file | ||
write(file, "statistics", coffee.stats) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
using HDF5, JLD | ||
using ArgParse | ||
|
||
function read_stats(fname) | ||
stats = jldopen(fname, "r") do file | ||
read(file, "statistics") | ||
end | ||
return stats | ||
end | ||
|
||
function number_stats(fnames, names) | ||
res = Dict() | ||
n = 1 | ||
for (i, fname) in enumerate(fnames) | ||
for (j, name) in enumerate(names[i]) | ||
res[n] = (i, fname, name) | ||
n += 1 | ||
end | ||
end | ||
return res | ||
end | ||
|
||
function list_stats(numbered_names) | ||
println("Listing available statistics") | ||
for k in sort(collect(keys(numbered_names))) | ||
(_, fname, name) = numbered_names[k] | ||
println(" $k : $fname/$name") | ||
end | ||
println("Select statistics to plot using -i and specify the numbers 1-$(length(numbered_names)) seperated with ,") | ||
end | ||
|
||
function create_safe_files(fnames, to_tmp) | ||
# copy to temporary file if requested | ||
if to_tmp | ||
stats_files = [tempname() for fname in fnames] | ||
for (tmpfile,fname) in zip(stats_files, fnames) | ||
cp(fname, tmpfile) | ||
end | ||
return stats_files | ||
else | ||
return fnames | ||
end | ||
end | ||
|
||
get_unique_names(stats) = unique(vcat(map(collect, map(keys, values(stats)))...)) | ||
|
||
s = ArgParseSettings() | ||
@add_arg_table s begin | ||
"--idx", "-i" | ||
help = "a list of indices seperated by , denoting the statistics that should be plotted" | ||
arg_type = String | ||
default = "" | ||
"--list", "-l" | ||
help = "list all available statistics for plotting" | ||
action = :store_true | ||
"--tmp", "-t" | ||
help = "copy the statistics file to a temporary location before plotting (useful when plotting during training)" | ||
action = :store_true | ||
"statistics_filenames" | ||
nargs = '*' | ||
help = "the filenames of the statistics hdf5 files" | ||
required = true | ||
end | ||
|
||
# first parse arguments and read statistics files | ||
parsed_args = parse_args(ARGS, s) | ||
filenames = unique(parsed_args["statistics_filenames"]) | ||
stats_files = create_safe_files(filenames, parsed_args["tmp"]) | ||
all_stats = map(read_stats, stats_files) | ||
# get all unique statistic names that were logged in each files | ||
names = map(get_unique_names, all_stats) | ||
# and assign a number to each | ||
numbered_names = number_stats(filenames, names) | ||
|
||
# process according to arguments | ||
using PyPlot | ||
if parsed_args["list"] || parsed_args["idx"] == "" | ||
list_stats(numbered_names) | ||
end | ||
|
||
if parsed_args["idx"] != "" | ||
selected_ind = map(int, split(parsed_args["idx"], ",")) | ||
if any([x < 0 || x > length(numbered_names) for x in selected_ind]) | ||
list_stats(numbered_names) | ||
error("Invalid index in your list : $selected_ind make sure the indices are between 1 and $(length(numbered_names))") | ||
end | ||
|
||
figure() | ||
for ind in selected_ind | ||
(stats_num, fname, key) = numbered_names[ind] | ||
stats = all_stats[stats_num] | ||
|
||
N = length(stats) | ||
x = zeros(N) | ||
y = zeros(N) | ||
for (i, iter) in enumerate(sort(collect(keys(stats)))) | ||
x[i] = iter | ||
y[i] = stats[iter][key] | ||
end | ||
plot(x, y, label="$(fname)/$(key)") | ||
end | ||
legend() | ||
|
||
print("Hit <enter> to continue") | ||
readline() | ||
close() | ||
end | ||
|
||
# delete temporary file if it was created | ||
if parsed_args["tmp"] | ||
for f in stats_files | ||
rm(f) | ||
end | ||
end |