New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support int32 index for Gather #2030
Conversation
lib/Graph/Nodes.cpp
Outdated
@@ -818,7 +818,8 @@ bool RowwiseQuantizedFullyConnectedNode::verify() const { | |||
|
|||
bool GatherNode::verify() const { | |||
bool isValid = checkType(getResult(), getData().getElementType(), this); | |||
isValid &= checkType(getIndices(), ElemKind::Int64ITy, this); | |||
isValid &= (checkType(getIndices(), ElemKind::Int64ITy, this) || |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
include/glow/Graph/VerifierHelper.h
Outdated
const char *msg, const InputTy &a, llvm::ArrayRef<InputTy> b, | ||
const Node *parent, | ||
const CompareWithName<InputTy> &comp = CompareOperatorEqual<InputTy>()) { | ||
bool seed = false; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
include/glow/Graph/VerifierHelper.h
Outdated
@@ -153,6 +188,12 @@ bool checkSameShape(NodeValue A, NodeValue B, const Node *parent); | |||
/// \see expectCompareTrue for more details. | |||
bool checkType(NodeValue A, ElemKind expectedType, const Node *parent); | |||
|
|||
/// Check that the element type of the operand \p A matches any of the expected | |||
/// types \p expected Types. \p parent is used to print the context of that |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
lib/Backends/CPU/libjit/libjit.cpp
Outdated
@@ -370,7 +370,7 @@ void libjit_gather(T *dest, const T *data, const size_t *indices, | |||
size_t sampleStart = sample * sampleSize; | |||
|
|||
// For each slice that we fetch: | |||
for (size_t i = 0; i < numIndices; i++) { | |||
for (IDX i = 0; i < numIndices; i++) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
auto *F = getFunction("gather", dest->getElementType()); | ||
llvm::Function *F = nullptr; | ||
if (indices->getElementType() == ElemKind::Int64ITy) { | ||
F = getFunction("gather64", dest->getElementType()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Comments addressed. |
isValid &= checkType(getIndices(), ElemKind::Int64ITy, this); | ||
isValid &= checkType( | ||
getIndices(), | ||
llvm::makeArrayRef({ElemKind::Int64ITy, ElemKind::Int32ITy}), this); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
|
||
{ | ||
ONNXModelLoader onnxLD(netFilename, {"data", "indices"}, | ||
{&data.getType(), &indices.getType()}, *F); | ||
output = onnxLD.getSingleOutput(); | ||
} | ||
EE.compile(CompilationMode::Infer, F); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Description:
The indices input of
Gather
can be eitherint32_t
orint64_t
, we need to support both. This PR does this and adds support for Interpreter and CPU backend.Testing:
Unit test
Documentation:
[Optional Fixes #issue]
Please see a detailed explanation of how to fill out the fields in the relevant sections in PULL_REQUEST.md.