Skip to content

Commit

Permalink
LRN unit-test : 4D -> ND tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Dec 20, 2014
1 parent f85928b commit deb64a5
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions test/layers/lrn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@ function test_lrn_layer(backend::Backend, mode::LRNModeType, tensor_dim, T, eps)
@test all(abs(got_output - expected_output) .< eps)

println(" > Backward")
#top_diff = float(int(100*rand(size(input))))/100
top_diff = rand(T, size(input))
copy!(state.blobs_diff[1], top_diff)
backward(backend, state, input_blobs, diff_blobs)
got_grad = zeros(T, size(input))
copy!(got_grad, diff_blobs[1])
expected_grad = lrn_backward(input, top_diff, state)
expected_grad = lrn_backward(input, top_diff, state, op_dim)
@test all(abs(got_grad - expected_grad) .< eps)

shutdown(backend, state)
Expand Down Expand Up @@ -98,26 +97,31 @@ function lrn_forward{T}(input::Array{T}, state, op_dim)
end
end

function lrn_backward_across_channel{T}(input::Array{T}, top_diff::Array{T}, state)
function lrn_backward_across_channel{T}(input::Array{T}, top_diff::Array{T}, state, op_dim)
output = zeros(T, size(input))
width, height, channels, num = size(input)
pre_dim, chann_dim, post_dim = split_dims(input, op_dim)
pre_pad = div(state.layer.kernel-1,2)
post_pad = state.layer.kernel - pre_pad - 1

for n = 1:num
for c = 1:channels
canonical_input = reshape(input, (pre_dim, chann_dim, post_dim))
canonical_output = reshape(output, (pre_dim, chann_dim, post_dim))
canonical_diff = reshape(top_diff, (pre_dim, chann_dim, post_dim))

for n = 1:post_dim
for c = 1:chann_dim
cstart = c-pre_pad
cend = min(c + post_pad, channels)
cend = min(c + post_pad, chann_dim)
cstart = max(1, cstart)

tmp = input[:,:,cstart:cend,n].^2 * (state.layer.scale / state.layer.kernel)
tmp = (sum(tmp, 3) + state.layer.shift)
tmp = canonical_input[:,cstart:cend,n].^2 * (state.layer.scale / state.layer.kernel)
tmp = (sum(tmp, 2) + state.layer.shift)

output[:,:,c,n] += tmp .^ (-state.layer.power) .* top_diff[:,:,c,n]
canonical_output[:,c,n] += tmp .^ (-state.layer.power) .* canonical_diff[:,c,n]

tmp = -state.layer.power * tmp .^ (-state.layer.power - 1)
tmp = 2 * state.layer.scale / state.layer.kernel * tmp
output[:,:,cstart:cend,n] += tmp .* input[:,:,cstart:cend,n] .* input[:,:,c,n] .* top_diff[:,:,c,n]
canonical_output[:,cstart:cend,n] += tmp .* canonical_input[:,cstart:cend,n] .*
canonical_input[:,c,n] .* canonical_diff[:,c,n]
end
end

Expand Down Expand Up @@ -157,9 +161,9 @@ function lrn_backward_within_channel{T}(input::Array{T}, top_diff::Array{T}, sta
return output
end

function lrn_backward{T}(input::Array{T}, top_diff::Array{T}, state)
function lrn_backward{T}(input::Array{T}, top_diff::Array{T}, state, op_dim)
if isa(state.layer.mode, LRNMode.AcrossChannel)
lrn_backward_across_channel(input, top_diff, state)
lrn_backward_across_channel(input, top_diff, state, op_dim)
elseif isa(state.layer.mode, LRNMode.WithinChannel)
lrn_backward_within_channel(input, top_diff, state)
else
Expand Down

0 comments on commit deb64a5

Please sign in to comment.