Skip to content

Commit

Permalink
add comments to dropout example, plot_statistics.jl now supports mult…
Browse files Browse the repository at this point in the history
…iple files for plotting
  • Loading branch information
stokasto committed Nov 29, 2014
1 parent 09ae436 commit b7a6053
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 29 deletions.
37 changes: 33 additions & 4 deletions examples/mnist/mnist_dropout_fc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,30 @@ ENV["MOCHA_USE_CUDA"] = "true"

using Mocha

############################################################
# This is an example script for training a fully connected
# network with dropout on mnist.
#
# The network size is 784-1200-1200-10 with ReLU units
# in the hidden layers and a softmax output layer.
# The parameters for training the network were chosen
# to reproduce the results from the original dropout paper:
# http://arxiv.org/abs/1207.0580
# and the corresponding newer JMLR paper:
# http://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf
#
# Our parameters slightly differ. This is mainly due to the
# fact that in the original dropout paper the weights are scaled
# by 0.5 after training whereas we scale them by 2 during training.
#
# The settings in this script should currently produce a model that
# gets 100 errors (or 99 % accuracy) on the test set
# if you run it for the whole 2000 epochs (=600*2000 steps).
# This is slightly better than, but well within the error
# bars of the JMLR paper.
############################################################


# fix the random seed to make results reproducable
srand(12345678)

Expand All @@ -41,8 +65,12 @@ drop_layers = [drop_input, drop_fc1, drop_fc2]
# put training net together, note that the correct ordering will automatically be established by the constructor
net = Net("MNIST-train", sys, [data_layer, common_layers..., drop_layers..., loss_layer])

params = SolverParameters(max_iter=600*1000, regu_coef=0.0, mom_policy=MomPolicy.Linear(0.5, 0.0008, 600, 0.9),
#mom_policy=MomPolicy.Step(0.5, 1.0012, 600, 0.9),
# we let the learning rate decrease by 0.998 in each epoch (=600 batches of size 100)
# and let the momentum increase linearly from 0.5 to 0.9 over 500 epochs
# which is equivalent to an increase step of 0.0008
# training is done for 2000 epochs
params = SolverParameters(max_iter=600*2000, regu_coef=0.0,
mom_policy=MomPolicy.Linear(0.5, 0.0008, 600, 0.9),
lr_policy=LRPolicy.Step(0.1, 0.998, 600))
solver = SGD(params)

Expand All @@ -53,11 +81,12 @@ add_coffee_break(solver,
every_n_iter=5000)

# show performance on test data every 600 iterations (one epoch)
# also log evrything using the AccumulateStatistics module
# also log everything using the AccumulateStatistics module
data_layer_test = HDF5DataLayer(name="test-data", source=source_fns[2], batch_size=100)
acc_layer = AccuracyLayer(name="test-accuracy", bottoms=[:out, :label], report_error=true)
test_net = Net("MNIST-test", sys, [data_layer_test, common_layers..., acc_layer])
stats = AccumulateStatistics([ValidationPerformance(test_net), TrainingSummary()], try_load = true, save = true, fname = "$(base_dir)/statistics.h5")
stats = AccumulateStatistics([ValidationPerformance(test_net), TrainingSummary()],
try_load = true, save = true, fname = "$(base_dir)/statistics.h5")
add_coffee_break(solver, stats, every_n_iter=600)

solve(solver, net)
Expand Down
74 changes: 49 additions & 25 deletions tools/plot_statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,42 @@ function read_stats(fname)
return stats
end

function list_statistics(names)
function number_stats(fnames, names)
res = Dict()
n = 1
for (i, fname) in enumerate(fnames)
for (j, name) in enumerate(names[i])
res[n] = (i, fname, name)
n += 1
end
end
return res
end

function list_stats(numbered_names)
println("Listing available statistics")
for (i, name) in enumerate(names)
println(" $i : $name : $i")
for k in sort(collect(keys(numbered_names)))
(_, fname, name) = numbered_names[k]
println(" $k : $fname/$name")
end
println("Select statistics to plot using -i and specify the numbers 1-$(length(names)) seperated with ,")
println("Select statistics to plot using -i and specify the numbers 1-$(length(numbered_names)) seperated with ,")
end

function create_safe_file(fname, to_tmp)
function create_safe_files(fnames, to_tmp)
# copy to temporary file if requested
if to_tmp
stats_file = tempname()
cp(fname, stats_file)
return stats_file
stats_files = [tempname() for fname in fnames]
for (tmpfile,fname) in zip(stats_files, fnames)
cp(fname, tmpfile)
end
return stats_files
else
return fname
return fnames
end
end

get_unique_names(stats) = unique(vcat(map(collect, map(keys, values(stats)))...))

s = ArgParseSettings()
@add_arg_table s begin
"--idx", "-i"
Expand All @@ -39,43 +56,48 @@ s = ArgParseSettings()
"--tmp", "-t"
help = "copy the statistics file to a temporary location before plotting (useful when plotting during training)"
action = :store_true
"statistics_filename"
help = "the filename of the statistics hdf5 file"
"statistics_filenames"
nargs = '*'
help = "the filenames of the statistics hdf5 files"
required = true
end

# first parse arguments and read statistics file
# first parse arguments and read statistics files
parsed_args = parse_args(ARGS, s)
stats_file = create_safe_file(parsed_args["statistics_filename"], parsed_args["tmp"])
stats = read_stats(stats_file)
# get all unique statistic names that were logged
names = unique(map(collect, map(keys, values(stats))))[1]
filenames = unique(parsed_args["statistics_filenames"])
stats_files = create_safe_files(filenames, parsed_args["tmp"])
all_stats = map(read_stats, stats_files)
# get all unique statistic names that were logged in each files
names = map(get_unique_names, all_stats)
# and assign a number to each
numbered_names = number_stats(filenames, names)

# process according to arguments
using PyPlot
if parsed_args["list"] || parsed_args["idx"] == ""
list_statistics(names)
list_stats(numbered_names)
end

if parsed_args["idx"] != ""
selected_ind = map(int, split(parsed_args["idx"], ","))
if any([x < 0 || x > length(names) for x in selected_ind])
list_statistics(names)
error("Invalid index in your list : $selected_ind make sure the indices are between 1 and $(length(names))")
if any([x < 0 || x > length(numbered_names) for x in selected_ind])
list_stats(numbered_names)
error("Invalid index in your list : $selected_ind make sure the indices are between 1 and $(length(numbered_names))")
end

selected = [names[i] for i in selected_ind]

figure()
for key in selected
for ind in selected_ind
(stats_num, fname, key) = numbered_names[ind]
stats = all_stats[stats_num]

N = length(stats)
x = zeros(N)
y = zeros(N)
for (i, iter) in enumerate(sort(collect(keys(stats))))
x[i] = iter
y[i] = stats[iter][key]
end
plot(x, y, label=key)
plot(x, y, label="$(fname)/$(key)")
end
legend()

Expand All @@ -86,5 +108,7 @@ end

# delete temporary file if it was created
if parsed_args["tmp"]
rm(stats_file)
for f in stats_files
rm(f)
end
end

0 comments on commit b7a6053

Please sign in to comment.