Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add norm constraints #8

Merged
merged 9 commits into from
Dec 1, 2014
10 changes: 10 additions & 0 deletions examples/mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
The two scripts in this folder can be used to train example models on MNIST.

###mnist.jl
This trains a LeNet like convolutional neural network on the MNIST dataaset (see [the LeNet paper](http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf) for a description of the model).

###mnist_dropout_fc.jl
This trains a fully connected two layer neural network on MNIST and reproduces the results of the [original dropout paper](http://arxiv.org/abs/1207.0580).
It should currently bottom out at 99.05 % accuracy (or 0.95 % error) on the test set.

NOTE: these scripts currently do not select learning parameters using a validation set and hence should be taken with a grain of salt.
27 changes: 21 additions & 6 deletions examples/mnist/mnist_dropout_fc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,38 @@ using Mocha
# by 0.5 after training whereas we scale them by 2 during training.
#
# The settings in this script should currently produce a model that
# gets 100 errors (or 99 % accuracy) on the test set
# gets 95 errors (or 99.05 % accuracy) on the test set
# if you run it for the whole 2000 epochs (=600*2000 steps).
# This is slightly better than, but well within the error
# bars of the JMLR paper.
# This is slightly better than the results of the JMLR paper.
# This difference is likely due to slight differences in the
# learning parameters. Also note that our hyperparameters
# are not chosen using a validation set, as one would do
# for a paper.
############################################################
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to add a comment here that the script used to convert the MNIST dataset to HDF5 format do some randomization in the order of data samples. And I didn't fix the random seed there. So other people might not get exactly the same results if they prepare their HDF5 MNIST dataset separately. A fix might be just to fix the random seed in the data conversion script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think I'll simply fix the random seed in the conversion script then as it is quite useful to be able to reproduce exact results. I'll still add a comment here though since different GPUs/cuda versions etc could potentially lead to small changes as well.



# fix the random seed to make results reproducable
srand(12345678)

data_layer = HDF5DataLayer(name="train-data", source=source_fns[1], batch_size=100)
fc1_layer = InnerProductLayer(name="fc1", output_dim=1200, neuron=Neurons.ReLU(), weight_init = GaussianInitializer(std=0.01), bottoms=[:data], tops=[:fc1])
fc2_layer = InnerProductLayer(name="fc2", output_dim=1200, neuron=Neurons.ReLU(), weight_init = GaussianInitializer(std=0.01), bottoms=[:fc1], tops=[:fc2])
fc3_layer = InnerProductLayer(name="out", output_dim=10, bottoms=[:fc2], weight_init = ConstantInitializer(0), tops=[:out])
# each fully connected layer uses a ReLU activation and a constraint on the L2 norm of the weights
fc1_layer = InnerProductLayer(name="fc1", output_dim=1200, neuron=Neurons.ReLU(),
weight_init = GaussianInitializer(std=0.01),
#weight_cons = L2Cons(4.5),
bottoms=[:data], tops=[:fc1])
fc2_layer = InnerProductLayer(name="fc2", output_dim=1200, neuron=Neurons.ReLU(),
weight_init = GaussianInitializer(std=0.01),
weight_cons = L2Cons(4.5),
bottoms=[:fc1], tops=[:fc2])
fc3_layer = InnerProductLayer(name="out", output_dim=10, bottoms=[:fc2],
weight_init = ConstantInitializer(0),
weight_cons = L2Cons(4.5),
tops=[:out])
loss_layer = SoftmaxLossLayer(name="loss", bottoms=[:out,:label])

# setup dropout for the different layers
# we use 20% dropout on the inputs and 50% dropout in the hidden layers
# as these values were previously found to be good defaults
drop_input = DropoutLayer(name="drop_in", bottoms=[:data], ratio=0.2)
drop_fc1 = DropoutLayer(name="drop_fc1", bottoms=[:fc1], ratio=0.5)
drop_fc2 = DropoutLayer(name="drop_fc2", bottoms=[:fc2], ratio=0.5)
Expand Down
2 changes: 2 additions & 0 deletions src/Mocha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ end

include("initializers.jl")
include("regularizers.jl")
include("constraints.jl")
include("neurons.jl")

if Config.use_cuda
include("cuda/regularizers.jl")
include("cuda/constraints.jl")
include("cuda/neurons.jl")
end

Expand Down
61 changes: 61 additions & 0 deletions src/constraints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
export Constraint, NoCons, L2Cons
export constrain!

abstract Constraint

immutable NoCons <: Constraint
coefficient :: FloatingPoint # not used, just for consistent API
every_n_iter :: Int # also not used
end
NoCons() = NoCons(0.0, 0)

immutable L2Cons <: Constraint
coefficient :: FloatingPoint
every_n_iter :: Int
end
L2Cons(coefficient) = L2Cons(coefficient, 1)

############################################################
# No constraint
############################################################
function constrain!(sys::System, cons::NoCons, param)
# do nothing if no constraints apply
end

############################################################
# L2 norm constraint on the weights
############################################################

function apply_l2_cons!{T <: FloatingPoint}(sys::System{CPUBackend}, blob::CPUBlob{T},
coef::FloatingPoint, ninputs::Integer, nunits::Integer)
param = reshape(blob.data, (ninputs, nunits))
# we constrain each column vector
@simd for i = 1:nunits
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you are getting a warning that meta data for simd cannot be attached. My reading from Julia document is that @simd only works for very simple loops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, that instruction is left-over from a first attempt where I actually computed the constraint using two for loops (which turned out to be slightly slower), I'll fix that. Strangely I did not get a warning, should maybe check my config.

# compute norm and scale using blas
norm = vecnorm(param[:, i])
if norm > coef
scale_factor = (1. / norm) * coef
offset = sizeof(T) * (i-1) * ninputs
BLAS.scal!(ninputs, convert(T, scale_factor), convert(Ptr{T}, param) + offset, 1)
end
end
end

# this constraints a given blob along the last dimension that is not of size 1
# it is a bit ugly but was the easiest way to implement it for now
function constrain!(sys :: System, cons :: L2Cons, param :: Blob)
W = size(param, 1) # source dim in fully connected
H = size(param, 2) # target dim in fully connected
C = size(param, 3)
N = size(param, 4) # number of filters in convolution
if H == 1 && N == 1 && C == 1
# only width left ... this is a bias ... lets constrain that
apply_l2_cons!(sys, param, cons.coefficient, W, 1)
elseif N == 1 && C == 1
# we have only one channel and num -> constrain target dim
apply_l2_cons!(sys, param, cons.coefficient, W, H)
else
# constrain on N -> e.g. the number if units for convolutional filters
apply_l2_cons!(sys, param, cons.coefficient, W*H*C, N)
end
end
53 changes: 53 additions & 0 deletions src/cuda/constraints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
############################################################
# apply L2 constraint
############################################################

function apply_l2_cons!{T <: FloatingPoint}(sys::System{CuDNNBackend}, blob::CuTensorBlob{T},
coef::FloatingPoint, ninputs::Integer, nunits::Integer)
# we allocate a bit of temporary memory here
# we could instead also store this in the cons type
# but that would double the memory footprint of a network
# which is prohibitive for large models!
# --
# NOTE stokasto:
# an even better alternative would be to write
# a dedicated kernel for normalization
# but since the weight matrices are usually small
# I am not sure whether that will pay off especially
# since the constraints only apply rarely
# I also tested using cublas cublasSnorm2 but that was way slower
# than computing all norms using gemm
@assert(ninputs*nunits == length(blob))
width, height, channels, num = size(blob)
# allocate
tmpA = make_blob(sys.backend, T, size(blob)...)
onesv = make_blob(sys.backend, ones(T, ninputs, 1, 1, 1))
tmp_norm = make_blob(sys.backend, T, (nunits, 1, 1, 1))
tmp_norm_host = zeros(T, nunits)
# copy blob so that it stays intact
copy!(tmpA, blob)

# we compute the squared norm of all colums of matrix A as:
# ||A||^2 = transpose(A .* A) * ones(size(A))
# square blob inplace
CuVec.mul!(sys, T, tmpA.ptr.p, tmpA.ptr.p, width*height, channels, num)
# and reduce via gemv to get the sum
CuBLAS.gemm(sys.backend.cublas_ctx, CuBLAS.OP_T, CuBLAS.OP_N, nunits, 1, ninputs,
convert(T, 1), tmpA.ptr, ninputs, onesv.ptr, ninputs, convert(T, 0), tmp_norm.ptr, nunits)
# copy back for doing the norm size check on the cpu
copy!(tmp_norm_host, tmp_norm)

for i = 1:nunits
# calculate offset in blob vector
offset = sizeof(T) * (i-1) * ninputs
off_ptr = CuPtr(blob.ptr.p + offset)
@inbounds norm = sqrt(tmp_norm_host[i])
if norm > coef
scale_factor = (1. / norm) * coef
CuBLAS.scal(sys.backend.cublas_ctx, ninputs, convert(T, scale_factor), off_ptr, 1)
end
end
destroy(tmpA)
destroy(onesv)
destroy(tmp_norm)
end
6 changes: 4 additions & 2 deletions src/layers/convolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
bias_init :: Initializer = ConstantInitializer(0),
filter_regu :: Regularizer = L2Regu(1),
bias_regu :: Regularizer = NoRegu(),
filter_cons :: Constraint = NoCons(),
bias_cons :: Constraint = NoCons(),
filter_lr :: FloatingPoint = 1.0,
bias_lr :: FloatingPoint = 2.0,
)
Expand Down Expand Up @@ -100,8 +102,8 @@ type ConvolutionLayerState <: LayerState

etc = setup_etc(sys, layer, dtype, width, height, channels, batch_size, width_out, height_out, inputs)

parameters = [Parameter("filter", filter, ∇filter, layer.filter_init, layer.filter_regu, layer.filter_lr),
Parameter("bias", bias, ∇bias, layer.bias_init, layer.bias_regu, layer.bias_lr)]
parameters = [Parameter("filter", filter, ∇filter, layer.filter_init, layer.filter_regu, layer.filter_cons, layer.filter_lr),
Parameter("bias", bias, ∇bias, layer.bias_init, layer.bias_regu, layer.bias_cons, layer.bias_lr)]

state = new(layer, blobs, blobs_diff, parameters)
state.filter = filter
Expand Down
6 changes: 4 additions & 2 deletions src/layers/inner-product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
bias_init :: Initializer = ConstantInitializer(0),
weight_regu :: Regularizer = L2Regu(1),
bias_regu :: Regularizer = NoRegu(),
weight_cons :: Constraint = NoCons(),
bias_cons :: Constraint = NoCons(),
weight_lr :: FloatingPoint = 1.0,
bias_lr :: FloatingPoint = 2.0,
neuron :: ActivationFunction = Neurons.Identity()
Expand Down Expand Up @@ -63,8 +65,8 @@ type InnerProductLayerState <: LayerState
state.bias_multiplier = make_blob(sys.backend, data_type, nums, 1, 1, 1)
fill!(state.bias_multiplier, 1)

state.parameters = [Parameter("weight", state.W, state.∇W, layer.weight_init, layer.weight_regu, layer.weight_lr),
Parameter("bias", state.b, state.∇b, layer.bias_init, layer.bias_regu, layer.bias_lr)]
state.parameters = [Parameter("weight", state.W, state.∇W, layer.weight_init, layer.weight_regu, layer.weight_cons, layer.weight_lr),
Parameter("bias", state.b, state.∇b, layer.bias_init, layer.bias_regu, layer.bias_cons, layer.bias_lr)]

return state
end
Expand Down
1 change: 1 addition & 0 deletions src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ type Parameter
gradient :: Blob
initializer :: Initializer
regularizer :: Regularizer
constraint :: Constraint
learning_rate :: FloatingPoint # relative learning rate
end
2 changes: 1 addition & 1 deletion src/regularizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function backward(sys::System{CPUBackend}, regu :: L2Regu, global_regu::Floating
end

############################################################
# L2 regularization
# L1 regularization
############################################################
function forward(sys::System{CPUBackend}, regu :: L1Regu, global_regu::FloatingPoint, param :: Blob)
return regu.coefficient * global_regu * sum(abs(param.data))
Expand Down
5 changes: 5 additions & 0 deletions src/solvers/sgd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ function solve(sgd::SGD, net::Net)

update_parameters(net, sgd, state.parameters[j].learning_rate * learning_rate, momentum,
state, state.parameters[j].blob, hist_blob, gradient, data_type)
# apply constraints after update
cons_every = state.parameters[j].constraint.every_n_iter
if cons_every > 0 && solver_state.iter % cons_every == 0
constrain!(net.sys, state.parameters[j].constraint, state.parameters[j].blob)
end
end
end

Expand Down
48 changes: 48 additions & 0 deletions test/constraints/l2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
function test_l2_constraint(sys::System, T, eps)
println("-- Testing L2 constraint on $(typeof(sys.backend)){$T}...")
# this simulates a convolutional filter and applies
# the l2 constraint to it
n_filters = 5
coef = 0.2
param = rand(T, 2,3,4,n_filters) - 0.5
param_after = zeros(T, size(param))
param_blob = make_blob(sys.backend, param)

cons = L2Cons(coef)
constrain!(sys, cons, param_blob)
copy!(param_after, param_blob)
param_after = reshape(param_after, size(param))
for f=1:n_filters
norm2 = vecnorm(param_after[:, :, :, f])
@test norm2 <= coef + eps
end

# this is the same as above but for fully connected weights
n_input = 10
n_out = 12
param = rand(T, n_input,n_out,1,1) - 0.5
param_after = zeros(T, size(param))
param_blob = make_blob(sys.backend, param)

cons = L2Cons(coef)
constrain!(sys, cons, param_blob)
copy!(param_after, param_blob)
param_after = reshape(param_after, size(param))
for f=1:n_out
norm2 = vecnorm(param_after[:, f, :, :])
@test norm2 <= coef + eps
end
end

function test_l2_constraint(sys::System)
test_l2_constraint(sys, Float32, 1e-5)
test_l2_constraint(sys, Float64, 1e-5)
end

if test_cpu
test_l2_constraint(sys_cpu)
end
if test_cudnn
test_l2_constraint(sys_cudnn)
end

5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ include("neurons/sigmoid.jl")
include("regularizers/l2.jl")
include("regularizers/l1.jl")

############################################################
# Regularizers
############################################################
include("constraints/l2.jl")

############################################################
# Data Transformers
############################################################
Expand Down