Skip to content

Commit

Permalink
Add a different code path for catting contiguous tensors along the fi…
Browse files Browse the repository at this point in the history
…rst dimension, for speed reasons.

Fix a bug in cat when catting with an empty tensor along first dim (it added an extra dim).
Fix the ambiguous 'catting along last dimension' sentence in the doc and change the behavior to pick the maximum last dimension over all input tensors.
Now empty tensors are allowed.
  • Loading branch information
nkoumchatzky committed Dec 26, 2016
1 parent 7ca7ec9 commit d41580e
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 29 deletions.
12 changes: 6 additions & 6 deletions TensorMath.lua
Expand Up @@ -9,7 +9,7 @@ local argtypes = wrap.CInterface.argtypes
argtypes['ptrdiff_t'] = {

helpname = function(arg)
return 'ptrdiff_t'
return 'ptrdiff_t'
end,

declare = function(arg)
Expand All @@ -35,21 +35,21 @@ argtypes['ptrdiff_t'] = {
end
end
end,

carg = function(arg)
return string.format('arg%d', arg.i)
end,

creturn = function(arg)
return string.format('arg%d', arg.i)
end,

precall = function(arg)
if arg.returned then
return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
end
end,

postcall = function(arg)
if arg.creturned then
return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
Expand Down Expand Up @@ -738,11 +738,11 @@ wrap("topk",
{{name=Tensor, default=true, returned=true},
{name=Tensor},
{name=Tensor},
{name="index", default=lastdim(2)}},
{name="index", default=-1}},
cname("catArray"),
{{name=Tensor, default=true, returned=true},
{name=Tensor .. "Array"},
{name="index", default=lastdimarray(2)}})
{name="index", default=-1}})

if Tensor == 'ByteTensor' then -- we declare this only once
interface:print(
Expand Down
10 changes: 9 additions & 1 deletion doc/maths.md
Expand Up @@ -60,12 +60,14 @@ The advantage of second case is, same `res2` `Tensor` can be used successively i
<a name="torch.cat"></a>
`x = torch.cat(x_1, x_2, [dimension])` returns a `Tensor` `x` which is the concatenation of `Tensor`s `x_1` and `x_2` along dimension `dimension`.

If `dimension` is not specified it is the last dimension.
If `dimension` is not specified or if it is `-1`, it is the maximum last dimension over all input tensors, except if all tensors are empty, then it is `1`.

The other dimensions of `x_1` and `x_2` have to be equal.

Also supports arrays with arbitrary numbers of `Tensor`s as inputs.

Empty tensors are ignored during catting, and thus do not throw an error. Performing cat on empty tensors only will always result in an empty tensor.

Examples:
```lua
> torch.cat(torch.ones(3), torch.zeros(2))
Expand Down Expand Up @@ -116,6 +118,12 @@ Examples:
0.2206 0.7449
[torch.DoubleTensor of size 7x2]

> torch.cat({torch.Tensor(), torch.rand(3, 2)}, 1)
0.3227 0.0493
0.9161 0.1086
0.2206 0.7449
[torch.DoubleTensor of size 3x2]

```


Expand Down
98 changes: 78 additions & 20 deletions lib/TH/generic/THTensorMath.c
Expand Up @@ -2035,53 +2035,111 @@ void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int
THLongStorage *size;
int i, j;
long offset;
int ndim = dimension + 1;
int maxDim = dimension + 1;
int allEmpty = 1;
int allContiguous = 1;
int ldimension = dimension;

for (i = 0; i < numInputs; i++)
{
ndim = THMax(ndim, inputs[i]->nDimension);
maxDim = THMax(maxDim, inputs[i]->nDimension);
}

// When the user input dimension is -1 (i.e. -2 in C)
// Then we pick the maximum last dimension across all tensors.
if ( dimension == -2 )
{
ldimension = maxDim?(maxDim-1):0;
}

THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
THArgCheck(ldimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);

size = THLongStorage_newWithSize(ndim);
for(i = 0; i < ndim; i++)
size = THLongStorage_newWithSize(maxDim);

for(i = 0; i < maxDim; i++)
{
long dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : 1;
if (i == dimension)
// dimSize is either the size of the dim if it exists, either 1 if #dim > 0, otherwise 0
long dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : THMin(inputs[0]->nDimension, 1);
if (i == ldimension)
{
for (j = 1; j < numInputs; j++)
{
dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : 1;
// accumulate the size over the dimension we want to cat on.
// Empty tensors are allowed
dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1);
if(inputs[j]->nDimension)
{
allContiguous = allContiguous && THTensor_(isContiguous)(inputs[j]);
}
}
}
else
{
for (j = 1; j < numInputs; j++)
{
if (dimSize != (i < inputs[j]->nDimension ? inputs[j]->size[i] : 1))
long sz = (i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1));
// If it's a dimension we're not catting on
// Then fail if sizes are different AND > 0
if (dimSize != sz && dimSize && sz)
{
THLongStorage_free(size);
THError("inconsistent tensor sizes");
}
else if(!dimSize)
{
dimSize = sz;
}
}
}
allEmpty = allEmpty && !dimSize;
size->data[i] = dimSize;
}

THTensor_(resize)(result, size, NULL);
THLongStorage_free(size);

offset = 0;
for (j = 0; j < numInputs; j++)
// Initiate catting and resizing
// If at least one of the input is not empty
if (!allEmpty)
{
long dimSize = dimension < inputs[j]->nDimension ? inputs[j]->size[dimension] : 1;
THTensor *nt = THTensor_(newWithTensor)(result);
THTensor_(narrow)(nt, NULL, dimension, offset, dimSize);
THTensor_(copy)(nt, inputs[j]);
THTensor_(free)(nt);
offset += dimSize;
THTensor_(resize)(result, size, NULL);

allContiguous = allContiguous && THTensor_(isContiguous)(result);

// First path is for contiguous inputs along dim 1
// Second path for non-contiguous
if (ldimension == 0 && allContiguous)
{
real* result_data = result->storage->data + result->storageOffset;
offset = 0;
for (j = 0; j < numInputs; j++)
{
if (inputs[j]->nDimension)
{
THTensor* input0 = inputs[j];
real* input0_data = input0->storage->data + input0->storageOffset;
long input0_size = THTensor_(nElement)(input0);
memcpy(result_data + offset, input0_data, input0_size*sizeof(real));
offset += input0_size;
}
}
}
else
{
offset = 0;
for (j = 0; j < numInputs; j++)
{
if (inputs[j]->nDimension)
{
long dimSize = ldimension < inputs[j]->nDimension ? inputs[j]->size[ldimension] : 1;
THTensor *nt = THTensor_(newWithTensor)(result);
THTensor_(narrow)(nt, NULL, ldimension, offset, dimSize);
THTensor_(copy)(nt, inputs[j]);
THTensor_(free)(nt);
offset += dimSize;
}
}
}
}
THLongStorage_free(size);
}

int THTensor_(equal)(THTensor *ta, THTensor* tb)
Expand Down
54 changes: 52 additions & 2 deletions test/test.lua
Expand Up @@ -1827,7 +1827,32 @@ function torchtest.cat()
local mxx = torch.Tensor()
torch.cat(mxx, x, y, dim)
mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
end

local x = torch.rand(1,2,3)
local y = torch.Tensor()
local mx = torch.cat(x,y,dim)
mytester:asserteq(mx:size(1),1,'torch.cat size')
mytester:asserteq(mx:size(2),2,'torch.cat size')
mytester:asserteq(mx:size(3),3,'torch.cat size')
mytester:assertTensorEq(mx, x, 0, 'torch.cat value')

local x = torch.Tensor()
local y = torch.Tensor()
local mx = torch.cat(x,y,dim)
mytester:asserteq(mx:dim(),0,'torch.cat dim')
end
local x = torch.Tensor()
local y = torch.rand(1,2,3)
local mx = torch.cat(x,y)
mytester:asserteq(mx:size(1),1,'torch.cat size')
mytester:asserteq(mx:size(2),2,'torch.cat size')
mytester:asserteq(mx:size(3),3,'torch.cat size')
mytester:assertTensorEq(mx, y, 0, 'torch.cat value')

local x = torch.Tensor()
local y = torch.Tensor()
local mx = torch.cat(x,y)
mytester:asserteq(mx:dim(),0,'torch.cat dim')
end
function torchtest.catArray()
for dim = 1, 3 do
Expand All @@ -1849,7 +1874,32 @@ function torchtest.catArray()
mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
torch.cat(mxx:double(), {x:double(), y:double(), z:double()}, dim)
mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
end

local x = torch.rand(1,2,3)
local y = torch.Tensor()
local mx = torch.cat({x,y},dim)
mytester:asserteq(mx:size(1),1,'torch.cat size')
mytester:asserteq(mx:size(2),2,'torch.cat size')
mytester:asserteq(mx:size(3),3,'torch.cat size')
mytester:assertTensorEq(mx, x, 0, 'torch.cat value')

local x = torch.Tensor()
local y = torch.Tensor()
local mx = torch.cat({x,y},dim)
mytester:asserteq(mx:dim(),0,'torch.cat dim')
end
local x = torch.Tensor()
local y = torch.rand(1,2,3)
local mx = torch.cat({x,y})
mytester:asserteq(mx:size(1),1,'torch.cat size')
mytester:asserteq(mx:size(2),2,'torch.cat size')
mytester:asserteq(mx:size(3),3,'torch.cat size')
mytester:assertTensorEq(mx, y, 0, 'torch.cat value')

local x = torch.Tensor()
local y = torch.Tensor()
local mx = torch.cat({x,y})
mytester:asserteq(mx:dim(),0,'torch.cat dim')
end
function torchtest.sin_2()
local x = torch.rand(msize,msize,msize)
Expand Down

0 comments on commit d41580e

Please sign in to comment.