diff --git a/TensorMath.lua b/TensorMath.lua index 682de23a..5971a7b8 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -9,7 +9,7 @@ local argtypes = wrap.CInterface.argtypes argtypes['ptrdiff_t'] = { helpname = function(arg) - return 'ptrdiff_t' + return 'ptrdiff_t' end, declare = function(arg) @@ -35,7 +35,7 @@ argtypes['ptrdiff_t'] = { end end end, - + carg = function(arg) return string.format('arg%d', arg.i) end, @@ -43,13 +43,13 @@ argtypes['ptrdiff_t'] = { creturn = function(arg) return string.format('arg%d', arg.i) end, - + precall = function(arg) if arg.returned then return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i) end end, - + postcall = function(arg) if arg.creturned then return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i) @@ -738,11 +738,11 @@ wrap("topk", {{name=Tensor, default=true, returned=true}, {name=Tensor}, {name=Tensor}, - {name="index", default=lastdim(2)}}, + {name="index", default=-1}}, cname("catArray"), {{name=Tensor, default=true, returned=true}, {name=Tensor .. "Array"}, - {name="index", default=lastdimarray(2)}}) + {name="index", default=-1}}) if Tensor == 'ByteTensor' then -- we declare this only once interface:print( diff --git a/doc/maths.md b/doc/maths.md index 252b52dc..44e5ea6d 100755 --- a/doc/maths.md +++ b/doc/maths.md @@ -60,12 +60,14 @@ The advantage of second case is, same `res2` `Tensor` can be used successively i `x = torch.cat(x_1, x_2, [dimension])` returns a `Tensor` `x` which is the concatenation of `Tensor`s `x_1` and `x_2` along dimension `dimension`. -If `dimension` is not specified it is the last dimension. +If `dimension` is not specified or if it is `-1`, it is the maximum last dimension over all input tensors, except if all tensors are empty, then it is `1`. The other dimensions of `x_1` and `x_2` have to be equal. Also supports arrays with arbitrary numbers of `Tensor`s as inputs. +Empty tensors are ignored during catting, and thus do not throw an error. Performing cat on empty tensors only will always result in an empty tensor. + Examples: ```lua > torch.cat(torch.ones(3), torch.zeros(2)) @@ -116,6 +118,12 @@ Examples: 0.2206 0.7449 [torch.DoubleTensor of size 7x2] +> torch.cat({torch.Tensor(), torch.rand(3, 2)}, 1) + 0.3227 0.0493 + 0.9161 0.1086 + 0.2206 0.7449 +[torch.DoubleTensor of size 3x2] + ``` diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index e04d3b6d..9fc15773 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -2035,53 +2035,111 @@ void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int THLongStorage *size; int i, j; long offset; - int ndim = dimension + 1; + int maxDim = dimension + 1; + int allEmpty = 1; + int allContiguous = 1; + int ldimension = dimension; + for (i = 0; i < numInputs; i++) { - ndim = THMax(ndim, inputs[i]->nDimension); + maxDim = THMax(maxDim, inputs[i]->nDimension); + } + + // When the user input dimension is -1 (i.e. -2 in C) + // Then we pick the maximum last dimension across all tensors. + if ( dimension == -2 ) + { + ldimension = maxDim?(maxDim-1):0; } THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs); - THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE); + THArgCheck(ldimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE); - size = THLongStorage_newWithSize(ndim); - for(i = 0; i < ndim; i++) + size = THLongStorage_newWithSize(maxDim); + + for(i = 0; i < maxDim; i++) { - long dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : 1; - if (i == dimension) + // dimSize is either the size of the dim if it exists, either 1 if #dim > 0, otherwise 0 + long dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : THMin(inputs[0]->nDimension, 1); + if (i == ldimension) { for (j = 1; j < numInputs; j++) { - dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : 1; + // accumulate the size over the dimension we want to cat on. + // Empty tensors are allowed + dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1); + if(inputs[j]->nDimension) + { + allContiguous = allContiguous && THTensor_(isContiguous)(inputs[j]); + } } } else { for (j = 1; j < numInputs; j++) { - if (dimSize != (i < inputs[j]->nDimension ? inputs[j]->size[i] : 1)) + long sz = (i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1)); + // If it's a dimension we're not catting on + // Then fail if sizes are different AND > 0 + if (dimSize != sz && dimSize && sz) { THLongStorage_free(size); THError("inconsistent tensor sizes"); } + else if(!dimSize) + { + dimSize = sz; + } } } + allEmpty = allEmpty && !dimSize; size->data[i] = dimSize; } - THTensor_(resize)(result, size, NULL); - THLongStorage_free(size); - - offset = 0; - for (j = 0; j < numInputs; j++) + // Initiate catting and resizing + // If at least one of the input is not empty + if (!allEmpty) { - long dimSize = dimension < inputs[j]->nDimension ? inputs[j]->size[dimension] : 1; - THTensor *nt = THTensor_(newWithTensor)(result); - THTensor_(narrow)(nt, NULL, dimension, offset, dimSize); - THTensor_(copy)(nt, inputs[j]); - THTensor_(free)(nt); - offset += dimSize; + THTensor_(resize)(result, size, NULL); + + allContiguous = allContiguous && THTensor_(isContiguous)(result); + + // First path is for contiguous inputs along dim 1 + // Second path for non-contiguous + if (ldimension == 0 && allContiguous) + { + real* result_data = result->storage->data + result->storageOffset; + offset = 0; + for (j = 0; j < numInputs; j++) + { + if (inputs[j]->nDimension) + { + THTensor* input0 = inputs[j]; + real* input0_data = input0->storage->data + input0->storageOffset; + long input0_size = THTensor_(nElement)(input0); + memcpy(result_data + offset, input0_data, input0_size*sizeof(real)); + offset += input0_size; + } + } + } + else + { + offset = 0; + for (j = 0; j < numInputs; j++) + { + if (inputs[j]->nDimension) + { + long dimSize = ldimension < inputs[j]->nDimension ? inputs[j]->size[ldimension] : 1; + THTensor *nt = THTensor_(newWithTensor)(result); + THTensor_(narrow)(nt, NULL, ldimension, offset, dimSize); + THTensor_(copy)(nt, inputs[j]); + THTensor_(free)(nt); + offset += dimSize; + } + } + } } + THLongStorage_free(size); } int THTensor_(equal)(THTensor *ta, THTensor* tb) diff --git a/test/test.lua b/test/test.lua index 3eb119f0..eb7cf0ae 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1827,7 +1827,32 @@ function torchtest.cat() local mxx = torch.Tensor() torch.cat(mxx, x, y, dim) mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value') - end + + local x = torch.rand(1,2,3) + local y = torch.Tensor() + local mx = torch.cat(x,y,dim) + mytester:asserteq(mx:size(1),1,'torch.cat size') + mytester:asserteq(mx:size(2),2,'torch.cat size') + mytester:asserteq(mx:size(3),3,'torch.cat size') + mytester:assertTensorEq(mx, x, 0, 'torch.cat value') + + local x = torch.Tensor() + local y = torch.Tensor() + local mx = torch.cat(x,y,dim) + mytester:asserteq(mx:dim(),0,'torch.cat dim') + end + local x = torch.Tensor() + local y = torch.rand(1,2,3) + local mx = torch.cat(x,y) + mytester:asserteq(mx:size(1),1,'torch.cat size') + mytester:asserteq(mx:size(2),2,'torch.cat size') + mytester:asserteq(mx:size(3),3,'torch.cat size') + mytester:assertTensorEq(mx, y, 0, 'torch.cat value') + + local x = torch.Tensor() + local y = torch.Tensor() + local mx = torch.cat(x,y) + mytester:asserteq(mx:dim(),0,'torch.cat dim') end function torchtest.catArray() for dim = 1, 3 do @@ -1849,7 +1874,32 @@ function torchtest.catArray() mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value') torch.cat(mxx:double(), {x:double(), y:double(), z:double()}, dim) mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value') - end + + local x = torch.rand(1,2,3) + local y = torch.Tensor() + local mx = torch.cat({x,y},dim) + mytester:asserteq(mx:size(1),1,'torch.cat size') + mytester:asserteq(mx:size(2),2,'torch.cat size') + mytester:asserteq(mx:size(3),3,'torch.cat size') + mytester:assertTensorEq(mx, x, 0, 'torch.cat value') + + local x = torch.Tensor() + local y = torch.Tensor() + local mx = torch.cat({x,y},dim) + mytester:asserteq(mx:dim(),0,'torch.cat dim') + end + local x = torch.Tensor() + local y = torch.rand(1,2,3) + local mx = torch.cat({x,y}) + mytester:asserteq(mx:size(1),1,'torch.cat size') + mytester:asserteq(mx:size(2),2,'torch.cat size') + mytester:asserteq(mx:size(3),3,'torch.cat size') + mytester:assertTensorEq(mx, y, 0, 'torch.cat value') + + local x = torch.Tensor() + local y = torch.Tensor() + local mx = torch.cat({x,y}) + mytester:asserteq(mx:dim(),0,'torch.cat dim') end function torchtest.sin_2() local x = torch.rand(msize,msize,msize)