-
Notifications
You must be signed in to change notification settings - Fork 565
Kl div backward removal #3185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kl div backward removal #3185
Changes from all commits
22f72f5
8322fdf
3a89b3e
f4b8c89
d46adca
7df13d7
e90e748
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -894,6 +894,26 @@ NodePtr LogicalOr(const Value& input, const Value& other) { | |
std::move(lower_fn)); | ||
} | ||
|
||
NodePtr XLogY(const Value& input, const Value& other) { | ||
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector { | ||
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); | ||
xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(1)); | ||
xla::XlaOp xla_output = BuildXLogY(xla_input, xla_other); | ||
return node.ReturnOp(xla_output, loctx); | ||
}; | ||
auto lower_for_shape_fn = | ||
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp { | ||
XLA_CHECK_EQ(operands.size(), 2) << "Unexpected number of operands"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we not have the same error check in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can add it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually lower and shape fn share the same operand. Shape function is being run first so no need to redo the check in lower fn. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, that statement is not true..shape fn and lower fn uses different parameter. I guess we don't really have a good reason for only doing it in shape fn. I check the history and it seems like it is being added for one shape fn and then in the later code we just copy the code for other ops hence inherit this check. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I prefer this leave this check here or I can delete it. If we decide to add operand check for lowering function, we should do that for all ops. It is currently only being done for shape_fn sometimes which is very inconsistent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sg. let's leave it out. |
||
return BuildXLogY(operands[0], operands[1]); | ||
}; | ||
return GenericOp(OpKind(at::aten::xlogy), {input, other}, | ||
[&]() { | ||
return InferOutputShape({input.shape(), other.shape()}, | ||
lower_for_shape_fn); | ||
}, | ||
std::move(lower_fn)); | ||
} | ||
|
||
NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf, | ||
const Value& neginf) { | ||
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's name this
std::nan("2")
for better debugging differentiation later?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between
nan(1)
andnan(2)
? I thought they are just nan.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I was thinking when something fails it shows nan(1) vs. nan(2) in the error message. It's nit. Feel free to ignore if this is the only change you have left.