Skip to content

Conversation

ruoqianguo
Copy link
Contributor

@ruoqianguo ruoqianguo commented Aug 26, 2022

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

When i run the below code, i found the shape of trt output mismatched pytorch. Thanks @peri044 for helping me find this bug.

import torch
import torch.nn as nn
import torch_tensorrt

class IndexTest(nn.Module):
    def __init__(self):
        super(IndexTest, self).__init__()

    def forward(self, input: torch.Tensor, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
        return input[x.long(), y.long(), z.long(), :]

def reproduce_error():
    torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Graph)
    model = IndexTest()

    input = torch.randn([4, 8, 8, 4]).cuda()
    x = torch.full([4, 13, 1], 1).int().cuda()
    y = torch.full([4, 13, 1], 2).int().cuda()
    z = torch.full([4, 13, 1], 3).int().cuda()
    test_output = model.forward(input, x, y , z)
    traced_model = torch.jit.trace(model, [input, x, y, z])

    trt_model = torch_tensorrt.compile(traced_model.eval().cuda(), inputs=[input, x, y, z],  **{
            "truncate_long_and_double": True,
        })
    converted_output = trt_model.forward(input, x, y ,z)

    print(test_output.shape)
    print(converted_output.shape)
    print(torch.sum(test_output-converted_output))

reproduce_error()

Fixes #1274

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>
Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>
Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>
@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: tests Issues re: Tests labels Aug 26, 2022
@ruoqianguo ruoqianguo changed the title Aten index Fix bug: correct the output shape of aten::index.Tensor Aug 26, 2022
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@narendasan narendasan merged commit d69651c into pytorch:master Aug 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: tests Issues re: Tests
Projects
None yet
3 participants