Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions test/cpp/lazy/test_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ namespace lazy {

class TestLeafNode : public Node {
public:
static OpKind ClassOpKind() {
return OpKind();
}

explicit TestLeafNode(size_t param)
: Node(OpKind(), /* num_outputs */ 1),
: Node(ClassOpKind(), /* num_outputs */ 1),
hash_(Hash(param)),
param_(param) {}
~TestLeafNode() override = default;
Expand All @@ -45,7 +49,7 @@ TEST(IrTest, BasicTest) {

EXPECT_EQ(node1->num_outputs(), 1);

const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get(), OpKind());
const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get());
EXPECT_TRUE(leafptr != nullptr);
}

Expand Down Expand Up @@ -102,7 +106,7 @@ TEST(IrTest, TsNodeTest) {

EXPECT_EQ(node1->num_outputs(), 1);

const TsNode* leafptr = NodeCast<TsNode>(node1.get(), OpKind(at::aten::view));
const TsNode* leafptr = dynamic_cast<const TsNode*>(node1.get());
EXPECT_TRUE(leafptr != nullptr);
}

Expand Down
42 changes: 23 additions & 19 deletions test/cpp/lazy/test_trie_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ namespace lazy {

class TrieCacheNode : public Node {
public:
static OpKind ClassOpKind() {
return OpKind();
}

explicit TrieCacheNode(size_t id)
: Node(OpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
: Node(ClassOpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
~TrieCacheNode() override = default;

bool Equal(size_t id) const {
Expand All @@ -40,14 +44,14 @@ TEST(TrieCacheTest, TestSinglePath) {
FLAGS_torch_lazy_reuse_ir = true;
TrieCache::Get()->Clear();

NodePtr a = MakeNode<TrieCacheNode>(0);
NodePtr b = MakeNode<TrieCacheNode>(1);
NodePtr c = MakeNode<TrieCacheNode>(2);
NodePtr a = ReuseOrMakeNode<TrieCacheNode>(0);
NodePtr b = ReuseOrMakeNode<TrieCacheNode>(1);
NodePtr c = ReuseOrMakeNode<TrieCacheNode>(2);
TrieCache::Get()->ResetCurrent(); // MarkStep

EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
}

Expand All @@ -62,25 +66,25 @@ TEST(TrieCacheTest, TestTwoPaths) {
FLAGS_torch_lazy_reuse_ir = true;
TrieCache::Get()->Clear();

NodePtr a = MakeNode<TrieCacheNode>(0);
NodePtr b = MakeNode<TrieCacheNode>(1);
NodePtr c = MakeNode<TrieCacheNode>(2);
NodePtr a = ReuseOrMakeNode<TrieCacheNode>(0);
NodePtr b = ReuseOrMakeNode<TrieCacheNode>(1);
NodePtr c = ReuseOrMakeNode<TrieCacheNode>(2);
TrieCache::Get()->ResetCurrent(); // MarkStep

EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3);
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(3);
EXPECT_NE(d.get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep

EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3).get(), d.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(3).get(), d.get());
TrieCache::Get()->ResetCurrent(); // MarkStep

EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
}

Expand Down
14 changes: 14 additions & 0 deletions torch/csrc/lazy/core/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ inline std::ostream& operator<<(std::ostream& stream, const Node& node) {
return stream;
}

// Note: Keep this version of NodeCast for smooth PyTorch/XLA migration, and
// clean up once the migration is done.
template <typename T>
const T* NodeCast(const Node* node, OpKind op) {
if (op != node->op()) {
Expand All @@ -187,6 +189,18 @@ const T* NodeCast(const Node* node, OpKind op) {
#endif
}

template <typename T>
const T* NodeCast(const Node* node) {
if (T::ClassOpKind() != node->op()) {
return nullptr;
}
#ifdef NDEBUG
return static_cast<const T*>(node);
#else
return &dynamic_cast<const T&>(*node);
#endif
}


// Represents a specific output produced by a node. Since the output of a node
// can be composed by multiple outputs, the node+index coordinates fully qualify
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/lazy/core/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace torch {
namespace lazy {

template <typename T, typename... Args>
NodePtr ReuseNode(OpKind op, Args&&... args) {
NodePtr ReuseNode(Args&&... args) {
if (FLAGS_torch_lazy_reuse_ir) {
return LookupNodeFromTrieCache<T>(op, std::forward<Args>(args)...);
return LookupNodeFromTrieCache<T>(std::forward<Args>(args)...);
}
return nullptr;
}
Expand All @@ -27,16 +27,16 @@ template <typename T, typename... Args>
NodePtr MakeNode(Args&&... args) {
NodePtr node = std::make_shared<T>(std::forward<Args>(args)...);
if (FLAGS_torch_lazy_reuse_ir) {
// If ir caching is enabled, we need to record all new nodes
// If ir caching is enabled, we need to record all new nodes
TrieCache::Get()->Insert(node);
}
return node;
}

// op is passed in for a more efficient node casting, see the implementation of NodeCast
template <typename T, typename... Args>
NodePtr ReuseOrMakeNode(OpKind op, Args&&... args) {
NodePtr node = ReuseNode<T>(op, std::forward<Args>(args)...);
NodePtr ReuseOrMakeNode(Args&&... args) {
NodePtr node = ReuseNode<T>(std::forward<Args>(args)...);
if (!node) {
node = MakeNode<T>(std::forward<Args>(args)...);
}
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/lazy/core/trie.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ class TORCH_API TrieCache {
};

template <typename T, typename... Args>
NodePtr LookupNodeFromTrieCache(OpKind op, Args&&... args) {
NodePtr LookupNodeFromTrieCache(Args&&... args) {
auto& successors = TrieCache::Get()->Current()->successors;
for (auto it = successors.begin(); it != successors.end(); it++) {
NodePtr ir_node = (*it)->ir_node;
const T* concrete_node = NodeCast<T>(ir_node.get(), op);
const T* concrete_node = NodeCast<T>(ir_node.get());
if (concrete_node && concrete_node->Equal(std::forward<Args>(args)...)) {
TORCH_LAZY_COUNTER("IrNodeReused::" + std::string(typeid(T).name()), 1);
TrieCache::Get()->SetCurrent(it);
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ namespace lazy {
// Node for the backward batch norm operator.
class TSNativeBatchNormBackward : public torch::lazy::TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::native_batch_norm_backward);
}

TSNativeBatchNormBackward(const torch::lazy::Value& grad_out, const torch::lazy::Value& input,
const torch::lazy::Value& weight, const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var, const torch::lazy::Value& save_mean,
Expand Down Expand Up @@ -35,6 +39,10 @@ class TSNativeBatchNormBackward : public torch::lazy::TsNode {

class TSNativeBatchNormForward : public torch::lazy::TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::native_batch_norm);
}

TSNativeBatchNormForward(const torch::lazy::Value& input, const torch::lazy::Value& weight,
const torch::lazy::Value& bias, const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var, bool training,
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/lazy/ts_backend/ops/cast.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <torch/csrc/lazy/ts_backend/ops/cast.h>

#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/tensor_util.h>

namespace torch {
Expand All @@ -15,12 +14,13 @@ Shape NodeOutputShape(const Value& input, c10::ScalarType type) {
}

} // namespace

Cast::Cast(
const Value& input,
at::ScalarType dtype,
c10::optional<at::ScalarType> stype)
: TsNode(
ltc_cast,
ClassOpKind(),
{input},
{NodeOutputShape(input, dtype)},
/*num_outputs=*/1,
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@

#include <c10/core/ScalarType.h>
#include <c10/util/Optional.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>

namespace torch {
namespace lazy {

class TORCH_API Cast : public TsNode {
public:
static OpKind ClassOpKind() {
return ltc_cast;
}

Cast(
const Value& input,
at::ScalarType dtype,
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/lazy/ts_backend/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace lazy {

DeviceData::DeviceData(std::shared_ptr<BackendData> data)
: TsNode(
ltc_device_data,
ClassOpKind(),
data->shape(),
/*num_outputs=*/1,
/*hash_seed=*/static_cast<uint32_t>(101)),
Expand All @@ -22,7 +22,7 @@ std::string DeviceData::ToString() const {
}

const DeviceData* DeviceData::Cast(const Node* node) {
return NodeCast<DeviceData>(node, ltc_device_data);
return NodeCast<DeviceData>(node);
}

} // namespace lazy
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/device_data.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
#pragma once

#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>

namespace torch {
namespace lazy {

class TORCH_API DeviceData : public TsNode {
public:
static OpKind ClassOpKind() {
return ltc_device_data;
}

explicit DeviceData(std::shared_ptr<BackendData> data);

std::string ToString() const override;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/lazy/ts_backend/ops/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Expand::Expand(
std::vector<int64_t> size,
bool is_scalar_expand)
: TsNode(
OpKind(at::aten::expand),
ClassOpKind(),
{input},
/*num_outputs=*/1,
MHash(size, is_scalar_expand)),
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/expand.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ namespace lazy {

class TORCH_API Expand : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::expand);
}

Expand(const Value& input, std::vector<int64_t> size, bool is_scalar_expand);

std::string ToString() const override;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/lazy/ts_backend/ops/random_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace torch {
namespace lazy {

Normal::Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector<torch::lazy::Shape>&& shapes)
: torch::lazy::TsNode(torch::lazy::OpKind(c10::Symbol::fromQualString("aten::normal_")),
: torch::lazy::TsNode(ClassOpKind(),
{self}, std::move(shapes),
/* num_outputs */ 1,
torch::lazy::MHash(mean, std)),
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/random_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ namespace lazy {

class Normal : public torch::lazy::TsNode {
public:
static OpKind ClassOpKind() {
return OpKind::Get("aten::normal_");
}

Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector<torch::lazy::Shape>&& shapes);

std::string ToString() const override;
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/lazy/ts_backend/ops/scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ using at::operator<<;

Scalar::Scalar(const at::Scalar& value, Shape shape)
: TsNode(
OpKind(at::prim::Constant),
ClassOpKind(),
std::move(shape),
/*num_outputs=*/1,
ScalarHash(value)),
value_(value) {}

Scalar::Scalar(const at::Scalar& value, c10::ScalarType type)
: TsNode(
OpKind(at::prim::Constant),
ClassOpKind(),
{Shape(type, {})},
/*num_outputs=*/1,
ScalarHash(value)),
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ namespace lazy {
// computation graph.
class TORCH_API Scalar : public TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::prim::Constant);
}

Scalar(const at::Scalar& value, Shape shape);
Scalar(const at::Scalar& value, c10::ScalarType type);

Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/lazy/ts_backend/ops/to_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ namespace lazy {
// the aten/eager fallback necessitating directly implementing the right to(device) behavior
class ToCopy : public torch::lazy::TsNode {
public:
static OpKind ClassOpKind() {
return OpKind(at::aten::_to_copy);
}

ToCopy(const torch::lazy::Value& self, const c10::optional<at::ScalarType>& dtype, const c10::optional<at::Layout>& layout, const c10::optional<at::Device>& device, const c10::optional<bool>& pin_memory, const bool& non_blocking, const c10::optional<at::MemoryFormat>& memory_format, std::vector<torch::lazy::Shape>&& shapes)
: torch::lazy::TsNode(torch::lazy::OpKind(at::aten::_to_copy),
: torch::lazy::TsNode(ClassOpKind(),
{self}, std::move(shapes),
/* num_outputs */ 1,
torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)),
Expand Down Expand Up @@ -85,5 +89,6 @@ class ToCopy : public torch::lazy::TsNode {
bool non_blocking;
c10::optional<at::MemoryFormat> memory_format;
};

} // namespace lazy
} // namespace torch
2 changes: 1 addition & 1 deletion torch/csrc/lazy/ts_backend/ts_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ TSLoweringContext::TSLoweringContext(
void TSLoweringContext::AssignOutputOp(
const Output& output,
torch::jit::Value* op) {
auto ts_node = NodeCast<TsNode>(output.node, output.node->op());
const TsNode* ts_node = static_cast<const TsNode*>(output.node);
std::string stack_trace = ts_node->getPythonStacktrace();
if (!stack_trace.empty()) {
op->node()->s_(c10::Symbol::attr("source"), stack_trace);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/lazy/ts_backend/ts_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ TSOpVector TsNode::Lower(std::shared_ptr<torch::jit::GraphFunction> function,
}

TensorList::TensorList(OpList values)
: TsNode(/*op=*/tensor_list_opkind,
: TsNode(/*op=*/ClassOpKind(),
/*operands=*/values,
/*shapes=*/std::vector<Shape>(),
/*num_outputs=*/1,
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/lazy/ts_backend/ts_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and then implement it as NotImplemented for
// TensorList, also fixing the assertion that would fail.
struct TORCH_API TensorList : public TsNode {
static OpKind ClassOpKind() {
return tensor_list_opkind;
}

TensorList() = delete;
TensorList(OpList values);

Expand Down
Loading