diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 8c0c0f511f2..2d55bc09344 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -312,10 +312,16 @@ class GraphBuilder { add_value_to_graph(fb_id, value); } - // Parse the inputs + // Parse the inputs, which will be tensors most of the time but can also be + // symints and tensorrefs (which will be the case if the original graph had) + // mutable buffers. for (const uint32_t fb_id : *flatbuffer_->input_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); - compute_graph_->set_input_tensor(ref); + if (compute_graph_->val_is_tensor(ref)) { + compute_graph_->set_input_tensor(ref); + } else { + compute_graph_->set_val_as_input(ref); + } } // Parse the operators @@ -354,10 +360,15 @@ class GraphBuilder { } } - // Parse the outputs + // Parse the outputs, which will be mostly tensors. For some reason, + // mutable buffers are shown to be returned in the fx.Graph but do not get + // returned by the delegate; this may be an implementation detail of how the + // executorch emitter handles mutable buffers. for (const uint32_t fb_id : *flatbuffer_->output_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); - compute_graph_->set_output_tensor(ref); + if (compute_graph_->val_is_tensor(ref)) { + compute_graph_->set_output_tensor(ref); + } } } }; @@ -401,6 +412,26 @@ bool maybe_resize_input( return should_resize; } +bool maybe_update_scalar_tensor( + ComputeGraph* graph, + const ValueRef ref, + executorch::aten::Tensor& scalar_tensor_src) { + const int32_t cur_val = graph->read_symint(ref); + int32_t scalar_tensor_val = 0; + exec_aten::ScalarType dtype = scalar_tensor_src.scalar_type(); + if (dtype == exec_aten::ScalarType::Int) { + scalar_tensor_val = *scalar_tensor_src.const_data_ptr(); + } else if (dtype == exec_aten::ScalarType::Long) { + scalar_tensor_val = int32_t(*scalar_tensor_src.const_data_ptr()); + } + bool was_updated = false; + if (scalar_tensor_val != cur_val) { + graph->set_symint(ref, scalar_tensor_val); + was_updated = true; + } + return was_updated; +} + void maybe_resize_output( ComputeGraph* graph, const size_t output_i, @@ -487,7 +518,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { Error err = compileModel(processed->data(), compute_graph); - // This backend does not need its processed data after compiling the model. + // This backend does not need its processed data after compiling the + // model. processed->Free(); if (err != Error::Ok) { @@ -508,13 +540,31 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { const size_t num_inputs = compute_graph->inputs().size(); bool should_propagate_resize = false; for (size_t i = 0; i < num_inputs; i++) { - bool was_resized = - maybe_resize_input(compute_graph, i, args[i]->toTensor()); - should_propagate_resize = should_propagate_resize || was_resized; - compute_graph->copy_into_staging( - compute_graph->inputs()[i].staging, - args[i]->toTensor().const_data_ptr(), - args[i]->toTensor().numel()); + const ValueRef iref = compute_graph->inputs()[i].value; + if (compute_graph->val_is_tensor(iref)) { + VK_CHECK_COND(args[i]->isTensor()); + bool was_resized = + maybe_resize_input(compute_graph, i, args[i]->toTensor()); + should_propagate_resize = should_propagate_resize || was_resized; + compute_graph->copy_into_staging( + compute_graph->inputs()[i].staging, + args[i]->toTensor().const_data_ptr(), + args[i]->toTensor().numel()); + } else if (compute_graph->val_is_symint(iref)) { + VK_CHECK_COND( + args[i]->isTensor(), + "Cannot handle symint arg to graph that is not derived from a " + "scalar tensor at the moment."); + bool was_updated = maybe_update_scalar_tensor( + compute_graph, iref, args[i]->toTensor()); + // Since symint inputs may impact tensor's sizes, trigger a resize if + // any symbolic integer shapes are updated. + should_propagate_resize = should_propagate_resize || was_updated; + } else { + VK_THROW( + "Could not handle input with type ", + compute_graph->get_val_type(iref)); + } } if (should_propagate_resize) { @@ -523,13 +573,21 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->execute(); for (size_t i = 0; i < compute_graph->outputs().size(); i++) { - maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor()); - // args holds inputs directly followed by outputs, so the i'th output - // for compute_graph corresponds to the (i + num_inputs)'th arg - compute_graph->copy_from_staging( - compute_graph->outputs()[i].staging, - args[num_inputs + i]->toTensor().mutable_data_ptr(), - args[num_inputs + i]->toTensor().numel()); + const ValueRef oref = compute_graph->outputs()[i].value; + if (compute_graph->val_is_tensor(oref)) { + VK_CHECK_COND(args[i]->isTensor()); + maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor()); + // args holds inputs directly followed by outputs, so the i'th output + // for compute_graph corresponds to the (i + num_inputs)'th arg + compute_graph->copy_from_staging( + compute_graph->outputs()[i].staging, + args[num_inputs + i]->toTensor().mutable_data_ptr(), + args[num_inputs + i]->toTensor().numel()); + } else { + VK_THROW( + "Could not handle output with type ", + compute_graph->get_val_type(oref)); + } } #ifdef ET_EVENT_TRACER_ENABLED diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 57cc5316612..397e3514153 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -555,6 +555,14 @@ class ComputeGraph final { int32_t read_symint(const ValueRef idx); + inline void set_val_as_input(const ValueRef idx) { + inputs_.push_back({idx, kDummyValueRef}); + } + + inline void set_val_as_output(const ValueRef idx) { + outputs_.push_back({idx, kDummyValueRef}); + } + /* * Convenience function to add an input tensor along with its staging buffer */