Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ API 'nn::Sequential' has inconsistent behavior with python conterpart #19499

Open
FluorineDog opened this issue Apr 19, 2019 · 3 comments
Open
Labels
module: cpp Related to C++ API triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@FluorineDog
Copy link

馃悰 Bug

To Reproduce

C++ version

class FakeNet : public torch::nn::Module {
  public:
    FakeNet() {
        id = register_module("id", nn::Sequential());
    }
    torch::Tensor forward(torch::Tensor x) {
        // Runtime Error: Cannot call forward() on an empty Sequential
        return id->forward(x); 
    }
  private:
    nn::Sequential id;
}; 

Python Version

class FakeNet(nn.Module):
    def __init___(self):
        self.id = nn.Sequential() 
    def forward(self, x):
        x = id(x) # OK
        return x

nn::Sequential doesn't overload operator() like other functions (e.g., nn::Conv2d, nn::BatchNorm), while in its forward function, it has an AT_CHECK on is_empty() == false, which makes using nn::Sequential as identity function impossible.

Suggestion

The C++ version should implement nn::Sequential::operator()(...), which should silently return x when is_empty() == true.

Environment

  • PyTorch Version (e.g., 1.0.1):
@FluorineDog FluorineDog changed the title C++ API 'nn::Sequential' has inconsistant behavior with python version C++ API 'nn::Sequential' has inconsistant behavior with python conterpart Apr 19, 2019
@FluorineDog FluorineDog changed the title C++ API 'nn::Sequential' has inconsistant behavior with python conterpart C++ API 'nn::Sequential' has inconsistent behavior with python conterpart Apr 19, 2019
@driazati driazati added the module: cpp Related to C++ API label Apr 19, 2019
@ailzhang
Copy link
Contributor

Yea our cpp API is not consistent with python api on this. Note that your python example was actually calling python id() instead of self.id().
A proper python example is:

import torch
import pdb
import torch.nn as nn


class FakeNet(nn.Module):
    def __init__(self):
        super(FakeNet, self).__init__()
        # self.conv = nn.Conv2d(1, 1, 1)
        self.conv = nn.Sequential()

    def forward(self, x):
        y = self.conv(x)
        return y


input = torch.rand(1, 1, 2, 2)
pdb.set_trace()
m = FakeNet()
print(m(input))

We should match the behavior I think.

@ailzhang
Copy link
Contributor

We're happy to accept an PR for it, would you minding sending one? @FluorineDog

@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 19, 2019
@FluorineDog
Copy link
Author

Gladly, I will have a look tomorrow, but I have to go to bed now :)
BTW, I'm working on a C++ demo on ResNet50, currently runnable, but it still needs to be refined. I will send a PR when I'm done.
Cheers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpp Related to C++ API triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants