Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃悰 [Bug] Cannot save models using ExportedProgram if the model has weighted layers #2341

Closed
Tracked by #2262
peri044 opened this issue Sep 24, 2023 · 5 comments
Closed
Tracked by #2262
Assignees
Labels
Blocked [PyTorch] Issue is blocked by some limitation of PyTorch bug Something isn't working

Comments

@peri044
Copy link
Collaborator

peri044 commented Sep 24, 2023

Bug Description

If you have weighted layers in the graph, loading the model via ExportedProgram and running inference fails as the weights and inputs are on different device.

To Reproduce

import torch
import torch_tensorrt
import unittest

class MyModule(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
            self.relu = torch.nn.ReLU()

        def forward(self, x):
            conv = self.conv(x)
            relu = self.relu(conv)
            mul = relu*0.5
            return mul

input = torch.randn((1, 3, 224, 224), dtype=torch.float).to("cuda")
model = MyModule().eval().cuda()
from torch._export import export
from torch_tensorrt.dynamo.lowering import get_decompositions

with unittest.mock.patch(
    "torch._export.DECOMP_TABLE", get_decompositions(True)
):
    trt_exp_program = export(model, tuple([input]))

torch._export.save(trt_exp_program, "./trt.ep")
deserialized_prog = torch._export.load("./trt.ep")
out_pyt = model(input)
out_trt_ser = deserialized_prog(input).cuda()

Error message:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same 

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@peri044 peri044 added the bug Something isn't working label Sep 24, 2023
@peri044 peri044 self-assigned this Sep 24, 2023
@narendasan narendasan added the Blocked [PyTorch] Issue is blocked by some limitation of PyTorch label Nov 6, 2023
@peri044
Copy link
Collaborator Author

peri044 commented Nov 17, 2023

Make an issue in Pytorch repo

@peri044
Copy link
Collaborator Author

peri044 commented Nov 28, 2023

Related pytorch issue : pytorch/pytorch#114000

@angelayi
Copy link

angelayi commented Dec 8, 2023

Should this be closed as pytorch/pytorch#114000 has been fixed?

@peri044
Copy link
Collaborator Author

peri044 commented Dec 16, 2023

Verified and this bug can be closed.

@peri044 peri044 closed this as completed Dec 16, 2023
@johnzlli
Copy link

Verified and this bug can be closed.

It seems that pytorch/pytorch#114000 has been fixed. However the MR pytorch/pytorch#114695 was closed. So the bug is still active in the latest version of pytorch. Is there any updates?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Blocked [PyTorch] Issue is blocked by some limitation of PyTorch bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants