Skip to content

Commit

Permalink
[profiler] skip flop compute for Nested tensor (#112767)
Browse files Browse the repository at this point in the history
Summary:
Since nested tensor doesn't have size(), when profiler with_flops is turned on, it throws exception in saveExtraArgs().

It is tricky to support flop computation for Nested tensor because it has dynamic shape. So skip the flop compute for Nested tensor for now instead of throwing exception.

Test Plan:
Used profiler with NT, the log shows this warning instead of throwing.
```/torch/nested/_internal/nested_tensor.py:205: UserWarning: Failed to save extra arguments for flops computation of op aten::add with input[0] as nested tensor. (Triggered internally at fbcode/caffe2/torch/csrc/profiler/util.cpp:433.)```

Differential Revision: D50919789

Pull Request resolved: #112767
Approved by: https://github.com/aaronenyeshi
  • Loading branch information
YuqingJ authored and pytorchmergebot committed Nov 3, 2023
1 parent 43fb514 commit 2c3ab60
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 54 deletions.
26 changes: 19 additions & 7 deletions test/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,22 +1159,34 @@ def test_flops(self):
nn.ReLU(),
)
inputs = torch.randn(40, 16, 18, 260)
with _profile(record_shapes=True, with_flops=True, use_kineto=kineto_available()) as prof:
nested_tensor = torch.nested.nested_tensor(
[torch.randn((2, 5)), torch.randn((3, 5))], layout=torch.jagged
)
with _profile(
record_shapes=True, with_flops=True, use_kineto=kineto_available()
) as prof:
model(inputs)
profiler_output = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10)
# test that nested tensor won't cause exception during flop compute
nested_tensor = nested_tensor + nested_tensor
profiler_output = prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=10
)
self.assertIn("Total MFLOPs", profiler_output)
if not (kineto_available() and torch.cuda.is_available()):
return

with profile(activities=[
with profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_flops=True,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
with_flops=True,
) as kineto_profiler:
model(inputs)
profiler_output = kineto_profiler.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1)
sort_by="self_cuda_time_total", row_limit=-1
)
self.assertIn("Total MFLOPs", profiler_output)

def test_kineto_profiler_api(self):
Expand Down
88 changes: 41 additions & 47 deletions torch/csrc/profiler/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ static constexpr auto kMatSize = "mat_size";
static constexpr auto kMat1Size = "mat1_size";
static constexpr auto kMat2Size = "mat2_size";

static bool validateInput(
static std::vector<c10::IntArrayRef> getInputSizes(
const std::string& op_name,
size_t min_size,
c10::ArrayRef<const c10::IValue> inputs,
Expand All @@ -415,17 +415,26 @@ static bool validateInput(
<< op_name << ", min size: " << min_size
<< ", actual size: " << inputs.size();
TORCH_WARN(ss.str());
return false;
return {};
}
std::vector<c10::IntArrayRef> inputSizes = {};
for (auto index : should_be_tensor) {
if (!inputs[index].isTensor()) {
ss << "Failed to save extra arguments for flops computation of op "
<< op_name << ", input[" << index << "] must be a tensor.";
TORCH_WARN(ss.str());
return false;
return {};
}
at::Tensor t = inputs[index].toTensor();
if (t.is_nested()) {
ss << "Failed to save extra arguments for flops computation of op "
<< op_name << " with input[" << index << "] as nested tensor.";
TORCH_WARN(ss.str());
return {};
}
inputSizes.emplace_back(t.sizes());
}
return true;
return inputSizes;
}

std::unordered_map<std::string, c10::IValue> saveExtraArgs(
Expand All @@ -441,88 +450,73 @@ std::unordered_map<std::string, c10::IValue> saveExtraArgs(
}

if (fname == kConv2dOp) {
bool check = validateInput(fname, kConv2dGroups + 1, inputs, {0, 1});
if (!check) {
const auto inputSizes =
getInputSizes(fname, kConv2dGroups + 1, inputs, {0, 1});
if (inputSizes.empty()) {
return map;
}

at::Tensor input = inputs[0].toTensor();
at::Tensor weight = inputs[1].toTensor();
if (weight.sizes().size() != 4) {
if (inputSizes[1].size() != 4) {
TORCH_WARN(
"Failed to compute flops for op aten::conv2d because it requires a 4D kernel tensor.");
return map;
}
map[kInputSize] = at::IValue(input.sizes());
map[kWeightSize] = at::IValue(weight.sizes());
map[kInputSize] = at::IValue(inputSizes[0]);
map[kWeightSize] = at::IValue(inputSizes[1]);
map[kStride] = inputs[kConv2dStride];
map[kPadding] = inputs[kConv2dPadding];
map[kDilation] = inputs[kConv2dDilation];
map[kGroups] = inputs[kConv2dGroups];
} else if (fname == kMMOp) {
bool check = validateInput(fname, 2, inputs, {0, 1});
if (!check) {
const auto inputSizes = getInputSizes(fname, 2, inputs, {0, 1});
if (inputSizes.empty()) {
return map;
}

at::Tensor left = inputs[0].toTensor();
at::Tensor right = inputs[1].toTensor();
map[kMat1Size] = at::IValue(left.sizes());
map[kMat2Size] = at::IValue(right.sizes());
map[kMat1Size] = at::IValue(inputSizes[0]);
map[kMat2Size] = at::IValue(inputSizes[1]);
} else if (fname == kAddMMOp) {
bool check = validateInput(fname, 3, inputs, {0, 1, 2});
if (!check) {
const auto inputSizes = getInputSizes(fname, 3, inputs, {0, 1, 2});
if (inputSizes.empty()) {
return map;
}

// Exact FLOP count depends on scaling factors alpha and beta but
// just assume these are +=1.
// (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
// "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
at::Tensor left = inputs[1].toTensor();
at::Tensor right = inputs[2].toTensor();
map[kMat1Size] = at::IValue(left.sizes());
map[kMat2Size] = at::IValue(right.sizes());
map[kMat1Size] = at::IValue(inputSizes[1]);
map[kMat2Size] = at::IValue(inputSizes[2]);
} else if (fname == kMulOp) {
bool check = validateInput(fname, 1, inputs, {0});
if (!check) {
const auto inputSizes = getInputSizes(fname, 1, inputs, {0});
if (inputSizes.empty()) {
return map;
}

at::Tensor mat = inputs[0].toTensor();
map[kMatSize] = at::IValue(mat.sizes());
map[kMatSize] = at::IValue(inputSizes[0]);
} else if (fname == kAddOp) {
bool check = validateInput(fname, 1, inputs, {0});
if (!check) {
const auto inputSizes = getInputSizes(fname, 1, inputs, {0});
if (inputSizes.empty()) {
return map;
}

at::Tensor mat = inputs[0].toTensor();
map[kMatSize] = at::IValue(mat.sizes());
map[kMatSize] = at::IValue(inputSizes[0]);
} else if (fname == kBMMOp) {
bool check = validateInput(fname, 2, inputs, {0, 1});
if (!check) {
const auto inputSizes = getInputSizes(fname, 2, inputs, {0, 1});
if (inputSizes.empty()) {
return map;
}

at::Tensor left = inputs[0].toTensor();
at::Tensor right = inputs[1].toTensor();
map[kMat1Size] = at::IValue(left.sizes());
map[kMat2Size] = at::IValue(right.sizes());
map[kMat1Size] = at::IValue(inputSizes[0]);
map[kMat2Size] = at::IValue(inputSizes[1]);
} else if (fname == kBAddBMMOp) {
bool check = validateInput(fname, 3, inputs, {0, 1, 2});
if (!check) {
const auto inputSizes = getInputSizes(fname, 3, inputs, {0, 1, 2});
if (inputSizes.empty()) {
return map;
}

// Exact FLOP count depends on scaling factors alpha and beta but
// just assume these are +=1.
// (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
// "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
at::Tensor left = inputs[1].toTensor();
at::Tensor right = inputs[2].toTensor();
map[kMat1Size] = at::IValue(left.sizes());
map[kMat2Size] = at::IValue(right.sizes());
map[kMat1Size] = at::IValue(inputSizes[1]);
map[kMat2Size] = at::IValue(inputSizes[2]);
}

return map;
Expand Down

0 comments on commit 2c3ab60

Please sign in to comment.