Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions benchmarks/static_runtime/test_static_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1305,3 +1305,67 @@ TEST(AssignStorageToManagedTensors, MultipleUnused) {
testAssignStorageToManagedTensors(
src, std::move(managed_tensor_name_to_tensor), min_reused_tensors);
}

namespace {
void testStaticModuleThrows(
const std::string& src,
const std::vector<IValue>& args,
const std::unordered_map<std::string, IValue>& kwargs) {
auto static_module = makeStaticModuleFromScript(src);
EXPECT_THROW(static_module(args, kwargs), c10::Error);
}
} // namespace

TEST(StaticModule, IncorrectTypesPassed) {
const std::string args_bool_script = R"JIT(
def forward(self, x: bool):
return x
)JIT";
testStaticModuleThrows(args_bool_script, {at::randn({1})}, {});

const std::string args_tensor_script = R"JIT(
def forward(self, x: Tensor):
return x
)JIT";
testStaticModuleThrows(args_tensor_script, {false}, {});

const std::string kwargs_int_script = R"JIT(
def forward(self, x: bool = True):
return x
)JIT";
testStaticModuleThrows(kwargs_int_script, {}, {{"x", at::randn({1})}});

const std::string kwargs_tensor_script = R"JIT(
def forward(self, x: Tensor = torch.randn((1, ))):
return x
)JIT";
testStaticModuleThrows(kwargs_tensor_script, {}, {{"x", 1.0}});
}

TEST(StaticModule, TooManyArgs) {
const std::string args_src = R"JIT(
def forward(self, x: int):
return x
)JIT";
testStaticModuleThrows(args_src, {0, 1}, {});

const std::string kwargs_src = R"JIT(
def forward(self, x: int = 1):
return x
)JIT";
testStaticModuleThrows(kwargs_src, {}, {{"y", 0}, {"x", 1}});
}

TEST(StaticModule, NotEnoughArgs) {
const std::string args_src = R"JIT(
def forward(self, x: int):
return x
)JIT";
testStaticModuleThrows(args_src, {}, {});

const std::string kwargs_src = R"JIT(
def forward(self, *, x: int):
return x
)JIT";
testStaticModuleThrows(kwargs_src, {}, {});
}
159 changes: 82 additions & 77 deletions torch/csrc/jit/runtime/static/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void OptimizeGraph(
}

// remove unused input 0 from graph
bool RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
bool removeSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
if (graph->inputs().at(0)->type()->is_module()) {
if (graph->inputs().at(0)->hasUses()) {
return false;
Expand All @@ -146,13 +146,6 @@ bool RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
return true;
}

// remove "self" from function schema
c10::FunctionSchema RemoveSelfFromSchema(const c10::FunctionSchema& s) {
TORCH_CHECK(s.arguments().size() >= 1 && s.arguments()[0].name() == "self");
std::vector<Argument> args({s.arguments().begin() + 1, s.arguments().end()});
return s.cloneWithArguments(args);
}

std::vector<Value*> valueVecFromFastSet(const FastSet<const Value*>& s) {
std::vector<Value*> result;
result.reserve(s.size());
Expand Down Expand Up @@ -445,7 +438,8 @@ StaticModule::StaticModule(
const StaticModuleOptions& opts)
: opts_(opts),
graph_(std::move(graph_and_module.first)),
module_(std::move(graph_and_module.second)) {
module_(std::move(graph_and_module.second)),
num_inputs_(graph_->inputs().size()) {
// check opt flags
if (opts.manage_output_tensors) {
TORCH_CHECK(
Expand All @@ -461,11 +455,12 @@ StaticModule::StaticModule(
// handle schema
if (module_.has_value()) {
Method method = module_->get_method("forward");
if (RemoveSelfFromGraphInput(graph_)) {
schema_ = RemoveSelfFromSchema(method.function().getSchema());
schema_ = method.function().getSchema();
const auto num_schema_args = schema_->arguments().size();
DCHECK(num_schema_args > 0);
if (removeSelfFromGraphInput(graph_)) {
module_ = c10::nullopt;
} else {
schema_ = method.function().getSchema();
num_inputs_ = num_schema_args - 1;
}
}

Expand Down Expand Up @@ -697,7 +692,7 @@ size_t StaticModule::num_outputs() const {
}

size_t StaticModule::num_inputs() const {
return graph_->inputs().size();
return num_inputs_;
}

StaticRuntime& StaticModule::runtime() {
Expand Down Expand Up @@ -730,6 +725,7 @@ c10::IValue StaticModule::operator()(

StaticRuntime::StaticRuntime(const StaticModule& sm)
: static_module_(sm),
first_input_is_self_(static_module_.first_input_is_self()),
manage_output_tensors_enabled_(sm.opts().manage_output_tensors),
nodes_(sm.nodes()) {
values_.resize(sm.total_num_values());
Expand All @@ -751,81 +747,90 @@ StaticRuntime::StaticRuntime(const StaticModule& sm)

StaticRuntime::~StaticRuntime() = default;

void StaticRuntime::set_inputs(
const std::vector<IValue>& args,
const KeywordArgs& kwargs) {
if (!kwargs.empty()) {
// This is not ideal
TORCH_CHECK(
static_module_.schema(),
"Schema is not available. Consider creating the Static Runtime "
"with StaticModule(const torch::jit::Module& m) instead.");
std::vector<c10::IValue> stack;
stack.reserve(static_module_.num_inputs());
if (static_module_.first_input_is_self()) {
stack.emplace_back(static_module_.module()._ivalue());
}
stack.insert(stack.end(), args.begin(), args.end());
void StaticRuntime::set_arg(const size_t idx, std::vector<IValue>&& args) {
DCHECK(idx < args.size());
Input(idx + first_input_is_self_) = std::move(args[idx]);
}

static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
DCHECK_EQ(static_module_.num_inputs(), stack.size());
for (const auto i : c10::irange(stack.size())) {
Input(i) = std::move(stack[i]);
}
} else {
if (static_module_.first_input_is_self()) {
Input(0) = static_module_.module()._ivalue();
DCHECK_EQ(static_module_.num_inputs(), args.size() + 1);
for (const auto i : c10::irange(args.size())) {
Input(i + 1) = args[i];
}
} else {
DCHECK_EQ(static_module_.num_inputs(), args.size());
for (const auto i : c10::irange(args.size())) {
Input(i) = args[i];
}
}
void StaticRuntime::set_arg(const size_t idx, const std::vector<IValue>& args) {
DCHECK(idx < args.size());
Input(idx + first_input_is_self_) = args[idx];
}

void StaticRuntime::set_arg(const size_t idx, const IValue& arg) {
Input(idx + first_input_is_self_) = arg;
}

namespace {
void check_type(const Argument& schema_arg, const IValue& arg) {
// Fast path for most common case
if (arg.isTensor() &&
schema_arg.type()->kind() == c10::TypeKind::TensorType) {
return;
}
TORCH_CHECK(arg.type()->isSubtypeOf(schema_arg.type()));
}
} // namespace

template <typename IValueList>
void StaticRuntime::set_inputs(
std::vector<IValue>&& args,
const KeywordArgs& kwargs) {
if (!kwargs.empty()) {
// This is not ideal
IValueList&& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
const auto total_num_inputs =
args.size() + kwargs.size() + first_input_is_self_;
TORCH_CHECK(total_num_inputs == static_module_.num_inputs());

const auto& schema = static_module_.schema();
if (first_input_is_self_) {
Input(0) = static_module_.module()._ivalue();
}

if (C10_UNLIKELY(!schema)) {
TORCH_CHECK(
static_module_.schema(),
"Schema is not available. Consider creating the Static Runtime "
kwargs.empty(),
"Schema is not available, but StaticRuntime got kwargs. "
"Consider creating the Static Runtime instance "
"with StaticModule(const torch::jit::Module& m) instead.");
std::vector<c10::IValue> stack;
stack.reserve(static_module_.num_inputs());
if (static_module_.first_input_is_self()) {
stack.emplace_back(static_module_.module()._ivalue());
for (size_t i = 0; i < args.size(); ++i) {
set_arg(i, std::forward<IValueList>(args));
}
stack.insert(
stack.end(),
std::make_move_iterator(args.begin()),
std::make_move_iterator(args.end()));
return;
}

const auto& schema_args = schema->arguments();
size_t consumed_kwargs = 0;
DCHECK(schema_args.size() > 0);

for (size_t i = 0; i < schema_args.size() - 1; ++i) {
// Start at 1 since the schema always contains `self`.
const auto& schema_arg = schema_args[i + 1];

static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
DCHECK_EQ(static_module_.num_inputs(), stack.size());
for (const auto i : c10::irange(stack.size())) {
Input(i) = std::move(stack[i]);
if (i < args.size()) {
check_type(schema_arg, args[i]);
set_arg(i, std::forward<IValueList>(args));
continue;
}
} else {
if (static_module_.first_input_is_self()) {
Input(0) = static_module_.module()._ivalue();
DCHECK_EQ(static_module_.num_inputs(), args.size() + 1);
for (const auto i : c10::irange(args.size())) {
Input(i + 1) = std::move(args[i]);
}
} else {
DCHECK_EQ(static_module_.num_inputs(), args.size());
for (const auto i : c10::irange(args.size())) {
Input(i) = std::move(args[i]);
}

auto it = kwargs.find(schema_arg.name());
if (it != kwargs.end()) {
check_type(schema_arg, it->second);
set_arg(i, it->second);
++consumed_kwargs;
continue;
}

auto maybe_default_val = schema_arg.default_value();
if (maybe_default_val) {
set_arg(i, *maybe_default_val);
continue;
}

TORCH_CHECK(
false, "Static runtime is missing required kwarg ", schema_arg.name());
}
TORCH_CHECK(
consumed_kwargs == kwargs.size() &&
args.size() + consumed_kwargs == schema_args.size() - 1);
}

void StaticRuntime::create_memory_planner() {
Expand Down
22 changes: 19 additions & 3 deletions torch/csrc/jit/runtime/static/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ class TORCH_API StaticModule {
ManagedTensorRanges managed_tensor_ranges_{};

size_t num_intermediate_values_ = 0;

// Includes self if module_ != nullopt.
// Note that we might have num_inputs_ == 0 even if the schema has a `self`
// argument. In this case, `self` isn't used in the graph, but the schema
// includes it anyways to be consistent with the JIT interpreter.
size_t num_inputs_;
};

class TORCH_API StaticRuntime {
Expand Down Expand Up @@ -537,10 +543,18 @@ class TORCH_API StaticRuntime {
const KeywordArgs& kwargs);

// helper method for copying input args/kwargs into inputs_
template <typename IValueList>
void set_inputs(
const std::vector<c10::IValue>& args,
const KeywordArgs& kwargs);
void set_inputs(std::vector<c10::IValue>&& args, const KeywordArgs& kwargs);
IValueList&& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);

// Set Input(idx) to args[idx]. Invoked by set_inputs. Copies or moves
// depending on overload.
void set_arg(const size_t idx, std::vector<IValue>&& args);
void set_arg(const size_t idx, const std::vector<IValue>& args);

// Set Input(idx) to arg. Always copies. Used for kwargs.
void set_arg(const size_t idx, const IValue& arg);

void verify_and_correct_memory_overlap(ProcessedNode& n);

Expand Down Expand Up @@ -571,6 +585,8 @@ class TORCH_API StaticRuntime {
// Otherwise, the memory used by activations is cached inside the static
// runtime.
const StaticModule& static_module_;
// Cache this so we don't have to call static_module_.first_input_is_self()
const bool first_input_is_self_;
bool manage_output_tensors_enabled_ = false;
std::unique_ptr<MemoryPlanner> planner_;
// first static_module_.num_inputs() slots are inputs, next
Expand Down