diff --git a/benchmarks/jit_tensorwise.py b/benchmarks/jit_tensorwise.py new file mode 100644 index 00000000..c18fc677 --- /dev/null +++ b/benchmarks/jit_tensorwise.py @@ -0,0 +1,68 @@ +import torch +import nestedtensor +import utils +import time + + +@nestedtensor._C.jit_tensorwise() +@torch.jit.script +def f(i, w): + return torch.conv2d(i, w) + +def loop_f(inp1, w): + for inp in inp1: + torch.conv2d(inp, w) + + +if __name__ == "__main__": + w = torch.randn(64, 3, 9, 9).cuda() + inp1 = list(torch.randn(128, 1, 3, 16, 16).cuda().unbind()) + inp3 = nestedtensor.as_nested_tensor(inp1)._impl + # print(sum(inp.numel() for inp in inp1)) + # print(inp3.numel()) + + fc = nestedtensor._C.jit_tensorwise()(torch.conv2d) + + t0 = time.time() + count = 0 + while(time.time() - t0 < 5.0): + r2 = fc(inp3, w) + torch.cuda.synchronize() + count += 1 + print("jit: " + str(count)) + + t0 = time.time() + count = 0 + while(time.time() - t0 < 5.0): + loop_f(inp1, w) + torch.cuda.synchronize() + count += 1 + print("for loop: " + str(count)) + + + # print(r.nested_size()) + + # na = nestedtensor._C.jit_tensorwise()(torch.mul) + + # print("111") + # out = nestedtensor.as_nested_tensor([torch.randn(1, 2)]) + # print(na( + # nestedtensor.as_nested_tensor([torch.randn(1, 2)])._impl, + # 4.0, + # )) + # print("222") + # print('out') + # print(out) + + # nv = nestedtensor._C.jit_tensorwise()(torch.mv) + # print(nv( + # nestedtensor._C._ListNestedTensor([torch.randn(1, 2)]), + # nestedtensor._C._ListNestedTensor([torch.randn(2)]), + # )) + + # print("333") + # print(na( + # torch.randn(1, 2), + # torch.randn(1, 2), + # )) + # print("444") diff --git a/benchmarks/nearest_neighbors.py b/benchmarks/nearest_neighbors.py index be8cf5a9..26bd3f05 100644 --- a/benchmarks/nearest_neighbors.py +++ b/benchmarks/nearest_neighbors.py @@ -1,11 +1,11 @@ -from nestedtensor import torch import nestedtensor +import torch import argparse import time import random import pprint -EMBED_DIM = 1024 +EMBED_DIM = 128 SEED = 0 @@ -60,8 +60,8 @@ def gen_algorithm_nested_mv(keys, sub_clusters): for sub_cluster in sub_clusters: new_sub_cluster = [torch.tensor(list(map(list, cluster))) for cluster in sub_cluster] new_sub_clusters.append(new_sub_cluster) - nested_sub_clusters = torch.nested_tensor(sub_clusters).to_tensor(2) - nested_keys = torch.nested_tensor(keys) + nested_sub_clusters = nestedtensor.nested_tensor(sub_clusters).to_tensor(2) + nested_keys = nestedtensor.nested_tensor(keys) def _nested_mv(): return torch.mv(nested_sub_clusters, nested_keys) return _nested_mv @@ -74,18 +74,16 @@ def gen_algorithm_nested_jit_mv(keys, sub_clusters): for cluster in sub_cluster: new_sub_cluster.append(torch.stack(cluster)) new_sub_clusters.append(new_sub_cluster) - nested_sub_clusters = nestedtensor._ListNestedTensor(new_sub_clusters) - print("HERE") - print(nested_sub_clusters.nested_size()) - nested_keys = nestedtensor._ListNestedTensor(keys) - print(nested_keys.nested_size()) + nested_sub_clusters = nestedtensor.as_nested_tensor(new_sub_clusters) + nested_keys = nestedtensor.as_nested_tensor(keys) + @nestedtensor._C.jit_tensorwise() @torch.jit.script def my_fun(x, y): return torch.mv(x, y) def _nested_jit_mv(): - return nestedtensor._C.jit_apply_function((nested_sub_clusters, nested_keys), my_fun) + return my_fun(nested_sub_clusters, nested_keys) return _nested_jit_mv @@ -139,12 +137,12 @@ def benchmark_fn(fn, run_time = 15.0): gen_results_naive = gen_algorithm_naive(keys, sub_clusters) gen_results_mv = gen_algorithm_mv(keys, sub_clusters) gen_results_nested_mv = gen_algorithm_nested_mv(keys, sub_clusters) - gen_results_nested_jit_mv = gen_algorithm_nested_jit_mv(keys, sub_clusters) + # gen_results_nested_jit_mv = gen_algorithm_nested_jit_mv(keys, sub_clusters) - # print(benchmark_fn(gen_results_naive)) - # print(benchmark_fn(gen_results_mv)) - # print(benchmark_fn(gen_results_nested_mv)) - print(benchmark_fn(gen_results_nested_jit_mv)) + print(benchmark_fn(gen_results_nested_mv)) + print(benchmark_fn(gen_results_naive)) + print(benchmark_fn(gen_results_mv)) + # print(benchmark_fn(gen_results_nested_jit_mv)) # import cProfile, pstats, io # pr = cProfile.Profile() # pr.enable() diff --git a/nestedtensor/csrc/buffer_nested_tensor.cpp b/nestedtensor/csrc/buffer_nested_tensor.cpp index 6026a3c1..45123a05 100644 --- a/nestedtensor/csrc/buffer_nested_tensor.cpp +++ b/nestedtensor/csrc/buffer_nested_tensor.cpp @@ -71,8 +71,8 @@ std::pair _build_structure( for (size_t i = 0; i < nested_size.degree(); i++) { std::pair result_i = _build_structure( index, buffers, nested_size.children(i), nested_stride.children(i)); + index = std::get<0>(result_i); result.push_back(std::get<1>(result_i)); - index++; } return std::pair(index, TensorNode(result)); } diff --git a/nestedtensor/csrc/buffer_nested_tensor.h b/nestedtensor/csrc/buffer_nested_tensor.h index 70ceba3d..62820cec 100644 --- a/nestedtensor/csrc/buffer_nested_tensor.h +++ b/nestedtensor/csrc/buffer_nested_tensor.h @@ -125,8 +125,11 @@ struct TORCH_API _BufferNestedTensor { new_size.push_back(start->degree()); start = start->children_data(0); } - for (size_t i = 0; i < start->payload(0).size(); i++) { - new_size.push_back(start->payload(0)[i]); + new_size.push_back(start->size()); + if (start->size() > 0) { + for (size_t i = 0; i < start->payload(0).size(); i++) { + new_size.push_back(start->payload(0)[i]); + } } return _buffer.reshape(at::IntArrayRef(new_size)); } diff --git a/nestedtensor/csrc/jit_list_apply.cpp b/nestedtensor/csrc/jit_list_apply.cpp index 6c99a8c3..cf7f3431 100644 --- a/nestedtensor/csrc/jit_list_apply.cpp +++ b/nestedtensor/csrc/jit_list_apply.cpp @@ -1,45 +1,66 @@ #include +#include +#include namespace torch { namespace nested_tensor { +namespace py = pybind11; + using namespace torch::jit; +using namespace torch::jit::script; + +// TODO Expand to IValues to support generic lists? +at::Tensor run_function(Stack&& stack, Function& fn) { + fn(stack); + return std::move(stack.front().toTensor()); +} + +at::Tensor run_function(Stack&& stack, Operation& fn) { + fn(stack); + return std::move(stack.front().toTensor()); +} +// TODO: Assert that one arg must be a nestedtensor? +template static TensorNode apply_jit_function( - const std::vector& nested_nodes, - Function& fn) { + Stack& stack_template, + const std::set& tensor_node_i, + const std::vector& tensor_nodes, + F& fn) { bool all_leaf = true; - for (size_t i = 0; i < nested_nodes.size(); i++) { - all_leaf = all_leaf && nested_nodes[i].is_leaf(); + for (const auto& node : tensor_nodes) { + all_leaf = all_leaf && node.is_leaf(); } if (all_leaf) { - // NOTE: Assuming this is a pure function not a method (no self!) - // NOTE: We assume there is only one Tensor inputs. // NOTE: We assume no named tensors and no sparse variables as - // appropriate - // for TorchScript. NOTE: We know the IValues of the argument, there is - // no - // need to cast around. - c10::List result; - for (size_t j = 0; j < nested_nodes[0].size(); j++) { - Stack stack; - for (size_t i = 0; i < nested_nodes.size(); i++) { - push(stack, nested_nodes[i].payload(j)); + // appropriate for TorchScript. + // TODO: Assert leaf sizes match and are non-zero, otherwise this isn't + // a NestedTensor function. + size_t leaf_size = tensor_nodes[0].size(); + c10::List results; + results.reserve(leaf_size); + for (size_t j = 0; j < leaf_size; j++) { + Stack stack(stack_template); + size_t ni = 0; + for (size_t i = 0; i < stack.size(); i++) { + if (tensor_node_i.count(i)) { + stack[i] = tensor_nodes[ni].payload(j); + ni++; + } } - fn.run(stack); - result.push_back(stack.back().toTensor()); + results.push_back(run_function(std::move(stack), fn)); } - return TensorNode(result); + return TensorNode(results); } else { bool broadcastable = true; size_t num_children = 0; - for (size_t i = 0; i < nested_nodes.size(); i++) { - if (!nested_nodes[i].is_leaf()) { + for (const auto& node : tensor_nodes) { + if (!node.is_leaf()) { if (num_children > 0) { - broadcastable = - broadcastable && (num_children == nested_nodes[i].degree()); + broadcastable = broadcastable && (num_children == node.degree()); } else { - num_children = nested_nodes[i].degree(); + num_children = node.degree(); } } } @@ -47,42 +68,206 @@ static TensorNode apply_jit_function( std::vector result; for (size_t i = 0; i < num_children; i++) { std::vector local_args; - for (size_t j = 0; j < nested_nodes.size(); j++) { - if (nested_nodes[j].is_leaf()) { - local_args.push_back(nested_nodes[j]); + for (const auto& node : tensor_nodes) { + if (node.is_leaf()) { + local_args.push_back(node); } else { - local_args.push_back(nested_nodes[j].children(i)); + local_args.push_back(node.children(i)); } } - result.push_back(apply_jit_function(local_args, fn)); + result.push_back( + apply_jit_function(stack_template, tensor_node_i, local_args, fn)); } return TensorNode(result); } } -THPNestedTensor jit_apply_function( - std::vector nts_, - py::object fn) { - std::vector<_ListNestedTensor> nts; - for (size_t i = 0; i < nts_.size(); i++) { - nts.push_back(nts_[i].data().left()); + +c10::optional is_builtin(py::object fn) { + py::object builtin_name = + py::module::import("torch.jit").attr("_find_builtin")(fn); + Symbol name = c10::Symbol::fromQualString(py::str(builtin_name)); + + // TODO: Is there a cheaper way to do this? + const auto& variants = getAllOperatorsFor(name); + if (variants.size() == 0) { + return c10::nullopt; + } + return name; +} + +c10::optional try_nested_node( + Argument argument, + py::object py_arg) { + InferredType inferred_type = tryToInferType(py_arg); + // Nestedtensor must not be a valid IValue + if (inferred_type.success()) { + return c10::nullopt; + } + if (argument.type()->kind() == TypeKind::TensorType && + py::isinstance(py_arg)) { + TensorNode node = py::cast(py_arg).get_structure(); + return node; + } + return c10::nullopt; +} + +inline c10::optional< + std::tuple, std::vector>> +my_createStackForSchema( + const FunctionSchema& schema, + const tuple_slice& args, + const py::kwargs& kwargs, + c10::optional self) { + size_t all_arguments = (self ? 1 : 0) + args.size() + kwargs.size(); + if (all_arguments > schema.arguments().size()) { + return c10::nullopt; + } + Stack stack; + stack.reserve(schema.arguments().size()); + + std::set tensor_node_i; + std::vector tensor_nodes; + + if (self) { + // NOTE: self cannot be a NestedTensor because it cannot be an ivalue. + push(stack, std::move(*self)); } - auto sfn = py::cast(fn); - auto tracing_state = tracer::getTracingState(); - TORCH_CHECK(!tracing_state, "doesnt support tracing"); - Function& callee = *sfn.function_; - auto schema = callee.getSchema(); - TORCH_CHECK( - schema.arguments().size() == nts.size(), - "Give NestedTensors don't match function args."); - std::vector nested_nodes; - for (size_t i = 0; i < nts.size(); i++) { - nested_nodes.push_back(nts[i].get_structure()); + // First push all positional args. + for (size_t i = 0; i < args.size(); i++) { + // Use the type information from the schema to convert the PyObject. + const auto& schema_arg = schema.arguments()[i]; + if (auto tensor_node = try_nested_node(schema_arg, args[i])) { + tensor_nodes.push_back(*tensor_node); + tensor_node_i.insert(stack.size()); + push(stack, torch::jit::IValue(torch::zeros({}))); + } else { + // TODO: This is expensive because argumentToIValue constructs an error + // message. + try { + push(stack, argumentToIValue(schema, i, args[i])); + } catch (const std::runtime_error& e) { + return c10::nullopt; + } + } + } + + // Now for every remaining non-positional argument in the schema, look for it + // in the kwargs dict and push it if found, or use its default value if it + // has one. + size_t consumed_kwargs = 0; + for (size_t i = stack.size(); i < schema.arguments().size(); ++i) { + const auto& schema_arg = schema.arguments()[i]; + if (kwargs.contains(schema_arg.name().c_str())) { + auto kwarg = kwargs[schema_arg.name().c_str()]; + if (auto tensor_node = try_nested_node(schema_arg, kwarg)) { + tensor_nodes.push_back(*tensor_node); + tensor_node_i.insert(stack.size()); + push(stack, torch::jit::IValue(torch::zeros({}))); + } else { + // TODO: This is expensive because argumentToIValue constructs an error + // message. + try { + push(stack, argumentToIValue(schema, i, kwarg)); + } catch (const std::runtime_error& e) { + return c10::nullopt; + } + } + consumed_kwargs += 1; + } else if (schema_arg.default_value()) { + push(stack, *schema_arg.default_value()); + } else { + return c10::nullopt; + } } - py::gil_scoped_release release; - TensorNode nested_node = apply_jit_function(nested_nodes, callee); - py::gil_scoped_acquire acquire; - return THPNestedTensor(_ListNestedTensor(nested_node)); + + if (consumed_kwargs != kwargs.size()) { + std::vector names; + for (const auto& kwarg : kwargs) { + names.emplace_back(py::cast(kwarg.first)); + } + try { + schema.findErrorInKwargs(names); + } catch (const std::runtime_error& e) { + return c10::nullopt; + } + } + + return std::make_tuple(stack, tensor_node_i, tensor_nodes); } +// TODO: This should support 3 types of functions +// fn might be scripted (i.e. StrongFunctionPtr) +// fn might be a builtin (need to resolve!) +// fn might be neither, so we just dispatch to some regular python for-loops +// (not fast!) +// TODO: Support for no NestedTensor arguments +// NOTE: For now this is a private function +py::cpp_function jit_tensorwise() { + return py::cpp_function([](py::object fn) { + return py::cpp_function([fn](py::args args, py::kwargs kwargs) { + if (py::isinstance(fn)) { + auto sfn = py::cast(fn); + Function& operation = *sfn.function_; + if (auto pack = my_createStackForSchema( + operation.getSchema(), args, kwargs, c10::nullopt)) { + py::gil_scoped_release release; + THPNestedTensor result = + THPNestedTensor(_ListNestedTensor(apply_jit_function( + std::get<0>(*pack), + std::get<1>(*pack), + std::get<2>(*pack), + operation))); + return result; + } + } + if (auto name = is_builtin(fn)) { + // TODO: Why doesn't argumentToIValue deal with NoneType for a kwarg? + // See also + // https://github.com/pytorch/pytorch/blob/7d630278daee00ea2db6bc01e8a2a5f160bd8e81/torch/csrc/jit/pybind_utils.h#L778 + // If out is NoneType for a builtin we'll simply remove it. + bool out_is_none = false; + for (const auto& kwarg : kwargs) { + if (py::cast(kwarg.first) == "out") { + auto inferred_type = tryToInferType(kwarg.second); + if (inferred_type.success() && + inferred_type.type()->kind() == TypeKind::NoneType) { + out_is_none = true; + } + } + } + if (out_is_none) { + py::dict new_kwargs; + for (const auto& kwarg : kwargs) { + if (py::cast(kwarg.first) == "out") { + continue; + } + new_kwargs[kwarg.first] = kwarg.second; + } + kwargs = py::kwargs(new_kwargs); + } + for (std::shared_ptr op : getAllOperatorsFor(*name)) { + if (auto pack = my_createStackForSchema( + op->schema(), args, kwargs, c10::nullopt)) { + auto operation = op->getOperation(); + py::gil_scoped_release release; + THPNestedTensor result = + THPNestedTensor(_ListNestedTensor(apply_jit_function( + std::get<0>(*pack), + std::get<1>(*pack), + std::get<2>(*pack), + operation))); + return result; + } + } + } + // TODO: Need implementation of generic python version. + std::stringstream ss; + ss << "FAIL! Can't find something for " << fn; + TORCH_CHECK(false, ss.str()); + TensorNode result; + return THPNestedTensor(_ListNestedTensor(result)); + }); + }); +} } // namespace nested_tensor } // namespace torch diff --git a/nestedtensor/csrc/jit_list_apply.h b/nestedtensor/csrc/jit_list_apply.h index bab595e1..defc5e6f 100644 --- a/nestedtensor/csrc/jit_list_apply.h +++ b/nestedtensor/csrc/jit_list_apply.h @@ -1,10 +1,9 @@ -#pragma once #include namespace torch { namespace nested_tensor { -THPNestedTensor jit_apply_function( - std::vector nts_, - py::object fn); -} + +pybind11::cpp_function jit_tensorwise(); + +} // namespace nested_tensor } // namespace torch diff --git a/nestedtensor/csrc/py_init.cpp b/nestedtensor/csrc/py_init.cpp index 36ec3dab..c7bc4f46 100644 --- a/nestedtensor/csrc/py_init.cpp +++ b/nestedtensor/csrc/py_init.cpp @@ -110,7 +110,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("__str__", &torch::nested_tensor::THPNestedTensor::str) .def("__repr__", &torch::nested_tensor::THPNestedTensor::str); - m.def("jit_apply_function", &torch::nested_tensor::jit_apply_function); + // NOTE: This is a private function until it is feature complete + m.def("_jit_tensorwise", &torch::nested_tensor::jit_tensorwise); m.def("as_nested_tensor", &torch::nested_tensor::as_nested_tensor); m.def("nested_tensor", &torch::nested_tensor::nested_tensor); } diff --git a/nestedtensor/csrc/python_nested_tensor.h b/nestedtensor/csrc/python_nested_tensor.h index 56805e79..a8e01285 100644 --- a/nestedtensor/csrc/python_nested_tensor.h +++ b/nestedtensor/csrc/python_nested_tensor.h @@ -94,6 +94,10 @@ struct THPNestedTensor { return data_map( _data, [](auto data) { return data.is_contiguous(); }); } + TensorNode get_structure() { + return data_map( + _data, [](auto data) { return data.get_structure(); }); + } private: c10::either<_ListNestedTensor, _BufferNestedTensor> _data; diff --git a/nestedtensor/nested/monkey_patch.py b/nestedtensor/nested/monkey_patch.py index 6bbe308d..036a5c56 100644 --- a/nestedtensor/nested/monkey_patch.py +++ b/nestedtensor/nested/monkey_patch.py @@ -12,8 +12,10 @@ def monkey_patch(NestedTensor): from nestedtensor.nested import functions import torch from nestedtensor.nested import utils + from nestedtensor import _C function_dispatch = {} + jit_function_dispatch = {} def _check_meaningful_overwrite(cls, method_name): import os @@ -34,6 +36,10 @@ def set_wrapped_torch_function(function_name, wrapper): function_dispatch[getattr(torch, function_name)] = wrapper( getattr(torch, function_name)) + def set_wrapped_jit_torch_function(function_name, wrapper): + jit_function_dispatch[getattr(torch, function_name)] = wrapper( + getattr(torch, function_name)) + def set_function(key, function): function_dispatch[key] = function @@ -82,7 +88,7 @@ def set_function(key, function): set_nt_method(function_name + '_', utils.tensorwise()) if function_name in ['fill']: continue - set_wrapped_torch_function(function_name, utils.tensorwise()) + set_wrapped_jit_torch_function(function_name, _C._jit_tensorwise()) set_nt_method(function_name, utils.tensorwise()) # < @@ -222,3 +228,4 @@ def set_function(key, function): # module.NestedTensor = NestedTensor setattr(NestedTensor, '_NestedTensor__function_dispatch', function_dispatch) + setattr(NestedTensor, '_NestedTensor__jit_function_dispatch', jit_function_dispatch) diff --git a/nestedtensor/nested/nested.py b/nestedtensor/nested/nested.py index 637f1244..d1ffbbdc 100644 --- a/nestedtensor/nested/nested.py +++ b/nestedtensor/nested/nested.py @@ -320,9 +320,16 @@ def nested_stride(self, dim=None): def __torch_function__(self, func, args=(), kwargs=None): _local_func = None + if kwargs is None: + kwargs = {} + if func in NestedTensor.__jit_function_dispatch: + _jit_local_func = NestedTensor.__jit_function_dispatch[func] + impl_args = [a._impl if isinstance(a, NestedTensor) else a for a in args] + impl_kwargs = {k: v._impl if isinstance(v, NestedTensor) else v for (k, v) in kwargs.items()} + return NestedTensor(_jit_local_func(*impl_args, **impl_kwargs)) if func in NestedTensor.__function_dispatch: _local_func = NestedTensor.__function_dispatch[func] - return _local_func(*args) if kwargs is None else _local_func(*args, **kwargs) + return _local_func(*args, **kwargs) raise NotImplementedError("NestedTensor doesn't support function {}".format(func)) def __bool__(self): diff --git a/nestedtensor/nested/utils.py b/nestedtensor/nested/utils.py index 7ab3bd27..5e971c9f 100644 --- a/nestedtensor/nested/utils.py +++ b/nestedtensor/nested/utils.py @@ -163,9 +163,8 @@ def match_type_signature_prefix(types, args): # and calls f tensor-wise # Make nested_stride optional (cont. by default) # Return flattened tensor pairs, then create _BufferNestedTensor impl directly - - def tensorwise(unbind_args=None, dim_args=None, wrap_dim_args=True): + if unbind_args is None: unbind_args = [] if dim_args is None: @@ -176,7 +175,6 @@ def wrapper(f): def decorator(*_args, **_kwargs): def _func(*args, **kwargs): if find_nested_tensor_dispatch_key(*args) is None: - # import pdb; pdb.set_trace() result = f(*args, **kwargs) if not torch.is_tensor(result): return tuple(result) diff --git a/nestedtensor/version.py b/nestedtensor/version.py index 44ec9fff..9eefa823 100644 --- a/nestedtensor/version.py +++ b/nestedtensor/version.py @@ -1,5 +1,5 @@ -__version__ = '0.0.1.dev20201919+d8734b8' -git_version = 'd8734b84bcdd5dd1c74b2c1f48f8c890c783925a' +__version__ = '0.0.1.dev20201122+00d5796' +git_version = '00d579661b2e93046c7666d752ca8e0e063e2f9d' from nestedtensor import _C if hasattr(_C, 'CUDA_VERSION'): cuda = _C.CUDA_VERSION