Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def run(self):
"torch/csrc/jit/script/module.cpp",
"torch/csrc/jit/script/init.cpp",
"torch/csrc/jit/script/python_tree_views.cpp",
"torch/csrc/jit/batched/BatchTensor.cpp",
"torch/csrc/autograd/init.cpp",
"torch/csrc/autograd/aten_variable_hooks.cpp",
"torch/csrc/autograd/grad_mode.cpp",
Expand Down
17 changes: 16 additions & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from test_autograd import method_tests, create_input, unpack_variables, \
exclude_tensor_method, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
from copy import deepcopy

import random

from torch.jit.frontend import NotSupportedError

Expand Down Expand Up @@ -1079,6 +1079,21 @@ def test_fn(ten, mask):
self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))


class TestBatched(TestCase):
# generate random examples and create an batchtensor with them
def rand_batch(self, *dims):
dims = [dim for dim in dims if dim != ()]
xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:])) for i in range(dims[0])]
xb = torch.BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]))
return xs, xb

def test_create_batchtensor(self):
xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5))
self.assertEqual(xs, batch.examples())
batch2 = torch.BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
self.assertEqual(xs, batch2.examples())


class TestScript(JitTestCase):

@contextmanager
Expand Down
78 changes: 78 additions & 0 deletions torch/csrc/jit/batched/BatchTensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include "BatchTensor.h"

namespace torch { namespace jit {

BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims){
if(data.dim() != mask.dim() || mask.dim() != dims.size(0) + 1){
throw std::runtime_error("malformed MaskedBatch with data.dim(): "
+ std::to_string(data.dim()) + ", mask.dim(): " + std::to_string(mask.dim())
+ ", dims.size(0): " + std::to_string(dims.size(0)));
}
this->data = data;
this->mask = mask;
this->dims = dims;
}

BatchTensor::BatchTensor(const std::vector<at::Tensor> datalist, at::Tensor dims) {
auto bs = datalist.size();
std::vector<int64_t> sizes(dims.size(0) + 1, 0), mask_sizes(dims.size(0) + 1, 0);
sizes[0] = bs;
mask_sizes[0] = bs;
for(int64_t i = 1; i < dims.size(0) + 1; i++){
for(auto x : datalist){
sizes[i] = std::max(sizes[i], x.size(i));
}
mask_sizes[i] = *dims[i - 1].toByteData() ? sizes[i] : 1;
}
data = datalist[0].type().tensor(sizes);
data.fill_(0);
mask = datalist[0].type().toScalarType(at::kByte).tensor(mask_sizes);
mask.fill_(0);
for(std::size_t i = 0; i < datalist.size(); i++){
auto data_item = data.narrow(0, i, 1);
auto mask_item = mask.narrow(0, i, 1);
for(int64_t j = 0; j < dims.size(0); j++){
if(*dims[j].toByteData()){
data_item = data_item.narrow(j + 1, 0, datalist[i].size(j + 1));
mask_item = mask_item.narrow(j + 1, 0, datalist[i].size(j + 1));
}
}
data_item += datalist[i];
mask_item.fill_(1);
}
this->dims = dims;

This comment was marked as off-topic.

}

std::vector<at::Tensor> BatchTensor::examples() {

This comment was marked as off-topic.

std::vector<at::Tensor> result;
// calculate number of valid entries in dth dimension of data
auto mask_sum = [](at::Tensor data, int d) -> int64_t{
data = data.sum(d, /*keepdim=*/true);
while(data.dim() >= 1)
data = data[0];
return *data.toLongData();
};
for(int64_t i = 0; i < data.size(0); i++){
auto data_tmp = data.narrow(0, i, 1);
for(int64_t d = 0; d < dims.size(0); d++){
if(*dims[d].toByteData()){
data_tmp = data_tmp.narrow(d + 1, 0, mask_sum(mask[i], d));
}
}
result.push_back(data_tmp);
}
return result;
}

void initBatchTensorBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<BatchTensor>(m, "BatchTensor")
.def(py::init<at::Tensor, at::Tensor, at::Tensor>())
.def(py::init<std::vector<at::Tensor>, at::Tensor>())
.def("examples", &BatchTensor::examples)
.def("get_data", &BatchTensor::get_data)
.def("get_mask", &BatchTensor::get_mask)
.def("get_dims", &BatchTensor::get_dims);
}

}} // namespace torch::jit
51 changes: 51 additions & 0 deletions torch/csrc/jit/batched/BatchTensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#pragma once
#include "ATen/Tensor.h"
#include "torch/csrc/jit/pybind.h"
#include "ATen/ATen.h"
#include <iostream>
#include <vector>

namespace torch { namespace jit {
struct BatchTensor {
public:
BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims);
BatchTensor(const std::vector<at::Tensor> datalist, at::Tensor dims);
~BatchTensor(){};
const char * toString() const {
return "BatchTensor";
}
at::IntList sizes() const {
return data.sizes();
}
int64_t dim() const {
return data.dim();
}
std::vector<at::Tensor> examples();
at::Tensor get_data(){
return data;
}
at::Tensor get_mask(){
return mask;
}
at::Tensor get_dims(){
return dims;
}

public:
// data is a Tensor whose size is the batch size in the batch dimension,
// the size of all examples in static dimensions,
// and at least as large as the largest example in the batch in dynamic dimensions.
at::Tensor data;
// mask is a Tensor whose size is the batch size in the batch dimension,
// one in static dimensions,
// and at least as large as the largest example in the batch in dynamic dimensions.
// Each entry in the mask corresponds to one or more entries in the data array (singleton, i.e., static, dimensions are broadcasted),
// with a one in the mask denoting that the corresponding data entries represent valid, meaningful data and a zero denoting that they do not.
at::Tensor mask;
// dims is a 1-dimensional tensor with a bool for each non-batch dimension,
// representing whether that dimension is static (False) or dynamic (True).
at::Tensor dims;
};

void initBatchTensorBindings(PyObject* module);
}} // namespace torch::jit
2 changes: 2 additions & 0 deletions torch/csrc/jit/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/jit/script/init.h"
#include "torch/csrc/jit/script/python_tree_views.h"
#include "torch/csrc/jit/batched/BatchTensor.h"
#include "torch/csrc/jit/python_interpreter.h"
#include "torch/csrc/jit/pybind_utils.h"

Expand Down Expand Up @@ -200,6 +201,7 @@ void initJITBindings(PyObject *module) {
tracer::initPythonTracerBindings(module);
script::initTreeViewBindings(module);
script::initJitScriptBindings(module);
initBatchTensorBindings(module);
registerPythonInterpreterOps();
}

Expand Down