Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Dynamic expand with SymInt implementation #3558

Draft
wants to merge 78 commits into
base: master
Choose a base branch
from

Conversation

miladm
Copy link
Collaborator

@miladm miladm commented May 10, 2022

This is the POC implementation of torch.Tensor.expand op based on the PyTorch SymInt POC implementation PR.

Action items to unblock:

@miladm miladm added this to the Dynamic Shape milestone May 10, 2022
@miladm miladm added the DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing label May 10, 2022
@miladm miladm self-assigned this May 10, 2022
@miladm miladm requested a review from JackCaoG May 10, 2022 04:06
@miladm miladm changed the title dynamic expand with symint implementation POC [POC] Dynamic expand with SymInt implementation May 10, 2022
@miladm miladm requested review from Krovatkin and ezyang May 10, 2022 07:01
@JackCaoG
Copy link
Collaborator

Does this pr build locally on your end? Build on CI failed with conflicts.

torch_xla/csrc/aten_xla_type.cpp Outdated Show resolved Hide resolved
Comment on lines 1399 to 1386
for (int index = 0; i < sizes.size(); i++) {
auto _symbolicIntNode = sizes[i].toSymbolicIntNode();
auto _sizenode = MakeNode<SizeNode>(_symbolicIntNode);
upper_bound.push_back(_sizenode.getStaticValue());
dynamic_dims.push_back(_sizenode.isDynamic());
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably put these in a helper function for now. I felt like we should be able to codegen these in the future too.


/* Construct Upper Bound Tensor Shape */
xla::XlaOp upper_bound_size_input =
xla::Parameter(loctx->builder(), 0, target_shape, "upper_bound_size");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it is rare to use xla::Parameter in the lowering, you can use xla::Zero and xla::Broadcast to achieve the same.

Copy link
Collaborator Author

@miladm miladm May 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do I use to two API calls to reach the same goal instead of one?

torch_xla/csrc/data_ops.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/aten_xla_type.cpp Outdated Show resolved Hide resolved
@miladm
Copy link
Collaborator Author

miladm commented May 10, 2022

This PR doesn't build at the moment because the upstream layer LTC doesn't yet have API support for expand with SymInt. @JackCaoG @Krovatkin

@miladm miladm requested a review from wconstab May 10, 2022 18:53
@miladm miladm marked this pull request as draft May 10, 2022 18:58
@miladm
Copy link
Collaborator Author

miladm commented May 16, 2022

Update: The current unit test checks the expand.SymInt code path. It does not check the dynamic dimension propagation across a SymInt op since DimensionNode::isDynamic implementation is currently WIP.

CC @JackCaoG @Krovatkin

std::vector<torch::lazy::NodePtr> size_nodes;
std::vector<int64_t> upper_bounds;
std::vector<bool> dynamic_dims;
/* TODO: move this code to a helper function */
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering to move this code to a helper function. My candidates are creating a new helper file, helpers.h, or tensor_util.h? @JackCaoG wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could also consider putting a helper in dynamic_ir.h. And I can make a copy for the TS backend. That's a disadvantage of code duplication. Hopefully, we won't need to do it often.

std::vector<torch::lazy::NodePtr> size_nodes;
std::vector<int64_t> upper_bounds;
std::vector<bool> dynamic_dims;
/* TODO: move this code to a helper function */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could also consider putting a helper in dynamic_ir.h. And I can make a copy for the TS backend. That's a disadvantage of code duplication. Hopefully, we won't need to do it often.

torch_xla/csrc/aten_xla_type.cpp Outdated Show resolved Hide resolved
absl::Span<const xla::XlaOp> output_sizes) {
for (int i = 0; i < output_sizes.size(); i++) {
xla::Shape dim_shape = XlaHelpers::ShapeOfXlaOp(output_sizes[i]);
if (dim_shape.is_dynamic()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if the shape isn't dynamic we will be using static dimensions from InferShape? I'm asking because we would like shapes to appear symbolically in a graph and I'm not sure this would happen if dimensions are static? @miladm @JackCaoG

Copy link
Collaborator Author

@miladm miladm May 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, dimensions from InferShape set shape. This means when a dimension is static, the upper_bound values are expected to be equal to the "true" dimensions. Does this align with your side @Krovatkin?

Can you elaborate on: we would like shapes to appear symbolically in a graph and I'm not sure this would happen if dimensions are static? @Krovatkin

@@ -116,6 +116,7 @@ supported:
- erfinv
- exp
- expand
- expand.SymInt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually not seeing the exact place where your hooking in XLATensor::expand into a dispatcher? is the plan to use codegen for that?

Copy link
Collaborator Author

@miladm miladm May 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our current codegen creates XLANativeFunctions.h for aten_xla_type.cpp. In it, you find the following definitions. Does it answer your question @Krovatkin?

static at::Tensor expand(const at::Tensor & self, at::IntArrayRef size, bool implicit);
static at::Tensor expand(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit);

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still WIP for the review..

-DPYTHON_INCLUDE_DIR=$(python -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())") \
-DPYTHON_LIBRARY=$(python -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR') + '/' + sysconfig.get_config_var('LDLIBRARY'))")
-DPYTHON_INCLUDE_DIR=$(python3 -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())") \
-DPYTHON_LIBRARY=$(python3 -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR') + '/' + sysconfig.get_config_var('LDLIBRARY'))")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was an issue by forcing python to python3, @yeounoh might have more context. If this is for the test purpose you can use

sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.8 100

on your tpuvm

test/cpp/test_aten_xla_tensor.cpp Show resolved Hide resolved
torch_xla/csrc/data_ops.cpp Show resolved Hide resolved
torch_xla/csrc/data_ops.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/ops/dynamic_ir.cpp Outdated Show resolved Hide resolved

std::string SizeNode::ToString() const { return "SizeNode"; }

SizeAdd::SizeAdd(XlaValue a, XlaValue b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check that XlaValue actually contains DimensionNode and store them as DimensionNode diretly? This can save us from doing multiple dynamic_cast on getStaticValue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am flexible. @Krovatkin wdyt?

SizeAdd(XlaValue a, XlaValue b);
int64_t getStaticValue() const override;
std::string ToString() const override;
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why these operator does not have Lower?

Copy link
Collaborator Author

@miladm miladm May 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we need to lower these ops as they return upper bound information. Unless we directly pass SizeAdd to downstream ops like expand. @Krovatkin wdyt?

To help the discussion I will push a lowering implementation shortly. Let's continue discussing.

SizeDiv::SizeDiv(XlaValue a, XlaValue b)
: DimensionNode(
torch::lazy::OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b},
torch::lazy::MHash(1)){};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take this hash is a hack for now..

torch_xla/csrc/ops/ops.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/ops/ops.cpp Outdated Show resolved Hide resolved
@miladm miladm changed the title [POC] Dynamic expand with SymInt implementation [POC][WIP] Dynamic expand with SymInt implementation May 17, 2022
c10::SymInt xla_y0_size = xla_y.sym_sizes()[0];
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = xla_a.expand(
c10::SymIntArrayRef({xla_y0_size, c10::SymInt(3), c10::SymInt(4)}),
Copy link
Collaborator Author

@miladm miladm May 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Krovatkin I hope constructing c10::SymIntArrayRef makes sense (ref). Let me know if it should be done differently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks good. I think just plain xla_a.expand_symint({xla_y0_size, c10::SymInt(3), c10::SymInt(4)}, .... should've worked. If not we should look into that.

Copy link
Collaborator Author

@miladm miladm Jul 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Krovatkin as discussed this morning, looks like this line causes the failure raised here. What's your guidance?

@miladm miladm changed the title [POC][WIP] Dynamic expand with SymInt implementation [POC] Dynamic expand with SymInt implementation Jul 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
BLOCKED DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing dynamism Dynamic Shape Features
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants