-
Notifications
You must be signed in to change notification settings - Fork 538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.cond operator not supported on a simple example #4028
Comments
I ran the code and got a similar error message. Correct me if I'm wrong: |
Yes, I think so, after working on this the best workaround I found is to split the model into two models and then write the "if" logic on the application level. This is not ideal for me because I want to compile the IR afterwards but it solves the issue. |
Hello @JibAxelera and @amemov I believe that under the hood, torch_mlir.fx relies on FX symbolic tracing of a nn.Module. Potential workaround: class CondNetwork(nn.Module):
def __init__(self):
super(CondNetwork, self).__init__()
self.confidence_threshold = 2
self.linear1 = nn.Linear(3072, 3)
self.linear2 = nn.Linear(3072, 3)
def forward(self, x):
condition = torch.mean(x) > self.confidence_threshold
feature = x.clone().flatten()
feature1 = self.linear1(feature)
feature2 = self.linear2(feature)
return torch.where(condition, feature1, feature2) With this code, I can successfully run: module = export_and_import(model, torch.ones(1, 3, 32, 32), output_type="torch") Final caveat: if the sub-networks are large, computing both might be inefficient. In principle, TorchScript (torch.jit.script) can handle truly dynamic data-based branching, but I’m not sure if Torch-MLIR still supports direct ScriptModule importing. As far as I know, the FX-based path is the recommended or primary route at the moment. |
Very clear response, thank you. Would be nice if we could route samples dynamicly and export. Is this a feature request that makes sense ? Or is this out of the scope of torch-mlir ? |
@JibAxelera I think it should be in the scope of torch-mlir, since semantically your code above is acceptable for PyTorch. I wouldn't mind implementing this feature myself. |
I'm trying to implement this (my first attempt at MLIR stuff), generates MLIR for torch.cond (although not necessarily technically correct 😄). Tested it in IREE; it compiles and works in the runtime, for the given example in this issue. FX graph and generated MLIR graph: https://gist.github.com/thomasverelst/3e4d564a0ad6aebadd227c8c5b8cadac
Definitely open for feedback! |
Issue :
Trying to implement in a neural network a logic that routes dynamically a sample based on some condition. I built a dummy example of how the network should look like and I would like to export this model to MLIR. When I try to do so using torch-mlir, I get an error. I would like to know if the operator torch.cond is not supported or if my implementation is just wrong.
Steps to reproduce :
Just run this code :
You should get this error :
Additional informations
torch version : 2.7.0.dev20250210+cpu
torchvision version : torchvision-0.22.0.dev20250210+cpu
torch_mlir version : 20250127.357
The text was updated successfully, but these errors were encountered: