Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions aten/src/TH/generic/THTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -779,24 +779,41 @@ void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int
}
allContiguous = allContiguous && THTensor_(isContiguous)(result);

// First path is for contiguous inputs along dim 0
// First path is for contiguous inputs
// Second path for non-contiguous
int64_t offset;
if (dimension == 0 && allContiguous) {
if (allContiguous) {
int64_t outer = 1, inner = 1;

// Outer is the product of dimensions from the left up to (and not
// including the concatenation dimension). This becomes the number of times
// we have to replicate the memcpy call.
for (int i = 0; i < dimension; ++i) {
outer *= size[i];
}

// The product of dimensions to the right of the concatenation dimension.
// We go on to multiply this by the size of the concat dimension for
// each input tensor.
for (int i = dimension + 1; i < size.size(); ++i) {
inner *= size[i];
}

scalar_t* result_data = THStorage_(data)(THTensor_getStoragePtr(result)) + result->storage_offset();
offset = 0;
for (int j = 0; j < numInputs; j++) {
if (!should_skip(inputs[j])) {
THTensor* input0 = inputs[j];
scalar_t* input0_data = THStorage_(data)(THTensor_getStoragePtr(input0)) + input0->storage_offset();
int64_t input0_size = THTensor_(nElement)(input0);
// C standard says you can't pass nullptrs to memcpy, even if the size is 0; ubsan checks this.
if (input0_size != 0) {
memcpy(result_data + offset, input0_data, input0_size*sizeof(scalar_t));
}
offset += input0_size;
}
}
for (int o = 0; o < outer; ++o) {
for (int j = 0; j < numInputs; ++j) {
if (!should_skip(inputs[j])) {
THTensor* input0 = inputs[j];
scalar_t* input0_data = THStorage_(data)(THTensor_getStoragePtr(input0)) + input0->storage_offset();
int local_inner = inner * input0->size(dimension);
if (local_inner != 0) {
memcpy(result_data + offset, input0_data + o*local_inner, local_inner*sizeof(scalar_t));
} // input0_size != 0
offset += local_inner;
} // should_skip
} // for j
} // for i
} else {
offset = 0;
for (int j = 0; j < numInputs; j++) {
Expand Down