diff --git a/src/Mocha.jl b/src/Mocha.jl index 596f824..330c3db 100644 --- a/src/Mocha.jl +++ b/src/Mocha.jl @@ -15,6 +15,7 @@ include("exception.jl") include("utils/blas.jl") include("utils/math.jl") include("utils/io.jl") +include("utils/tensor.jl") if Config.use_native_extension include("utils/im2col-native.jl") diff --git a/src/cuda/layers/accuracy.jl b/src/cuda/layers/accuracy.jl index fe6b026..6d0c03a 100644 --- a/src/cuda/layers/accuracy.jl +++ b/src/cuda/layers/accuracy.jl @@ -1,6 +1,7 @@ -function setup_etc(backend::GPUBackend, layer::AccuracyLayer, inputs) - width, height, channels, num = get_whcn(inputs[1]) - etc = make_blob(backend, eltype(inputs[1]), (width,height,1,num)) +function setup_etc(backend::GPUBackend, layer::AccuracyLayer, op_dim::Int, inputs) + dims = [size(inputs[1])...] + dims[op_dim] = 1 + etc = make_blob(backend, eltype(inputs[1]), dims...) return etc end function shutdown(backend::GPUBackend, state::AccuracyLayerState) @@ -11,8 +12,7 @@ function forward(backend::GPUBackend, state::AccuracyLayerState, inputs::Vector{ pred = inputs[1] label = inputs[2] - width, height, channels, num = get_whcn(pred) - spatial_dim = width*height + spatial_dim, pred_dim, num = split_dims(pred, state.op_dim) data_type = eltype(pred) x_block = int(ceil(float64(num)/CUDA.THREADS_PER_BLOCK_X)); @@ -26,7 +26,7 @@ function forward(backend::GPUBackend, state::AccuracyLayerState, inputs::Vector{ error("Unsupported data type $data_type") end CUDA.launch(kernel, (x_block,y_block),(CUDA.THREADS_PER_BLOCK_X,CUDA.THREADS_PER_BLOCK_Y), - (pred.ptr.p, label.ptr.p, state.etc.ptr.p, num, channels, spatial_dim)); + (pred.ptr.p, label.ptr.p, state.etc.ptr.p, num, pred_dim, spatial_dim)); N = num * spatial_dim accuracy = CuBLAS.dot(backend.cublas_ctx, data_type, N, state.etc.ptr, 1, state.etc.ptr, 1) diff --git a/src/cuda/layers/argmax.jl b/src/cuda/layers/argmax.jl index ae78eee..c2c435d 100644 --- a/src/cuda/layers/argmax.jl +++ b/src/cuda/layers/argmax.jl @@ -3,8 +3,7 @@ function forward(backend::GPUBackend, state::ArgmaxLayerState, inputs::Vector{Bl input = inputs[i] output = state.blobs[i] - width, height, channels, num = get_whcn(input) - spatial_dim = width*height + spatial_dim, channels, num = split_dims(input, state.dims[i]) data_type = eltype(input) x_block = int(ceil(float64(num)/CUDA.THREADS_PER_BLOCK_X)); diff --git a/src/cuda/layers/channel-pooling.jl b/src/cuda/layers/channel-pooling.jl index 42b0c85..073fbaf 100644 --- a/src/cuda/layers/channel-pooling.jl +++ b/src/cuda/layers/channel-pooling.jl @@ -1,16 +1,14 @@ -function setup_etc(backend::GPUBackend, layer::ChannelPoolingLayer, inputs, pooled_chann) +function setup_etc(backend::GPUBackend, layer::ChannelPoolingLayer, inputs, blobs) if isa(layer.pooling, Pooling.Max) masks = Array(CuPtr, length(inputs)) for i = 1:length(inputs) - masks[i] = CUDA.cualloc(Csize_t, get_width(inputs[i]) * get_height(inputs[i]) * - pooled_chann[i] * get_num(inputs[i])) + masks[i] = CUDA.cualloc(Csize_t, length(blobs[i])) end etc = masks elseif isa(layer.pooling, Pooling.Mean) integrals = Array(CuPtr, length(inputs)) for i = 1:length(inputs) - integrals[i] = CUDA.cualloc(eltype(inputs[i]), get_width(inputs[i]) * get_height(inputs[i]) * - get_chann(inputs[i])) + integrals[i] = CUDA.cualloc(eltype(inputs[i]), prod(size(inputs[i])[1:end-1])) end etc = integrals else @@ -39,9 +37,9 @@ function forward(backend::GPUBackend, pool::StdPoolingFunction, output = state.blobs[i] if isa(pool, Pooling.Max) - cuda_max_channel_pooling_forward(backend, input, output, state.etc[i], state.layer) + cuda_max_channel_pooling_forward(backend, input, output, state.etc[i], state.layer, state.op_dims[i]) elseif isa(pool, Pooling.Mean) - cuda_mean_channel_pooling_forward(backend, input, output, state.etc[i], state.layer) + cuda_mean_channel_pooling_forward(backend, input, output, state.etc[i], state.layer, state.op_dims[i]) else error("Pooling for $pool not implemented yet") end @@ -59,9 +57,9 @@ function backward(backend::GPUBackend, pool::StdPoolingFunction, state::ChannelP diff = diffs[i] if !isa(diff, NullBlob) if isa(pool, Pooling.Max) - cuda_max_channel_pooling_backward(backend, diff, state.blobs_diff[i], state.etc[i], state.layer) + cuda_max_channel_pooling_backward(backend, diff, state.blobs_diff[i], state.etc[i], state.layer, state.op_dims[i]) elseif isa(pool, Pooling.Mean) - cuda_mean_channel_pooling_backward(backend, diff, state.blobs_diff[i], state.layer) + cuda_mean_channel_pooling_backward(backend, diff, state.blobs_diff[i], state.layer, state.op_dims[i]) else error("Pooling for $pool not implemented yet") end @@ -72,15 +70,14 @@ function backward(backend::GPUBackend, pool::StdPoolingFunction, state::ChannelP end function cuda_mean_channel_pooling_forward{T}(backend::GPUBackend, input::CuTensorBlob{T}, - output::CuTensorBlob{T}, integral::CuPtr, layer) + output::CuTensorBlob{T}, integral::CuPtr, layer, op_dim) - width, height, channels, num = size(input) - pooled_chann = size(output, 3) + spatial_dim_T, channels, num = split_dims(input, op_dim) + pooled_chann = size(output, op_dim) one = convert(T, 1) neg_one = convert(T, -1) scale = convert(T, 1.0/layer.kernel) - spatial_dim_T = width*height spatial_dim = spatial_dim_T * sizeof(T) fea_dim = spatial_dim * channels output_fea_dim = spatial_dim * pooled_chann @@ -118,15 +115,14 @@ function cuda_mean_channel_pooling_forward{T}(backend::GPUBackend, input::CuTens end function cuda_mean_channel_pooling_backward{T}(backend::GPUBackend, input::CuTensorBlob{T}, - output::CuTensorBlob{T}, layer) + output::CuTensorBlob{T}, layer, op_dim) - width, height, channels, num = size(input) - pooled_chann = size(output, 3) + spatial_dim_T, channels, num = split_dims(input, op_dim) + pooled_chann = size(output, op_dim) scale = 1/convert(T, layer.kernel) fill!(input, 0) - spatial_dim_T = width*height spatial_dim = spatial_dim_T * sizeof(T) fea_dim = spatial_dim * channels output_fea_dim = spatial_dim * pooled_chann @@ -159,11 +155,10 @@ function cuda_geometry_max_chann_pool(sp_dim::Int, num::Int) end function cuda_max_channel_pooling_forward{T}(backend::GPUBackend, input::CuTensorBlob{T}, - output::CuTensorBlob{T}, mask::CuPtr, layer) + output::CuTensorBlob{T}, mask::CuPtr, layer, op_dim) - width, height, channels, num = size(input) - sp_dim = width*height - pooled_chann = get_chann(output) + sp_dim, channels, num = split_dims(input, op_dim) + pooled_chann = size(output, op_dim) cuda_dim = cuda_geometry_max_chann_pool(sp_dim, num); if T == Float32 @@ -179,11 +174,10 @@ function cuda_max_channel_pooling_forward{T}(backend::GPUBackend, input::CuTenso end function cuda_max_channel_pooling_backward{T}(backend::GPUBackend, input::CuTensorBlob{T}, - output::CuTensorBlob{T}, mask::CuPtr, layer) + output::CuTensorBlob{T}, mask::CuPtr, layer, op_dim) - width, height, channels, num = size(input) - sp_dim = width*height - pooled_chann = get_chann(output) + sp_dim, channels, num = split_dims(input, op_dim) + pooled_chann = size(output, op_dim) cuda_dim = cuda_geometry_max_chann_pool(sp_dim, num); if T == Float32 diff --git a/src/cuda/layers/multinomial-logistic-loss.jl b/src/cuda/layers/multinomial-logistic-loss.jl index 865e220..abd3384 100644 --- a/src/cuda/layers/multinomial-logistic-loss.jl +++ b/src/cuda/layers/multinomial-logistic-loss.jl @@ -3,9 +3,7 @@ function forward(backend::GPUBackend, state::MultinomialLogisticLossLayerState, label = inputs[2] data_type = eltype(pred) - width, height, channels, num = get_whcn(pred) - - spatial_dim = height*width + spatial_dim, channels, num = split_dims(pred, state.op_dim) prob_dim = channels x_block = int(ceil(float64(num)/CUDA.THREADS_PER_BLOCK_X)) diff --git a/src/cuda/layers/softmax-loss.jl b/src/cuda/layers/softmax-loss.jl index 99a0ff6..9422d23 100644 --- a/src/cuda/layers/softmax-loss.jl +++ b/src/cuda/layers/softmax-loss.jl @@ -4,9 +4,7 @@ function backward(backend::GPUBackend, state::SoftmaxLossLayerState, inputs::Vec copy!(diff, state.softmax.blobs[1]) data_type = eltype(diff) - height, width, channels, num = get_whcn(diff) - - spatial_dim = height*width + spatial_dim, channels, num = split_dims(diff, state.logistic.op_dim) prob_dim = channels x_block = int(ceil(float64(num)/CUDA.THREADS_PER_BLOCK_X)) diff --git a/src/cuda/layers/softmax.jl b/src/cuda/layers/softmax.jl index 83061f7..40ad309 100644 --- a/src/cuda/layers/softmax.jl +++ b/src/cuda/layers/softmax.jl @@ -3,12 +3,13 @@ type CuDNNSoftmaxState outputs_desc :: Vector{CuDNN.Tensor4dDescriptor} end -function setup_etc(backend::GPUBackend, layer::SoftmaxLayer, data_type, inputs) +function setup_etc(backend::GPUBackend, layer::SoftmaxLayer, dims::Vector{Int}, data_type, inputs) inputs_desc = Array(CuDNN.Tensor4dDescriptor, length(inputs)) outputs_desc = Array(CuDNN.Tensor4dDescriptor, length(inputs)) for i = 1:length(inputs) - inputs_desc[i] = CuDNN.create_tensor4d_descriptor(data_type, get_whcn(inputs[i])) - outputs_desc[i] = CuDNN.create_tensor4d_descriptor(data_type, get_whcn(inputs[i])) + dim_sp, dim_prob, dim_num = split_dims(inputs[i], dims[i]) + inputs_desc[i] = CuDNN.create_tensor4d_descriptor(data_type, (1,dim_sp,dim_prob,dim_num)) + outputs_desc[i] = CuDNN.create_tensor4d_descriptor(data_type, (1,dim_sp,dim_prob,dim_num)) end etc = CuDNNSoftmaxState(inputs_desc, outputs_desc) return etc diff --git a/src/layers/accuracy.jl b/src/layers/accuracy.jl index e2a7ec1..a24b417 100644 --- a/src/layers/accuracy.jl +++ b/src/layers/accuracy.jl @@ -1,6 +1,7 @@ @defstruct AccuracyLayer Layer ( name :: String = "accuracy", report_error :: Bool = false, + (dim :: Int = -2, dim != 0), (bottoms :: Vector{Symbol} = Symbol[], length(bottoms) == 2), ) @characterize_layer(AccuracyLayer, @@ -11,18 +12,24 @@ type AccuracyLayerState <: LayerState layer :: AccuracyLayer + op_dim :: Int accuracy :: Float64 n_accum :: Int etc :: Any end -function setup_etc(backend::CPUBackend, layer::AccuracyLayer, inputs) +function setup_etc(backend::CPUBackend, layer::AccuracyLayer, op_dim::Int, inputs) nothing end function setup(backend::Backend, layer::AccuracyLayer, inputs::Vector{Blob}, diffs::Vector{Blob}) - etc = setup_etc(backend, layer, inputs) - return AccuracyLayerState(layer, 0.0, 0, etc) + total_dim = ndims(inputs[1]) + dim = layer.dim < 0 ? layer.dim + total_dim + 1 : layer.dim + @assert 1 <= dim <= total_dim + @assert dim != total_dim + + etc = setup_etc(backend, layer, dim, inputs) + return AccuracyLayerState(layer, dim, 0.0, 0, etc) end function shutdown(backend::CPUBackend, state::AccuracyLayerState) end @@ -48,20 +55,18 @@ function forward(backend::CPUBackend, state::AccuracyLayerState, inputs::Vector{ pred = inputs[1].data label = inputs[2].data - width, height, channels, num = get_whcn(pred) - canonical_pred = reshape(pred, (width,height,channels,num)) - canonical_label = reshape(label, (width,height,1,num)) + dim_pre, dim_prob, dim_post = split_dims(pred, state.op_dim) accuracy = 0.0 - for w = 1:width - for h = 1:height - for n = 1:num - if int(canonical_label[w,h,1,n])+1 == indmax(canonical_pred[w,h,:,n]) - accuracy += 1.0 - end + for i = 0:dim_pre-1 + for j = 0:dim_post-1 + idx = Int[i + dim_pre*(k + dim_prob*j) for k=0:dim_prob-1] + 1 + @inbounds if int(label[i + dim_pre*j + 1])+1 == indmax(pred[idx]) + accuracy += 1.0 end end end + state.accuracy = float64(state.accuracy * state.n_accum + accuracy) / (state.n_accum + length(label)) state.n_accum += length(label) end diff --git a/src/layers/argmax.jl b/src/layers/argmax.jl index 481f0b6..a2b2ee8 100644 --- a/src/layers/argmax.jl +++ b/src/layers/argmax.jl @@ -1,5 +1,6 @@ @defstruct ArgmaxLayer Layer ( name :: String = "argmax", + (dim :: Int = -2, dim != 0), (tops :: Vector{Symbol} = Symbol[], length(tops) > 0), (bottoms :: Vector{Symbol} = Symbol[], length(bottoms) == length(tops)), ) @@ -7,39 +8,45 @@ type ArgmaxLayerState <: LayerState layer :: ArgmaxLayer blobs :: Vector{Blob} + + dims :: Vector{Int} end function setup(backend::Backend, layer::ArgmaxLayer, inputs::Vector{Blob}, diffs::Vector{Blob}) - blobs = map(inputs) do input - width, height, channels, num = get_whcn(input) - data_type = eltype(input) - - blob = make_blob(backend, data_type, width, height, 1, num) - blob + dims = Array(Int, length(inputs)) + blobs = Array(Blob, length(inputs)) + for i = 1:length(inputs) + total_dim = ndims(inputs[i]) + dim = layer.dim < 0 ? layer.dim + total_dim + 1 : layer.dim + @assert 1 <= dim <= total_dim + @assert dim != total_dim + dims[i] = dim + shape = [size(inputs[i])...] + shape[dim] = 1 + blobs[i] = make_blob(backend, eltype(inputs[i]), shape...) end - return ArgmaxLayerState(layer, blobs) + return ArgmaxLayerState(layer, blobs, dims) end function forward(backend::CPUBackend, state::ArgmaxLayerState, inputs::Vector{Blob}) for i = 1:length(inputs) input = inputs[i].data output = state.blobs[i].data - width, height, channels, num = get_whcn(input) - canonical_input = reshape(input, (width,height,channels,num)) - for n = 1:num - for w = 1:width - for h = 1:height - maxc = 1; maxval = canonical_input[w,h,maxc,n] - for c = 2:channels - @inbounds val = canonical_input[w,h,c,n] - if val > maxval - maxval = val - maxc = c - end + pre_dim, mid_dim, post_dim = split_dims(input, state.dims[i]) + for x = 0:pre_dim-1 + for z = 0:post_dim-1 + idx = Int[x + pre_dim*(y + mid_dim*z) for y=0:mid_dim-1] + 1 + maxc = 1 + @inbounds maxval = input[idx[1]] + for y = 2:length(idx) + @inbounds val = input[idx[y]] + if val > maxval + maxval = val + maxc = y end - @inbounds output[w,h,1,n] = maxc-1 end + @inbounds output[x + pre_dim*z + 1] = maxc-1 end end end diff --git a/src/layers/channel-pooling.jl b/src/layers/channel-pooling.jl index 9f25e69..d05751b 100644 --- a/src/layers/channel-pooling.jl +++ b/src/layers/channel-pooling.jl @@ -5,6 +5,7 @@ (kernel :: Int = 1, kernel > 0), (stride :: Int = 1, stride > 0), (pad :: NTuple{2, Int} = (0,0), all([pad...] .>= 0)), + (dim :: Int = -2, dim != 0), pooling :: PoolingFunction = Pooling.Max(), ) @characterize_layer(ChannelPoolingLayer, @@ -16,22 +17,21 @@ type ChannelPoolingLayerState <: LayerState blobs :: Vector{Blob} blobs_diff :: Vector{Blob} + op_dims :: Vector{Int} etc :: Any end -function setup_etc(backend::CPUBackend, layer::ChannelPoolingLayer, inputs, pooled_chann) +function setup_etc(backend::CPUBackend, layer::ChannelPoolingLayer, inputs, blobs) if isa(layer.pooling, Pooling.Max) masks = Array(Array, length(inputs)) for i = 1:length(inputs) - masks[i] = Array(Csize_t, get_width(inputs[i]), get_height(inputs[i]), - pooled_chann[i], get_num(inputs[i])) + masks[i] = Array(Csize_t, size(blobs[i])) end etc = masks elseif isa(layer.pooling, Pooling.Mean) integrals = Array(Array, length(inputs)) for i = 1:length(inputs) - integrals[i] = Array(eltype(inputs[i]), get_width(inputs[i]), get_height(inputs[i]), - get_chann(inputs[i])) + integrals[i] = Array(eltype(inputs[i]), size(inputs[i])[1:end-1]) end etc = integrals else @@ -41,36 +41,44 @@ function setup_etc(backend::CPUBackend, layer::ChannelPoolingLayer, inputs, pool end function setup(backend::Backend, layer::ChannelPoolingLayer, inputs::Vector{Blob}, diffs::Vector{Blob}) - for i = 1:length(inputs) - # currently we only handle 4D-tensor - @assert ndims(inputs[i]) == 4 - end - pooled_chann_all = Array(Int, length(inputs)) blobs = Array(Blob, length(inputs)) blobs_diff = Array(Blob, length(inputs)) + op_dims = Array(Int, length(inputs)) for i = 1:length(inputs) - width, height, channels, num = size(inputs[i]) - pooled_chann = int(ceil(float(channels + layer.pad[1]+layer.pad[2] - layer.kernel) / layer.stride)) + 1 + dim_total = ndims(inputs[i]) + op_dim = layer.dim < 0 ? layer.dim + dim_total+1 : layer.dim + @assert 1 <= op_dim <= dim_total + @assert op_dim != dim_total + + op_dims[i] = op_dim + + dims = [size(inputs[i])...] + pool_dim = dims[op_dim] + pooled_dim = int(ceil(float(pool_dim + layer.pad[1]+layer.pad[2] - layer.kernel) / layer.stride)) + 1 # make sure the last pooling is not purely pooling padded area - if ((pooled_chann-1)*layer.stride >= channels + layer.pad[1]) - pooled_chann -= 1 + if ((pooled_dim-1)*layer.stride >= pool_dim + layer.pad[1]) + pooled_dim -= 1 end - pooled_chann_all[i] = pooled_chann + pooled_chann_all[i] = pooled_dim + + output_dims = copy(dims) + output_dims[op_dim] = pooled_dim + output_dims = tuple(output_dims...) data_type = eltype(inputs[i]) - blobs[i] = make_blob(backend, data_type, (width,height,pooled_chann_all[i],num)) + blobs[i] = make_blob(backend, data_type, output_dims) if isa(diffs[i], NullBlob) blobs_diff[i] = NullBlob() else - blobs_diff[i] = make_blob(backend, data_type, (width,height,pooled_chann_all[i],num)) + blobs_diff[i] = make_blob(backend, data_type, output_dims) end end - etc = setup_etc(backend, layer, inputs, pooled_chann_all) - state = ChannelPoolingLayerState(layer, blobs, blobs_diff, etc) + etc = setup_etc(backend, layer, inputs, blobs) + state = ChannelPoolingLayerState(layer, blobs, blobs_diff, op_dims, etc) end function shutdown_etc(backend::CPUBackend, state::ChannelPoolingLayerState) end @@ -90,12 +98,13 @@ function forward(backend::CPUBackend, pool::StdPoolingFunction, input = inputs[i].data output = state.blobs[i].data + dims_in = split_dims(input, state.op_dims[i]) + dims_out = split_dims(output, state.op_dims[i]) + if isa(pool, Pooling.Max) - max_channel_pooling_forward(input, output, state.etc[i], state.layer) + max_channel_pooling_forward(reshape(input,dims_in), reshape(output,dims_out), reshape(state.etc[i],dims_out), state.layer) elseif isa(pool, Pooling.Mean) - mean_channel_pooling_forward(input, output, state.etc[i], state.layer) - else - error("Pooling for $pool not implemented yet") + mean_channel_pooling_forward(reshape(input,dims_in), reshape(output,dims_out), state.etc[i], state.layer) end end end @@ -109,13 +118,16 @@ function backward(backend::CPUBackend, pool::StdPoolingFunction, state::ChannelP for i = 1:length(inputs) diff = diffs[i] + if !isa(diff, NullBlob) + dims_in = split_dims(inputs[i], state.op_dims[i]) + dims_out = split_dims(state.blobs[i], state.op_dims[i]) + if isa(pool, Pooling.Max) - max_channel_pooling_backward(diff.data, state.blobs_diff[i].data, state.etc[i], state.layer) + max_channel_pooling_backward(reshape(diff.data,dims_in), reshape(state.blobs_diff[i].data,dims_out), + reshape(state.etc[i],dims_out), state.layer) elseif isa(pool, Pooling.Mean) - mean_channel_pooling_backward(diff.data, state.blobs_diff[i].data, state.layer) - else - error("Pooling for $pool not implemented yet") + mean_channel_pooling_backward(reshape(diff.data,dims_in), reshape(state.blobs_diff[i].data,dims_out), state.layer) end end end diff --git a/src/layers/multinomial-logistic-loss.jl b/src/layers/multinomial-logistic-loss.jl index 3fdb796..ce4d4f6 100644 --- a/src/layers/multinomial-logistic-loss.jl +++ b/src/layers/multinomial-logistic-loss.jl @@ -4,6 +4,7 @@ @defstruct MultinomialLogisticLossLayer Layer ( name :: String = "multinomial-logistic-loss", weights :: Array = [], + (dim :: Int = -2, dim != 0), (normalize:: Symbol = :local, in(normalize,[:local,:global,:no])), (bottoms :: Vector{Symbol} = Symbol[], length(bottoms) == 2), ) @@ -16,41 +17,53 @@ type MultinomialLogisticLossLayerState{T} <: LayerState layer :: MultinomialLogisticLossLayer loss :: T + op_dim :: Int weights_blob :: Blob end function setup(backend::Backend, layer::MultinomialLogisticLossLayer, inputs::Vector{Blob}, diffs::Vector{Blob}) data_type = eltype(inputs[1]) - width, height, channels, num = get_whcn(inputs[1]) + tensor_dim = ndims(inputs[1]) + dims = size(inputs[1]) + + op_dim = layer.dim + if op_dim < 0 + op_dim += tensor_dim + 1 + end + @assert 1 <= op_dim <= tensor_dim + @assert op_dim != tensor_dim # the last dimension is the mini-batch dimension # weights for each class if isempty(layer.weights) weights_blob = NullBlob() else + @assert op_dim == tensor_dim-1 "When weights provided, LogisticLoss can only operate on the second-to-last dimension" weights = layer.weights if ndims(weights) == 1 - if length(weights) != channels + if length(weights) != dims[op_dim] error("Invalid weights: size should be equal to number of classes") end - weights = repeat(reshape(weights,1,1,channels), inner=[width,height,1]) + new_shape = ones(Int, tensor_dim-1); new_shape[op_dim] = dims[op_dim] + rep_shape = [dims[1:end-1]...]; rep_shape[op_dim] = 1 + weights = repeat(reshape(weights, new_shape...), inner=rep_shape) end - if ndims(weights) != 3 || size(weights) != (width,height,channels) - error("Invalid weights: should be either a 3-tensor of (width,height,channels) or a vector of (channels)") + if ndims(weights) != tensor_dim-1 || size(weights) != dims[1:end-1] + error("Invalid weights: should be either a ND-tensor of one data point or a vector of (classes)") end weights = convert(Array{data_type}, weights) if layer.normalize == :local - weights = weights .* (channels ./ sum(weights, 3)) + weights = weights .* (dims[op_dim] ./ sum(weights, op_dim)) elseif layer.normalize == :global - weights = weights * (width*height*channels / sum(weights)) + weights = weights * (prod(size(weights)) / sum(weights)) else @assert layer.normalize == :no end - weights_blob = make_blob(backend, reshape(weights, width,height,channels,1)) + weights_blob = make_blob(backend, weights) end - state = MultinomialLogisticLossLayerState(layer, convert(data_type, 0), weights_blob) + state = MultinomialLogisticLossLayerState(layer, convert(data_type, 0), op_dim, weights_blob) return state end function shutdown(backend::Backend, state::MultinomialLogisticLossLayerState) @@ -59,23 +72,26 @@ end function forward(backend::CPUBackend, state::MultinomialLogisticLossLayerState, inputs::Vector{Blob}) pred = inputs[1].data label = inputs[2].data - width, height, channels, num = get_whcn(pred) - idx_width = reshape(1:width, (width, 1, 1, 1)) - idx_height = reshape(1:height, (1, height, 1, 1)) - idx_chann = int(label)+1 - idx_num = reshape(1:num, (1, 1, 1, num)) + dims = size(pred) - pred = reshape(pred, (width,height,channels,num)) - label = reshape(label, (width,height,1,num)) + idx_all = map(1:length(dims)) do i + if i == state.op_dim + int(label) + 1 + else + dim = dims[i] + reshape(1:dim, [j == i? dim : 1 for j = 1:length(dims)]...) + end + end if isa(state.weights_blob, NullBlob) - loss = sum(-log(max(broadcast_getindex(pred, idx_width, idx_height, idx_chann, idx_num), 1e-20))) + loss = sum(-log(max(broadcast_getindex(pred, idx_all...), 1e-20))) else - loss = sum(-log(max(broadcast_getindex(pred, idx_width, idx_height, idx_chann, idx_num), 1e-20)) .* - broadcast_getindex(state.weights_blob.data, idx_width, idx_height, idx_chann, reshape([1],1,1,1,1))) + tmp = reshape([1], ones(Int, length(dims))...) + loss = sum(-log(max(broadcast_getindex(pred, idx_all...), 1e-20)) .* + broadcast_getindex(state.weights_blob.data, idx_all[1:end-1]..., tmp)) end - state.loss = loss / (width*height*num) + state.loss = loss / (prod(dims) / dims[state.op_dim]) end function backward(backend::Backend, state::MultinomialLogisticLossLayerState, inputs::Vector{Blob}, diffs::Vector{Blob}) diff --git a/src/layers/pooling/channel-pooling.jl b/src/layers/pooling/channel-pooling.jl index 5522b2c..28906fc 100644 --- a/src/layers/pooling/channel-pooling.jl +++ b/src/layers/pooling/channel-pooling.jl @@ -1,9 +1,9 @@ ################################################################################ # Pooling in channels ################################################################################ -function max_channel_pooling_forward{T}(input::Array{T}, output::Array{T}, mask::Array{Csize_t}, layer) - width, height, channels, num = size(input) - pooled_chann = size(output, 3) +function max_channel_pooling_forward{T}(input::Array{T,3}, output::Array{T,3}, mask::Array{Csize_t,3}, layer) + spatial_dim, channels, num = size(input) + pooled_chann = size(output, 2) for n = 1:num for pc = 1:pooled_chann @@ -11,22 +11,18 @@ function max_channel_pooling_forward{T}(input::Array{T}, output::Array{T}, mask: cend = min(cstart + layer.kernel - 1, channels) cstart = max(1, cstart) - for w = 1:width - for h = 1:height - @inbounds output[w,h,pc,n] = input[w,h,cstart,n] - @inbounds mask[w,h,pc,n] = cstart - end + for s = 1:spatial_dim + @inbounds output[s,pc,n] = input[s,cstart,n] + @inbounds mask[s,pc,n] = cstart end for c = cstart+1:cend - for w = 1:width - for h = 1:height - @inbounds maxval = output[w,h,pc,n] - @inbounds val = input[w,h,c,n] - if val > maxval - @inbounds output[w,h,pc,n] = val - @inbounds mask[w,h,pc,n] = c - end + for s = 1:spatial_dim + @inbounds maxval = output[s,pc,n] + @inbounds val = input[s,c,n] + if val > maxval + @inbounds output[s,pc,n] = val + @inbounds mask[s,pc,n] = c end end end @@ -34,14 +30,13 @@ function max_channel_pooling_forward{T}(input::Array{T}, output::Array{T}, mask: end end -function mean_channel_pooling_forward{T}(input::Array{T}, output::Array{T}, integral::Array{T}, layer) - width, height, channels, num = size(input) - pooled_chann = size(output, 3) +function mean_channel_pooling_forward{T}(input::Array{T,3}, output::Array{T,3}, integral::Array{T}, layer) + spatial_dim_T, channels, num = size(input) + pooled_chann = size(output, 2) one = convert(T, 1) neg_one = convert(T, -1) scale = 1/convert(T, layer.kernel) - spatial_dim_T = width*height spatial_dim = spatial_dim_T * sizeof(T) fea_dim = spatial_dim * channels output_fea_dim = spatial_dim * pooled_chann @@ -78,9 +73,9 @@ function mean_channel_pooling_forward{T}(input::Array{T}, output::Array{T}, inte end end -function max_channel_pooling_backward{T}(input::Array{T}, output::Array{T}, mask::Array{Csize_t}, layer) - width, height, channels, num = size(input) - pooled_chann = size(output, 3) +function max_channel_pooling_backward{T}(input::Array{T,3}, output::Array{T,3}, mask::Array{Csize_t,3}, layer) + spatial_dim, channels, num = size(input) + pooled_chann = size(output, 2) fill!(input, 0) for n = 1:num @@ -89,23 +84,20 @@ function max_channel_pooling_backward{T}(input::Array{T}, output::Array{T}, mask cend = min(cstart + layer.kernel - 1, channels) cstart = max(1, cstart) - for w = 1:width - for h = 1:height - @inbounds input[w,h,mask[w,h,pc,n],n] += output[w,h,pc,n] - end + for s = 1:spatial_dim + @inbounds input[s,mask[s,pc,n],n] += output[s,pc,n] end end end end -function mean_channel_pooling_backward{T}(input::Array{T}, output::Array{T}, layer) - width, height, channels, num = size(input) - pooled_chann = size(output, 3) +function mean_channel_pooling_backward{T}(input::Array{T,3}, output::Array{T,3}, layer) + spatial_dim_T, channels, num = size(input) + pooled_chann = size(output, 2) scale = 1/convert(T, layer.kernel) fill!(input, 0) - spatial_dim_T = width*height spatial_dim = spatial_dim_T * sizeof(T) fea_dim = spatial_dim * channels output_fea_dim = spatial_dim * pooled_chann diff --git a/src/layers/softmax-loss.jl b/src/layers/softmax-loss.jl index 5d06933..1b35334 100644 --- a/src/layers/softmax-loss.jl +++ b/src/layers/softmax-loss.jl @@ -5,6 +5,7 @@ name :: String = "softmax-loss", weights :: Array = [], normalize:: Symbol = :local, + (dim :: Int = -2, dim != 0), (bottoms :: Vector{Symbol} = Symbol[], length(bottoms) == 2), ) @characterize_layer(SoftmaxLossLayer, @@ -24,12 +25,11 @@ end function setup(backend::Backend, layer::SoftmaxLossLayer, inputs::Vector{Blob}, diffs::Vector{Blob}) data_type = eltype(inputs[1]) - softmax_layer = SoftmaxLayer(tops=Array(Symbol, length(inputs)), bottoms=Array(Symbol, length(inputs))) - + softmax_layer = SoftmaxLayer(tops=Array(Symbol, length(inputs)), bottoms=Array(Symbol, length(inputs)), dim=layer.dim) softmax = setup(backend, softmax_layer, Blob[inputs[1]], Blob[]) logistic_layer = MultinomialLogisticLossLayer(bottoms=Array(Symbol, 2), - weights=layer.weights, normalize=layer.normalize) + weights=layer.weights, normalize=layer.normalize, dim=layer.dim) logistic = setup(backend, logistic_layer, inputs, Blob[]) state = SoftmaxLossLayerState(layer, convert(data_type, 0), softmax, logistic) @@ -49,31 +49,36 @@ end function backward(backend::CPUBackend, state::SoftmaxLossLayerState, inputs::Vector{Blob}, diffs::Vector{Blob}) diff = diffs[1] if isa(diff, CPUBlob) - width, height, channels, num = get_whcn(diff) + dims = size(diff) - idx_width = reshape(1:width, (width, 1, 1, 1)) - idx_height = reshape(1:height, (1, height, 1, 1)) - idx_chann = int(inputs[2].data)+1 - idx_num = reshape(1:num, (1, 1, 1, num)) + idx_all = map(1:length(dims)) do i + if i == state.logistic.op_dim + int(inputs[2].data) + 1 + else + dim = dims[i] + reshape(1:dim, [j == i? dim : 1 for j = 1:length(dims)]...) + end + end if isa(state.logistic.weights_blob, NullBlob) copy!(diff, state.softmax.blobs[1]) else - idx_num_dumb = reshape([1],1,1,1,1) - copy!(diff, reshape(state.softmax.blobs[1].data, (width,height,channels,num)) .* - broadcast_getindex(state.logistic.weights_blob.data, idx_width, idx_height, idx_chann, idx_num_dumb)) + idx_num_dumb = reshape([1], ones(Int, length(dims))...) + copy!(diff, state.softmax.blobs[1].data .* + broadcast_getindex(state.logistic.weights_blob.data, idx_all[1:end-1]..., idx_num_dumb)) end - index = (idx_width,idx_height,idx_chann,idx_num) - diff_data = reshape(diff.data, (width,height,channels,num)) + diff_data = reshape(diff.data, dims) if isa(state.logistic.weights_blob, NullBlob) - broadcast_setindex!(diff_data, broadcast_getindex(diff_data, index...)-1, index...) + broadcast_setindex!(diff_data, broadcast_getindex(diff_data, idx_all...)-1, idx_all...) else - broadcast_setindex!(diff_data, broadcast_getindex(diff_data, index...) .- - broadcast_getindex(state.logistic.weights_blob.data, idx_width,idx_height,idx_chann,idx_num_dumb), - index...) + # NOTE: here we rely on the fact that op_dim == length(dims)-1, this requirement + # is enforced in MultinomialLogisticLossLayer when weights are provided + broadcast_setindex!(diff_data, broadcast_getindex(diff_data, idx_all...) .- + broadcast_getindex(state.logistic.weights_blob.data, idx_all[1:end-1]...,idx_num_dumb), + idx_all...) end - Vec.mul_scal!(diff.data, 1/(width*height*num)) + Vec.mul_scal!(diff.data, dims[state.logistic.op_dim]/prod(dims)) end end diff --git a/src/layers/softmax.jl b/src/layers/softmax.jl index 9860030..f003f2f 100644 --- a/src/layers/softmax.jl +++ b/src/layers/softmax.jl @@ -3,27 +3,37 @@ ############################################################ @defstruct SoftmaxLayer Layer ( name :: String = "softmax", + (dim :: Int = -2, dim != 0), (tops :: Vector{Symbol} = Symbol[], length(tops) > 0), (bottoms :: Vector{Symbol} = Symbol[], length(bottoms) == length(tops)) ) type SoftmaxLayerState <: LayerState - layer :: SoftmaxLayer - blobs :: Vector{Blob} + layer :: SoftmaxLayer + blobs :: Vector{Blob} - etc :: Any + dims :: Vector{Int} + etc :: Any end -function setup_etc(backend::CPUBackend, layer::SoftmaxLayer, data_type, inputs) +function setup_etc(backend::CPUBackend, layer::SoftmaxLayer, dims::Vector{Int}, data_type, inputs) nothing end function setup(backend::Backend, layer::SoftmaxLayer, inputs::Vector{Blob}, diffs::Vector{Blob}) data_type = eltype(inputs[1]) blobs = Blob[make_blob(backend, data_type, size(input)) for input in inputs] - etc = setup_etc(backend, layer, data_type, inputs) + dims = map(inputs) do input + total_dim = ndims(input) + dim = layer.dim < 0 ? layer.dim + total_dim + 1 : layer.dim + @assert 1 <= dim <= total_dim + @assert dim != total_dim # should not operate on the mini-batch dimension + dim + end + + etc = setup_etc(backend, layer, dims, data_type, inputs) - state = SoftmaxLayerState(layer, blobs, etc) + state = SoftmaxLayerState(layer, blobs, dims, etc) return state end function shutdown(backend::CPUBackend, state::SoftmaxLayerState) @@ -31,31 +41,30 @@ function shutdown(backend::CPUBackend, state::SoftmaxLayerState) end function forward(backend::CPUBackend, state::SoftmaxLayerState, inputs::Vector{Blob}) - for i = 1:length(inputs) - input = inputs[i].data - output = state.blobs[i].data + for ii = 1:length(inputs) + input = inputs[ii].data + output = state.blobs[ii].data + op_dim = state.dims[ii] + + dim_pre, dim_prob, dim_post = split_dims(input, op_dim) - width, height, channels, num = get_whcn(input) - input = reshape(input, (width,height,channels,num)) - output = reshape(output, (width,height,channels,num)) + for i = 0:dim_pre-1 + for j = 0:dim_post-1 + idx = Int[i + dim_pre*(k + dim_prob*j) for k=0:dim_prob-1] + 1 - for w = 1:width - for h = 1:height - for n = 1:num - maxval = -Inf - for c = 1:channels - @inbounds maxval = max(maxval, input[w,h,c,n]) - end - for c = 1:channels - @inbounds output[w,h,c,n] = exp(input[w,h,c,n]-maxval) - end - the_sum = 0.0 - for c = 1:channels - @inbounds the_sum += output[w,h,c,n] - end - for c = 1:channels - @inbounds output[w,h,c,n] /= the_sum - end + maxval = -Inf + for k in idx + @inbounds maxval = max(maxval, input[k]) + end + for k in idx + @inbounds output[k] = exp(input[k]-maxval) + end + the_sum = 0.0 + for k in idx + @inbounds the_sum += output[k] + end + for k in idx + @inbounds output[k] /= the_sum end end end diff --git a/src/utils/tensor.jl b/src/utils/tensor.jl new file mode 100644 index 0000000..abbe1d6 --- /dev/null +++ b/src/utils/tensor.jl @@ -0,0 +1,12 @@ +export split_dims + +# Split the dimension of a ND-tensor into 3 parts: +# (dim_pre, dim_mid, dim_post) +function split_dims{T}(tensor::T, dim::Int) + dims = size(tensor) + dim_pre ::Int = prod(dims[1:dim-1]) + dim_mid ::Int = dims[dim] + dim_post ::Int = prod(dims[dim+1:end]) + + (dim_pre, dim_mid, dim_post) +end diff --git a/test/layers/accuracy.jl b/test/layers/accuracy.jl index 4b9f9ff..487a99c 100644 --- a/test/layers/accuracy.jl +++ b/test/layers/accuracy.jl @@ -1,58 +1,65 @@ -function test_accuracy_layer(backend::Backend, T) +function test_accuracy_layer(backend::Backend, tensor_dim, T) println("-- Testing AccuracyLayer on $(typeof(backend)){$T}...") - tensor_dim = abs(rand(Int)) % 4 + 2 - dims = tuple((abs(rand(Int,tensor_dim)) % 6 + 6)...) - println(" > $dims") + dims = abs(rand(Int,tensor_dim)) % 6 + 6 + op_dim = max(abs(rand(Int)) % tensor_dim, 1) + dims_label = copy(dims); dims_label[op_dim] = 1 + dims = tuple(dims...) + dims_label = tuple(dims_label...) + println(" > $dims (operate on dimension $op_dim)") eps = 1e-5 input = rand(T, dims) input_blob = make_blob(backend, input) - width, height, channels, num = get_whcn(input) - - label = abs(rand(Int, (width, height, 1, num))) % channels + label = abs(rand(Int, dims_label)) % dims[op_dim] label = convert(Array{T}, label) label_blob = make_blob(backend, label) inputs = Blob[input_blob, label_blob] - layer = AccuracyLayer(bottoms=[:pred, :labels]) + layer = AccuracyLayer(bottoms=[:pred, :labels], dim=op_dim) state = setup(backend, layer, inputs, Blob[]) println(" > Forward") forward(backend, state, inputs) - @test state.n_accum == width*height*num + @test state.n_accum == prod(dims_label) + + dim_pre, dim_pred, dim_post = split_dims(input, op_dim) - canonical_input = reshape(input, (width,height,channels,num)) + canonical_input = reshape(input, (dim_pre, dim_pred, dim_post)) + canonical_label = reshape(label, (dim_pre, 1, dim_post)) expected_acc = 0.0 - for n = 1:num - for w = 1:width - for h = 1:height - pred = canonical_input[w, h, :, n] - if indmax(pred) == convert(Int, label[w,h,1,n])+1 - expected_acc += 1 - end + for i = 1:dim_pre + for j = 1:dim_post + pred = canonical_input[i,:,j] + if indmax(pred) == int(canonical_label[i,1,j])+1 + expected_acc += 1 end end end - expected_acc /= (width*height*num) + expected_acc /= prod(dims_label) @test abs(state.accuracy - expected_acc) < eps println(" > Forward Again") forward(backend, state, inputs) - @test state.n_accum == 2*width*height*num + @test state.n_accum == 2*prod(dims_label) @test abs(state.accuracy - expected_acc) < eps println(" > Forward Again and Again") reset_statistics(state) forward(backend, state, inputs) - @test state.n_accum == width*height*num + @test state.n_accum == prod(dims_label) @test abs(state.accuracy - expected_acc) < eps shutdown(backend, state) end +function test_accuracy_layer(backend::Backend, T) + for i = 2:5 + test_accuracy_layer(backend, i, T) + end +end function test_accuracy_layer(backend::Backend) test_accuracy_layer(backend, Float32) test_accuracy_layer(backend, Float64) diff --git a/test/layers/argmax.jl b/test/layers/argmax.jl index 4ca3ca8..20acc4a 100644 --- a/test/layers/argmax.jl +++ b/test/layers/argmax.jl @@ -1,30 +1,32 @@ -function test_argmax_layer(backend::Backend, n_input, T, eps) +function test_argmax_layer(backend::Backend, n_input, tensor_dim, T, eps) println("-- Testing ArgmaxLayer on $(typeof(backend)){$T}...") - tensor_dim = abs(rand(Int)) % 4 + 2 println(" > $tensor_dim-dimensional tensor") dims = [abs(rand(Int, tensor_dim)) % 6 + 1 for i = 1:n_input] - input = [rand(T, dims[i]...) for i = 1:n_input] - input_blob = Blob[make_blob(backend, x) for x in input] + op_dim = max(abs(rand(Int)) % tensor_dim, 1) + inputs = [rand(T, dims[i]...) for i = 1:n_input] + input_blob = Blob[make_blob(backend, x) for x in inputs] diff_blob = Blob[NullBlob() for i = 1:n_input] println(" > Setup") - layer = ArgmaxLayer(bottoms=Array(Symbol,n_input),tops=Array(Symbol,n_input)) + layer = ArgmaxLayer(bottoms=Array(Symbol,n_input),tops=Array(Symbol,n_input),dim=op_dim) state = setup(backend, layer, input_blob, diff_blob) println(" > Forward") forward(backend, state, input_blob) for i = 1:n_input - width,height,channels,num = get_whcn(input[i]) - got_output = zeros(T, width, height, 1, num) - canonical_input = reshape(input[i], (width,height,channels,num)) - expected_output = similar(got_output) - for n = 1:num - for w = 1:width - for h = 1:height - expected_output[w,h,1,n] = indmax(canonical_input[w,h,:,n])-1 - end + outdim = [size(inputs[i])...] + outdim[op_dim] = 1 + got_output = zeros(T, outdim...) + expected_output = zeros(T, outdim...) + + pre_dim, mid_dim, post_dim = split_dims(inputs[i], op_dim) + input = reshape(inputs[i], pre_dim, mid_dim, post_dim) + output = reshape(expected_output, pre_dim, 1, post_dim) + for x = 1:pre_dim + for z = 1:post_dim + output[x,1,z] = indmax(input[x,:,z])-1 end end @@ -34,6 +36,11 @@ function test_argmax_layer(backend::Backend, n_input, T, eps) shutdown(backend, state) end +function test_argmax_layer(backend::Backend, n_input, T, eps) + for i = 2:6 + test_argmax_layer(backend, n_input, i, T, eps) + end +end function test_argmax_layer(backend::Backend) test_argmax_layer(backend, 3, Float64, 1e-10) test_argmax_layer(backend, 3, Float32, 1e-10) diff --git a/test/layers/channel-pooling.jl b/test/layers/channel-pooling.jl index e8cc30a..c37720d 100644 --- a/test/layers/channel-pooling.jl +++ b/test/layers/channel-pooling.jl @@ -1,14 +1,16 @@ -function test_channel_pooling_layer(backend::Backend, pooling::PoolingFunction, n_input, T, eps) +function test_channel_pooling_layer(backend::Backend, pooling::PoolingFunction, tensor_dim::Int, n_input, T, eps) println("-- Testing ChannelPooling($(typeof(pooling))) on $(typeof(backend)){$T}...") - println(" > Setup") - dims = [abs(rand(Int, 4)) % 7 + 7 for i = 1:n_input] + dims = [abs(rand(Int, tensor_dim)) % 7 + 7 for i = 1:n_input] + op_dim = max(abs(rand(Int)) % tensor_dim, 1) pad = (2,2) kernel = 3 stride = 2 + println(" > Setup (pool along dimension $op_dim for $tensor_dim-D tensors)") + layer = ChannelPoolingLayer(kernel=kernel, stride=stride, pad=pad, pooling=pooling, - tops=Array(Symbol,n_input), bottoms=Array(Symbol,n_input)) + tops=Array(Symbol,n_input), bottoms=Array(Symbol,n_input), dim=op_dim) input = [rand(T, dim...) for dim in dims] inputs = Blob[make_blob(backend, x) for x in input] @@ -21,7 +23,7 @@ function test_channel_pooling_layer(backend::Backend, pooling::PoolingFunction, payloads = Array(Any, n_input) for i = 1:n_input - expected_output, payloads[i] = channel_pooling_forward(state, i, input[i]) + expected_output, payloads[i] = channel_pooling_forward(state, i, input[i], op_dim) got_output = to_array(state.blobs[i]) @test all(-eps .< expected_output-got_output .< eps) end @@ -34,7 +36,7 @@ function test_channel_pooling_layer(backend::Backend, pooling::PoolingFunction, backward(backend, state, inputs, diffs) for i = 1:n_input - expected_output = channel_pooling_backward(state, i, input[i], top_diff[i], payloads[i]) + expected_output = channel_pooling_backward(state, i, input[i], top_diff[i], payloads[i], op_dim) got_output = to_array(diffs[i]) @test all(-eps .< expected_output - got_output .< eps) end @@ -42,28 +44,34 @@ function test_channel_pooling_layer(backend::Backend, pooling::PoolingFunction, shutdown(backend, state) end -function channel_pooling_forward(state, i, input::Array) - width, height, channels, num = size(input) - pooled_chann = get_chann(state.blobs[i]) +function channel_pooling_forward(state, i, input::Array, op_dim) + dim_pre, dim_pool, dim_post = split_dims(input, op_dim) + dim_pooled = size(state.blobs[i], op_dim) - output = zeros(eltype(input), width, height, pooled_chann, num) + output = zeros(eltype(input), size(state.blobs[i])) if isa(state.layer.pooling, Pooling.Max) mask = similar(output, Int) end - for n = 1:num - for pc = 1:pooled_chann + canonical_input = reshape(input, dim_pre, dim_pool, dim_post) + canonical_output = reshape(output, dim_pre, dim_pooled, dim_post) + if isa(state.layer.pooling, Pooling.Max) + canonical_mask = reshape(mask, dim_pre, dim_pooled, dim_post) + end + + for n = 1:dim_post + for pc = 1:dim_pooled cstart = (pc-1)*state.layer.stride - state.layer.pad[1] + 1 - cend = min(cstart + state.layer.kernel - 1, channels) + cend = min(cstart + state.layer.kernel - 1, size(input, op_dim)) cstart = max(1, cstart) - region = input[:,:,cstart:cend, n] + region = canonical_input[:,cstart:cend,n] if isa(state.layer.pooling, Pooling.Max) - maxval, maxidx = findmax(region, 3) - output[:,:,pc,n] = maxval - mask[:,:,pc,n] = maxidx + maxval, maxidx = findmax(region, 2) + canonical_output[:,pc,n] = maxval + canonical_mask[:,pc,n] = maxidx elseif isa(state.layer.pooling, Pooling.Mean) - output[:,:,pc,n] = sum(region, 3) / state.layer.kernel + canonical_output[:,pc,n] = sum(region, 2) / state.layer.kernel else error("Unknown pooling $(state.layer.pooling)") end @@ -77,24 +85,30 @@ function channel_pooling_forward(state, i, input::Array) end end -function channel_pooling_backward(state, i, input::Array, diff::Array, payload::Any) - width, height, channels, num = size(input) - pooled_chann = get_chann(state.blobs[i]) +function channel_pooling_backward(state, i, input::Array, diff::Array, payload::Any, op_dim) + dim_pre, dim_pool, dim_post = split_dims(input, op_dim) + dim_pooled = size(state.blobs[i], op_dim) - gradient = zeros(eltype(input), width, height, channels, num) - for n = 1:num - for pc = 1:pooled_chann + gradient = zeros(eltype(input), size(input)) + canonical_input = reshape(input, dim_pre, dim_pool, dim_post) + canonical_gradient = reshape(gradient, dim_pre, dim_pool, dim_post) + canonical_diff = reshape(diff, dim_pre, dim_pooled, dim_post) + if isa(state.layer.pooling, Pooling.Max) + canonical_mask = reshape(payload, dim_pre, dim_pooled, dim_post) + end + for n = 1:dim_post + for pc = 1:dim_pooled cstart = (pc-1)*state.layer.stride - state.layer.pad[1] + 1 - cend = min(cstart + state.layer.kernel - 1, channels) + cend = min(cstart + state.layer.kernel - 1, size(input, op_dim)) cstart = max(1, cstart) if isa(state.layer.pooling, Pooling.Max) - region = sub(gradient,1:width,1:height,cstart:cend,n) - maxidx = payload[:,:,pc,n] - region[vec(maxidx)] += vec(diff[:,:,pc,n]) + region = sub(canonical_gradient,1:dim_pre,cstart:cend,n) + maxidx = canonical_mask[:,pc,n] + region[vec(maxidx)] += vec(canonical_diff[:,pc,n]) elseif isa(state.layer.pooling, Pooling.Mean) for c = cstart:cend - gradient[:,:,c,n] += diff[:,:,pc,n] / state.layer.kernel + canonical_gradient[:,c,n] += canonical_diff[:,pc,n] / state.layer.kernel end else error("Unknown pooling $(state.layer.pooling)") @@ -104,6 +118,11 @@ function channel_pooling_backward(state, i, input::Array, diff::Array, payload:: return gradient end +function test_channel_pooling_layer(backend::Backend, pooling::PoolingFunction, n_input, T, eps) + for i = 2:6 + test_channel_pooling_layer(backend, pooling, i, n_input, T, eps) + end +end function test_channel_pooling_layer(backend::Backend, n_input, T, eps) test_channel_pooling_layer(backend, Pooling.Max(), n_input, T, eps) test_channel_pooling_layer(backend, Pooling.Mean(), n_input, T, eps) diff --git a/test/layers/multinomial-logistic-loss.jl b/test/layers/multinomial-logistic-loss.jl index 2c53732..15d9e7d 100644 --- a/test/layers/multinomial-logistic-loss.jl +++ b/test/layers/multinomial-logistic-loss.jl @@ -1,15 +1,21 @@ -function test_multinomial_logistic_loss_layer(backend::Backend, class_weights, T, eps) +function test_multinomial_logistic_loss_layer(backend::Backend, tensor_dim, class_weights, T, eps) println("-- Testing MultinomialLogisticLossLayer{$(class_weights[1]),$(class_weights[2])} on $(typeof(backend)){$T}...") - tensor_dim = abs(rand(Int)) % 4 + 2 - dims = tuple((abs(rand(Int,tensor_dim)) % 6 + 6)...) - println(" > $dims") + dims = abs(rand(Int,tensor_dim)) % 6 + 6 + if class_weights[1] != :no + op_dim = tensor_dim-1 + else + op_dim = max(abs(rand(Int)) % tensor_dim, 1) + end + println(" > $dims (operate on dimension $op_dim)") - prob = abs(rand(T, dims)) + dims_label = copy(dims); dims_label[op_dim] = 1; dims_label = tuple(dims_label...) + dims = tuple(dims...) + channels = dims[op_dim] - width, height, channels, num = get_whcn(prob) + prob = abs(rand(T, dims)) + 0.01 - label = abs(rand(Int, (width, height, 1, num))) % channels + label = abs(rand(Int, dims_label)) % channels label = convert(Array{T}, label) prob_blob = make_blob(backend, prob) @@ -21,57 +27,71 @@ function test_multinomial_logistic_loss_layer(backend::Backend, class_weights, T elseif class_weights[1] == :local weights = rand(T, channels) elseif class_weights[1] == :global - weights = rand(T, width, height, channels) + weights = round(1000*rand(T, dims[1:end-1]))/1000 else @assert class_weights[1] == :no weights = [] end layer = MultinomialLogisticLossLayer(bottoms=[:pred, :labels], - weights=weights, normalize=class_weights[2]) + weights=weights, normalize=class_weights[2], dim=op_dim) state = setup(backend, layer, inputs, Blob[]) forward(backend, state, inputs) if class_weights[1] == :local || class_weights[1] == :equal - weights = repeat(reshape(weights,1,1,channels), inner=[width,height,1]) + new_shape = ones(Int, tensor_dim-1); new_shape[op_dim] = dims[op_dim] + rep_shape = [dims[1:end-1]...]; rep_shape[op_dim] = 1 + weights = repeat(reshape(weights, new_shape...), inner=rep_shape) end if class_weights[2] == :local - weights = weights .* (channels ./ sum(weights,3)) + weights = weights .* (channels ./ sum(weights,op_dim)) elseif class_weights[2] == :global - weights = weights * (width*height*channels / sum(weights)) + weights = weights * (prod(dims[1:end-1]) / sum(weights)) else @assert class_weights[2] == :no end + new_shape = [size(weights)..., 1] + rep_shape = ones(Int, tensor_dim); rep_shape[end] = dims[end] + weights = repeat(reshape(weights, new_shape...), inner=rep_shape) + expected_loss = convert(T, 0) - prob = reshape(prob, (width,height,channels,num)) - for w = 1:width - for h = 1:height - for n = 1:num - if isempty(weights) - expected_loss += -log(prob[w, h, int(label[w, h, 1, n])+1, n]) - else - idx = int(label[w,h,1,n])+1 - expected_loss += -log(prob[w,h,idx,n]) * weights[w,h,idx] - end + dim_pre, dim_prob, dim_post = split_dims(prob, op_dim) + prob = reshape(prob, (dim_pre, dim_prob, dim_post)) + if !isempty(weights) + weights = reshape(weights, (dim_pre, dim_prob, dim_post)) + end + label = reshape(label, (dim_pre, 1, dim_post)) + for i = 1:dim_pre + for j = 1:dim_post + if isempty(weights) + expected_loss += -log(prob[i, int(label[i,1,j])+1, j]) + else + idx = int(label[i,1,j])+1 + expected_loss += -log(prob[i,idx,j]) * weights[i,idx,j] end end end - expected_loss /= (width*height*num) + expected_loss /= prod(dims) / dims[op_dim] @test -eps < state.loss - expected_loss < eps shutdown(backend, state) end +function test_multinomial_logistic_loss_layer(backend::Backend, class_weights, T, eps) + for i = 2:6 + test_multinomial_logistic_loss_layer(backend, i, class_weights, T, eps) + end +end function test_multinomial_logistic_loss_layer(backend::Backend, T, eps) for class_weights in ((:equal,:local),(:local,:local),(:global,:global),(:global,:local),(:no,:no)) test_multinomial_logistic_loss_layer(backend, class_weights, T, eps) end end function test_multinomial_logistic_loss_layer(backend::Backend) - test_multinomial_logistic_loss_layer(backend, Float32, 1e-3) test_multinomial_logistic_loss_layer(backend, Float64, 1e-5) + test_multinomial_logistic_loss_layer(backend, Float32, 1e-2) end if test_cpu diff --git a/test/layers/softmax-loss.jl b/test/layers/softmax-loss.jl index dcaea35..c1d06a0 100644 --- a/test/layers/softmax-loss.jl +++ b/test/layers/softmax-loss.jl @@ -1,59 +1,75 @@ -function test_softmax_loss_layer(backend::Backend, use_weights::Bool, T, eps) +function test_softmax_loss_layer(backend::Backend, tensor_dim, use_weights::Bool, T, eps) println("-- Testing SoftmaxLossLayer on $(typeof(backend)){$T} $(use_weights ? "(with weights)" : "")...") - tensor_dim = abs(rand(Int)) % 4 + 2 - dims = tuple((abs(rand(Int,tensor_dim)) % 6 + 6)...) - println(" > $dims") + if use_weights + op_dim = tensor_dim-1 + else + op_dim = max(abs(rand(Int)) % tensor_dim, 1) + end + dims = abs(rand(Int,tensor_dim)) % 6 + 6 + dims_label = copy(dims); dims_label[op_dim] = 1 + dims = tuple(dims...) + dims_label = tuple(dims_label...) + println(" > $dims (operate on dimension $op_dim)") - input = rand(T, dims) - width, height, channels, num = get_whcn(input) + input = rand(T, dims) + 0.01 input_blob = make_blob(backend, input) diff_blob = make_blob(backend, T, size(input)) if use_weights - weights = rand(T, width, height, channels) + 0.1 - weights = weights .* (channels ./ sum(weights,3)) + weights = rand(T, dims[1:end-1]) + 0.1 + weights = weights .* (dims[op_dim] ./ sum(weights,op_dim)) else weights = [] end - label = abs(rand(Int, (width, height, 1, num))) % channels + label = abs(rand(Int, dims_label)) % dims[op_dim] label = convert(Array{T}, label) label_blob = make_blob(backend, label) inputs = Blob[input_blob, label_blob] - layer = SoftmaxLossLayer(bottoms=[:pred, :labels], weights=weights, normalize=:local) + layer = SoftmaxLossLayer(bottoms=[:pred, :labels], weights=weights, normalize=:local, dim=op_dim) state = setup(backend, layer, inputs, Blob[diff_blob]) println(" > Forward") forward(backend, state, inputs) + new_shape = [size(weights)..., 1] + rep_shape = ones(Int, tensor_dim); rep_shape[end] = dims[end] + weights = repeat(reshape(weights, new_shape...), inner=rep_shape) + expected_loss = 0 expected_grad = zeros(T, size(input)) - canonical_input = reshape(input, (width,height,channels,num)) - canonical_grad = reshape(expected_grad, (width,height,channels,num)) - for w = 1:width - for h = 1:height - for n = 1:num - pred = exp(canonical_input[w, h, :, n]) - pred /= sum(pred) - if isempty(weights) - canonical_grad[w, h, :, n] = pred - canonical_grad[w, h, int(label[w,h,1,n])+1, n] -= 1 - expected_loss += -log(pred[int(label[w,h,1,n])+1]) - else - y = int(label[w,h,1,n])+1 - canonical_grad[w, h, :, n] = pred .* weights[w,h,y] - canonical_grad[w, h, y, n] -= weights[w,h,y] - expected_loss += -log(pred[y]) * weights[w,h,y] - end + + dim_pre, dim_prob, dim_post = split_dims(input, op_dim) + canonical_input = reshape(input, dim_pre, dim_prob, dim_post) + canonical_grad = reshape(expected_grad, dim_pre, dim_prob, dim_post) + if !isempty(weights) + weights = reshape(weights, (dim_pre, dim_prob, dim_post)) + end + label = reshape(label, dim_pre, 1, dim_post) + for i = 1:dim_pre + for j = 1:dim_post + pred = exp(canonical_input[i,:,j]) + pred /= sum(pred) + if isempty(weights) + canonical_grad[i,:,j] = pred + canonical_grad[i,int(label[i,1,j])+1,j] -= 1 + expected_loss += -log(pred[int(label[i,1,j])+1]) + else + y = int(label[i,1,j])+1 + canonical_grad[i,:,j] = pred .* weights[i,y,j] + canonical_grad[i,y,j] -= weights[i,y,j] + expected_loss += -log(pred[y]) * weights[i,y,j] end end end - expected_loss /= (width*height*num) - expected_grad /= (width*height*num) + scale = dims[op_dim] / prod(dims) + expected_loss *= scale + expected_grad *= scale + expected_grad = reshape(expected_grad, size(input)) @test -eps < state.loss - expected_loss < eps @@ -66,6 +82,11 @@ function test_softmax_loss_layer(backend::Backend, use_weights::Bool, T, eps) shutdown(backend, state) end +function test_softmax_loss_layer(backend::Backend, use_weights::Bool, T, eps) + for i = 2:5 + test_softmax_loss_layer(backend, i, use_weights, T, eps) + end +end function test_softmax_loss_layer(backend::Backend, T, eps) test_softmax_loss_layer(backend, false, T, eps) test_softmax_loss_layer(backend, true, T, eps) diff --git a/test/layers/softmax.jl b/test/layers/softmax.jl index 7c4d011..be35b93 100644 --- a/test/layers/softmax.jl +++ b/test/layers/softmax.jl @@ -1,35 +1,39 @@ -function test_softmax_layer(backend::Backend, n_input, T, eps) +function test_softmax_layer(backend::Backend, tensor_dim, n_input, T, eps) println("-- Testing SoftmaxLayer on $(typeof(backend)){$T}...") - tensor_dim = abs(rand(Int)) % 4 + 2 - println(" > $tensor_dim-dimensional tensor") + norm_dim = max(1, abs(rand(Int)) % tensor_dim) + println(" > $tensor_dim-dimensional input, normalize along dimension $norm_dim") dims = [abs(rand(Int,tensor_dim)) % 6 + 6 for i = 1:n_input] input = [rand(T, dims[i]...) for i = 1:n_input] input_blob = Blob[make_blob(backend, x) for x in input] diff_blob = Blob[NullBlob() for i = 1:n_input] - layer = SoftmaxLayer(tops=Array(Symbol,n_input), bottoms=Array(Symbol,n_input)) + layer = SoftmaxLayer(tops=Array(Symbol,n_input), bottoms=Array(Symbol,n_input), + dim=norm_dim-tensor_dim-1) state = setup(backend, layer, input_blob, diff_blob) forward(backend, state, input_blob) for i = 1:n_input - width, height, channels, num = get_whcn(input[i]) - canonical_input = reshape(input[i], (width, height, channels, num)) + my_dims = size(input[i]) + dim_pre = prod(my_dims[1:norm_dim-1]) + dim_prob = my_dims[norm_dim] + dim_post = prod(my_dims[norm_dim+1:end]) + + canonical_input = reshape(input[i], (dim_pre, dim_prob, dim_post)) output = similar(canonical_input) - for w = 1:width - for h = 1:height - for n = 1:num - preds = canonical_input[w, h, :, n] - preds -= maximum(preds) - preds = exp(preds) - preds /= sum(preds) - output[w, h, :, n] = preds - end + for x = 1:dim_pre + for y = 1:dim_post + preds = canonical_input[x,:,y] + preds -= maximum(preds) + preds = exp(preds) + preds /= sum(preds) + output[x,:,y] = preds end end + output = reshape(output, my_dims) got_output = zeros(T, size(output)) copy!(got_output, state.blobs[i]) @@ -39,6 +43,11 @@ function test_softmax_layer(backend::Backend, n_input, T, eps) shutdown(backend, state) end +function test_softmax_layer(backend::Backend, n_input, T, eps) + for td = 2:6 + test_softmax_layer(backend, td, n_input, T, eps) + end +end function test_softmax_layer(backend::Backend) test_softmax_layer(backend, 3, Float32, 1e-5) test_softmax_layer(backend, 3, Float64, 1e-10)