Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support int32 index for Gather #2030

Merged
merged 3 commits into from Nov 15, 2018
Merged

Support int32 index for Gather #2030

merged 3 commits into from Nov 15, 2018

Conversation

yinghai
Copy link
Contributor

@yinghai yinghai commented Nov 15, 2018

Description:
The indices input of Gather can be either int32_t or int64_t, we need to support both. This PR does this and adds support for Interpreter and CPU backend.

Testing:
Unit test
Documentation:
[Optional Fixes #issue]

Please see a detailed explanation of how to fill out the fields in the relevant sections in PULL_REQUEST.md.

@@ -818,7 +818,8 @@ bool RowwiseQuantizedFullyConnectedNode::verify() const {

bool GatherNode::verify() const {
bool isValid = checkType(getResult(), getData().getElementType(), this);
isValid &= checkType(getIndices(), ElemKind::Int64ITy, this);
isValid &= (checkType(getIndices(), ElemKind::Int64ITy, this) ||

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

const char *msg, const InputTy &a, llvm::ArrayRef<InputTy> b,
const Node *parent,
const CompareWithName<InputTy> &comp = CompareOperatorEqual<InputTy>()) {
bool seed = false;

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -153,6 +188,12 @@ bool checkSameShape(NodeValue A, NodeValue B, const Node *parent);
/// \see expectCompareTrue for more details.
bool checkType(NodeValue A, ElemKind expectedType, const Node *parent);

/// Check that the element type of the operand \p A matches any of the expected
/// types \p expected Types. \p parent is used to print the context of that

This comment was marked as off-topic.

@@ -370,7 +370,7 @@ void libjit_gather(T *dest, const T *data, const size_t *indices,
size_t sampleStart = sample * sampleSize;

// For each slice that we fetch:
for (size_t i = 0; i < numIndices; i++) {
for (IDX i = 0; i < numIndices; i++) {

This comment was marked as off-topic.

auto *F = getFunction("gather", dest->getElementType());
llvm::Function *F = nullptr;
if (indices->getElementType() == ElemKind::Int64ITy) {
F = getFunction("gather64", dest->getElementType());

This comment was marked as off-topic.

@yinghai
Copy link
Contributor Author

yinghai commented Nov 15, 2018

Comments addressed.

isValid &= checkType(getIndices(), ElemKind::Int64ITy, this);
isValid &= checkType(
getIndices(),
llvm::makeArrayRef({ElemKind::Int64ITy, ElemKind::Int32ITy}), this);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@rdzhabarov rdzhabarov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!


{
ONNXModelLoader onnxLD(netFilename, {"data", "indices"},
{&data.getType(), &indices.getType()}, *F);
output = onnxLD.getSingleOutput();
}
EE.compile(CompilationMode::Infer, F);

This comment was marked as off-topic.

This comment was marked as off-topic.

@yinghai yinghai merged commit f9ee324 into pytorch:master Nov 15, 2018
@yinghai yinghai deleted the g2 branch November 15, 2018 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants