Skip to content

Commit

Permalink
[Static Runtime] Fix a bug that aten::full_like reuses a tensor tha…
Browse files Browse the repository at this point in the history
…t does not match arguments (#74255)

Summary:
Pull Request resolved: #74255

This change fixes a bug that `aten::full_like` reuses a previously allocated tensor that does not match requested one when arguments to `aten::full_like` are dynamically changed.

Test Plan: - Enhanced `StaticRuntime.FullLike` to cover the modified code path.

Reviewed By: mikeiovine

Differential Revision: D34863639

fbshipit-source-id: ca6d4ee3c039e263cc3a4f643d949cea59381608
(cherry picked from commit ae7db0a)
  • Loading branch information
d1jang authored and pytorchmergebot committed Apr 5, 2022
1 parent c2c260b commit 85e163c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
32 changes: 29 additions & 3 deletions benchmarks/static_runtime/test_static_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1307,13 +1307,39 @@ TEST(StaticRuntime, FullLike) {

auto a = at::randn({2, 3});
auto b = at::randn({3, 4, 2});
auto dtype = at::ScalarType::Int;
auto cpu = at::Device(DeviceType::CPU);
std::vector<IValue> args{
a, 4, dtype, at::kStrided, cpu, false, c10::MemoryFormat::Contiguous};
a,
4,
at::ScalarType::Int,
at::kStrided,
cpu,
false,
c10::MemoryFormat::Contiguous};
std::vector<IValue> args1{
a,
4,
at::ScalarType::Float,
at::kStrided,
cpu,
false,
c10::MemoryFormat::Contiguous};
std::vector<IValue> args2{
b, 4, dtype, at::kStrided, cpu, false, c10::MemoryFormat::Contiguous};
b,
4,
at::ScalarType::Float,
at::kStrided,
cpu,
false,
c10::MemoryFormat::Contiguous};
testStaticRuntime(full_like_script, args);
testStaticRuntime(
full_like_script,
args,
args1,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/false);
testStaticRuntime(full_like_script, args, args2);
}

Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/runtime/static/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2304,9 +2304,9 @@ REGISTER_OPERATOR_FUNCTOR(aten::full_like, aten_full_like, [](Node* n) -> SROper
return [](ProcessedNode* p_node) {
const auto in1_s = p_node->Input(1).toScalar();
const auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
const auto dtype = p_node->Input(2).toOptional<c10::ScalarType>();
const auto layout = p_node->Input(3).toOptional<c10::Layout>();
const auto dtype = p_node->Input(2).toOptional<c10::ScalarType>();
const auto layout = p_node->Input(3).toOptional<c10::Layout>();
if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) {
const auto device = p_node->Input(4).toOptional<c10::Device>();
const auto pin_memory = p_node->Input(5).toOptional<bool>();
const auto memory_format =
Expand Down

0 comments on commit 85e163c

Please sign in to comment.