Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨[Feature] Support list and namedtuple input types to forward function #798

Closed
chaoz-dev opened this issue Jan 8, 2022 · 9 comments
Closed
Assignees
Labels
feature request New feature or request release: v1.2 Tagged to be included in v1.2

Comments

@chaoz-dev
Copy link
Contributor

chaoz-dev commented Jan 8, 2022

Is your feature request related to a problem? Please describe.

Currently, the forward function only supports tensor input types when compiling. However, sometimes we wish to supply many tensors into the forward function at once (say, greater than 10); this results in a very long forward API call where we have to list every tensor individually when calling forward. It would be helpful if we could pass in a single container containing these tensors all at once instead, which results in a much cleaner API call.

For this specific request, I focus on the list and namedtuple input types first, since these should cover most basic uses cases (and should functionally satisfy named tensor key-value pair type inputs).

Describe the solution you'd like

Instead of supporting only the following, where we need to supply torch.Tensors into forward:

  DEVICE = torch.device("cuda:0")                                                                                            
  SHAPE = (1, 1)        

  torch.manual_seed(0)                                                                                                                                                                                                                                                                                                                                 

  class Model(torch.nn.Module):                                                                                              
      def __init__(self):                                                                                                    
          super().__init__()                                                                                                 
                                                                                                                          
      def forward(self, a, b):                                                                                               
          return a - b                                                                                      

  if __name__ == "__main__":                                                                                                 
      tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)                                                        
                                                                                                                             
      model = Model().eval().to(DEVICE)                                                                                      
      out = model(tensor, tensor)                                                                                                   
                                                                                                                                                                                                                                                          
      model_trt = torch_tensorrt.compile(                                                                                    
          model,                                                                                                             
          inputs=[                                                                                                           
              torch_tensorrt.Input(shape=SHAPE),                                                                             
              torch_tensorrt.Input(shape=SHAPE),                                                                             
          ],                                                                                                                 
          enabled_precisions={torch.float},                                                                                  
      )                                                                                                                      
      out_trt = model(tensor, tensor)                                                                                               
                                                                                                                                                                                                                                                         
      assert torch.max(torch.abs(out - out_trt)) < 1e-6                                                                      

Support also inputting namedtuple or list into forward:

  DEVICE = torch.device("cuda:0")                                                                                            
  SHAPE = (1, 1)        

  torch.manual_seed(0)  

  Input = namedtuple('Input', ['t1', 't2'])                                                                                  
                                                                                                                             
  class Model(torch.nn.Module):                                                                                              
      def __init__(self):                                                                                                    
          super().__init__()                                                                                                 
                                                                                                                          
      def forward(self, input_: Input):                                                                                      
          return input_.t1 - input_.t2                                                                                       

  if __name__ == "__main__":                                                                                                 
      tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)                                                        
      input_ = Input(tensor, tensor)                                                                                         
                                                                                                                             
      model = Model().eval().to(DEVICE)                                                                                      
      out = model(input_)                                                                                                   
                                                                                                                                                                                                                                                          
      model_trt = torch_tensorrt.compile(                                                                                    
          model,                                                                                                             
          inputs=[                                                                                                           
              torch_tensorrt.Input(shape=SHAPE),                                                                             
              torch_tensorrt.Input(shape=SHAPE),                                                                             
          ],                                                                                                                 
          enabled_precisions={torch.float},                                                                                  
      )                                                                                                                      
      out_trt = model(input_)                                                                                               
                                                                                                                                                                                                                                                         
      assert torch.max(torch.abs(out - out_trt)) < 1e-6                                                                      

Describe alternatives you've considered

Currently the only alternative is to supply tensors directly into the forward function; supplying namedtuples will cause the compilation to segfault, and supplying lists will cause the compilation to fail to recognize the input altogether.

Additional context

  • For simplicity, the input containers should contain ONLY tensors (implying that we disallow nested containers). Containers with mixed input types are ignored.
  • Furthermore, there must be a bijection between the tensors in the container and the sizes provided into the compile call; ie. there must be one Input size for each tensor in the container and both are taken in the same order.
  • We can mix tensors and containers into the forward call (eg. forward(x: torch.Tensor, y: List[torch.Tensor], z: namedtuple[torch.Tensor])). Any other types are treated as they are currently when input.
@chaoz-dev chaoz-dev added the feature request New feature or request label Jan 8, 2022
@chaoz-dev
Copy link
Contributor Author

@narendasan Let me know if the behaviors listed under Additional context make sense. In particular, I believe we currently ignore other input types going into forward... if we allow them at all?

I can try taking a crack at the implementation here later when I get a chance.

@chaoz-dev
Copy link
Contributor Author

chaoz-dev commented Jan 8, 2022

Ah this might be a duplicate of #428, although this request might be slightly less ambitious.

@narendasan
Copy link
Collaborator

@chaoz-dev Yeah this is reasonable. We have been working on a design doc for these sort of features here #629. @inocsin Has been working on the first steps here with arbitrary mixes of tuples (since they are fixed size) and tensors as inputs and outputs. Need to check with him on if he has a public dev branch but help here is greatly appreciated.

@chaoz-dev
Copy link
Contributor Author

Sounds good, I'll take a look at the design doc and make some suggestions there for review. I had a quick look at the code and my naive first pass at this is to unpack input containers in torch_tensorrt/ts/_compiler.py in the compile function before it hits the actual compilation step, so the compilation always sees a flat list of tensors... I believe this should satisfy the basic aspects of inputting an iterable container of tensors.

@narendasan
Copy link
Collaborator

Seems reasonable to take the step of adding support for one collection of inputs of any type. But we need to do this in compiler.cpp since we need to support C++ and Python APIs as well as we need to be able to construct a new module with the correct interface otherwise users cannot reuse the same input formatting code in their applications.

@chaoz-dev
Copy link
Contributor Author

Yeah that makes sense. I'll take a look at this shortly.

@chaoz-dev
Copy link
Contributor Author

Deferring to @inocsin in #629 here

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@ncomly-nvidia ncomly-nvidia added the release: v1.2 Tagged to be included in v1.2 label Jul 26, 2022
@narendasan
Copy link
Collaborator

Initial feature support has been merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request release: v1.2 Tagged to be included in v1.2
Projects
None yet
Development

No branches or pull requests

3 participants