Skip to content

Commit

Permalink
[3/N] Move c10::variant to std::variant (#110141)
Browse files Browse the repository at this point in the history
This PR moves more c10::variant calls to std::variant

Pull Request resolved: #110141
Approved by: https://github.com/Skylion007
  • Loading branch information
cyyever authored and pytorchmergebot committed Sep 28, 2023
1 parent 84c5435 commit 168f516
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 18 deletions.
16 changes: 8 additions & 8 deletions c10/core/ConstantSymNodeImpl.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <c10/core/SymNodeImpl.h>
#include <c10/util/variant.h>
#include <variant>

namespace c10 {

Expand Down Expand Up @@ -39,11 +39,11 @@ class C10_API ConstantSymNodeImpl : public SymNodeImpl {
}
int64_t int_() override {
TORCH_CHECK(is_int(), "not an int");
return c10::get<int64_t>(value_);
return std::get<int64_t>(value_);
}
bool bool_() override {
TORCH_CHECK(is_bool(), "not a bool");
return c10::get<bool>(value_);
return std::get<bool>(value_);
}
bool has_hint() override {
return true;
Expand All @@ -56,21 +56,21 @@ class C10_API ConstantSymNodeImpl : public SymNodeImpl {
c10::SymNode gt(const c10::SymNode& other) override;
std::string str() override {
if constexpr (is_int_()) {
return std::to_string(c10::get<int64_t>(value_));
return std::to_string(std::get<int64_t>(value_));
} else {
return c10::get<bool>(value_) ? "true" : "false";
return std::get<bool>(value_) ? "true" : "false";
}
}
c10::optional<int64_t> constant_int() override {
if constexpr (is_int_()) {
return c10::get<int64_t>(value_);
return std::get<int64_t>(value_);
} else {
return c10::nullopt;
}
}
c10::optional<bool> constant_bool() override {
if constexpr (is_bool_()) {
return c10::get<bool>(value_);
return std::get<bool>(value_);
} else {
return c10::nullopt;
}
Expand All @@ -83,7 +83,7 @@ class C10_API ConstantSymNodeImpl : public SymNodeImpl {
}

private:
c10::variant<int64_t, bool> value_;
std::variant<int64_t, bool> value_;

static constexpr bool is_int_() {
return std::is_same<T, int64_t>::value;
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/autograd/profiler_kineto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ auto parseArgData(
std::vector<c10::IValue> concrete_inputs_list;

for (const auto& i : c10::irange(input_shapes.size())) {
c10::visit(
std::visit(
c10::overloaded(
[&](const TensorMetadata& t) {
shapes[i] = t.sizes_;
Expand All @@ -127,12 +127,12 @@ auto parseArgData(
concrete_inputs_list.resize(input_shapes.size());

for (const auto& i : c10::irange(input_shapes.size())) {
c10::visit(
std::visit(
c10::overloaded(
[&](const c10::IValue& val) { concrete_inputs_list[i] = val; },
[&](const auto&) {}),
input_shapes[i]);
c10::visit(
std::visit(
c10::overloaded(
[&](const c10::IValue& val) {
concrete_inputs_list[i] = val;
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/profiler/collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <mutex>
#include <type_traits>
#include <utility>
#include <variant>

#include <ATen/Context.h>
#include <c10/core/Device.h>
Expand Down Expand Up @@ -85,7 +86,7 @@ struct TORCH_API TensorMetadata : public RawTensorMetadataBase {
c10::optional<AllocationID> allocation_id_;
};

using op_input_t = c10::variant<
using op_input_t = std::variant<
TensorMetadata,
std::vector<TensorMetadata>,
c10::IValue,
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/profiler/data_flow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void calculateUniqueTensorIDs(
result->visit(c10::overloaded(
[&](ExtraFields<EventType::TorchOp>& torch_op) {
for (auto& i : torch_op.inputs_) {
c10::visit(raw_tensors, i);
std::visit(raw_tensors, i);
}
},
[&](ExtraFields<EventType::PyCall>& py_call) {
Expand Down
15 changes: 10 additions & 5 deletions torch/csrc/profiler/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ static void THPCapturedTraceback_dealloc(PyObject* self_) {

PyTypeObject THPCapturedTracebackType = {
PyVarObject_HEAD_INIT(
NULL,
nullptr,
0) "torch._C._profiler.CapturedTraceback", /* tp_name */
sizeof(THPCapturedTraceback), /* tp_basicsize */
0, /* tp_itemsize */
Expand Down Expand Up @@ -162,7 +162,12 @@ int RecordFunctionFast_init(
constexpr const char* kwlist[] = {"name", nullptr};
PyObject* name = nullptr;
if (!PyArg_ParseTupleAndKeywords(
args, kwargs, "O", const_cast<char**>(kwlist), &name)) {
args,
kwargs,
"O",
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<char**>(kwlist),
&name)) {
return -1;
}
if (name) {
Expand Down Expand Up @@ -383,7 +388,7 @@ void initPythonBindings(PyObject* module) {
[](const torch_op_t& op) {
py::list out;
for (const auto& input : op.inputs_) {
c10::visit(
std::visit(
c10::overloaded(
[&](const c10::IValue& v) {
out.append(torch::jit::toPyObject(v));
Expand Down Expand Up @@ -536,11 +541,11 @@ void initPythonBindings(PyObject* module) {
static PyMethodDef RecordFunctionFast_methods[] = {
{"__enter__", RecordFunctionFast_enter, METH_NOARGS, nullptr},
{"__exit__", RecordFunctionFast_exit, METH_VARARGS, nullptr},
{NULL},
{nullptr},
};

static PyTypeObject RecordFunctionFast_Type = {
PyVarObject_HEAD_INIT(NULL, 0)};
PyVarObject_HEAD_INIT(nullptr, 0)};

RecordFunctionFast_Type.tp_name = "torch._C._profiler.RecordFunctionFast",
RecordFunctionFast_Type.tp_basicsize = sizeof(RecordFunctionFast);
Expand Down

0 comments on commit 168f516

Please sign in to comment.