-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[C++ API] Implement builder style construction #7597
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reset parameters disappeared somewhere. I'm curious why we're departing from PyTorch in this department? I wonder why it's even there in the first place. @apaszke maybe you can shed some light on this?
variable_list CUDNN_forward(variable_list); | ||
variable_list autograd_forward(variable_list); | ||
|
||
void flatten_parameters_for_cudnn(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
enum RNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 }; | ||
// These must line up with the CUDNN mode codes: | ||
// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t | ||
enum class CuDNNMode { RNN_RELU, RNN_TANH, LSTM, GRU }; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
std::vector<Variable> hhb_; | ||
|
||
size_t number_of_gates_; | ||
bool has_cell_state_; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
}} // namespace torch::nn | ||
|
||
#define TORCH_PARAMETER(T, name) \ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay besides my nits above.
return this->name##_; \ | ||
} \ | ||
\ | ||
protected: \ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why certain things changed as they are. The weight names are now inconsistent with those in Python, and I really don't think we should mess with visibility in our macros. The fact that it automagically changes just because you declared a parameter will be a constant source of unclear errors (the user doesn't have the line that makes things protected in their code!)
std::shared_ptr<Derived> build() { | ||
auto module = std::make_shared<Derived>(static_cast<Derived&&>(*this)); | ||
module->reset(); | ||
return std::move(module); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
#define TORCH_PARAMETER(T, name) \ | ||
public: \ | ||
auto name(T new_##name)->decltype(*this) { \ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
return this->name##_; \ | ||
} \ | ||
\ | ||
protected: \ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -19,32 +19,32 @@ class ContainerListImpl : public CloneableModule<Derived> { | |||
} | |||
|
|||
std::shared_ptr<Module> add(std::shared_ptr<Module> m) { | |||
return append(m).children_.back(); | |||
return append(m).modules_.back(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
enum RNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 }; | ||
// These must line up with the CUDNN mode codes: | ||
// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t | ||
enum class CuDNNMode { RNN_RELU, RNN_TANH, LSTM, GRU }; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/cpp/api/container.cpp
Outdated
REQUIRE(model->param("test.l1.bias").size(0) == 3); | ||
REQUIRE(model->param("test.l1.weight").size(0) == 3); | ||
REQUIRE(model->param("test.l1.weight").size(1) == 10); | ||
REQUIRE(model->param("test.l1.weights").size(0) == 3); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Sorry for the weight naming changes, I didn't know that's what they were called in Python. "weight" sounded more like a single float to me than a tensor of "weights", but I'll change it back. As for the visibility of the value inside |
0028a70
to
db43a50
Compare
I actually think there's an argument to be made for making things as public as possible anyway, if we're to encourage hackability. I think not changing the visibility is more important and a bigger source of bugs than if people were to directly just use the private variables. The point of private variables, getters, and setters is that you can do arbitrary logic inside them, so you shouldn't touch the private variables, but presumably with a parameter macro like this we never worry about that issue at all. |
Yes, make 'em public! If the user really cares they can stop using the macro and write the getters/properties themselves. |
std::unique_ptr<Module> clone() const override { | ||
auto ptr = std::unique_ptr<Module>( | ||
new Derived(*static_cast<const Derived*>(this))); | ||
virtual void reset() = 0; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
return std::move(module); | ||
} | ||
|
||
std::shared_ptr<Module> clone() const override { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
bool transposed = false, | ||
bool with_bias = true, | ||
int groups = 1); | ||
struct ExpandingSize { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
variable_list forward(variable_list input) override; | ||
|
||
TORCH_PARAMETER(double, rate) = 0.5; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This is not the fault of this patch, but I was thinking about safety: shoudn't the parameters array should store
the parameters array is updated properly? |
db43a50
to
2226f3f
Compare
@ezyang Oh man, the fact that |
* upstream/master: Makes AccumulateGrad high priority in backwards passes (pytorch#7604) [C++ API] Implement builder style construction (pytorch#7597) C10D: Added TCPStore to support C10D store interface (pytorch#7560) [auto] Update onnx to ba86ec2 - Protobuf typing (onnx/onnx#982) onnx/onnx@ba86ec2 Add LBFGS optimization algorithm to C++ API (pytorch#7596)
* Implemented fused builder based construction mechanism * "weights" -> "weight" * Use int64_t instead of size_t everywhere in RNN * Extracted Conv::ExpandingSize into its own thing * Rename TORCH_PARAMETER to TORCH_ATTR * Added documentation * Fix weight names in batchnorm module
This PR implements our discussed "builder" style construction mechanism, where the class is fused with the builder itself. It is similar to the
KWARGS
mechanism in autogradpp, with some differences.Largely, this PR:
TORCH_PARAMETER
(name up for discussion, maybeTORCH_PROPERTY
?) used to give modules parameters/properties. It is written in a way so that it requires only 2 arguments instead of 5, likeAUTOGRAD_KWARG
didreset()
functionreset()
is called insidebuild()
, which finalizes construction, andclone()
Conv
, the call toat::conv<dimension>d(...)
is now left to a virtual function, implemented inConv1d
,Conv2d
andConv3d
Recommended review order:
include/torch/nn/module.h
include/torch/nn/modules/*.h
andsrc/nn/modules/*.cpp
CC @ezyang @ebetica @apaszke @jgehring