-
Notifications
You must be signed in to change notification settings - Fork 138
Reduce explicit types required when instantiating the SGD optimizer. #28
Conversation
Previously, in order to instantiate the Optimizer, you had to call `type(of: model)` and pass that into the Optimizer constructor in order to get type inference to pick the right type for `Model`. This could be a little confusing for new users. This commit proposes an alternate way to write this: ```swift let optimizer = SGD(learningRate: 0.01, for: model) ``` The above formulation is clear and readable. It avoids any unnecessary typing of types, and should be generalizable to the other optimizers. (This is left as an exercise for the reader. :-D) We rely on the Swift optimizer to do the right thing such that only a reference to the model is passed to the `SGD.init` call so that we don't pay for the cost of a full model copy (which could eventually be very expensive). This is deemed to be a reasonably safe assumption, especially given the ABI standardization in Swift 5.
|
Awesome, so much nicer and less foreign, thanks! Random question for @rxwei and @dan-zheng - would marking the argument explicitly as __shared avoid the possibility of a copy? |
Yep, that is exactly what I was going to suggest! |
rxwei
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks like a pretty idiomatic use. My other suggestion is to apply this to all other optimizers at once -- they are still using metatypes right now. It would be quite a hassle to update broken models (and toolchains) in different patches.
Should go in after #28 gets merged.
|
Any reason not to put the model argument first? I’m OK either way, but having it first would have been my initial intuition. |
I think that intuition is definitely plausible, but here are my reasons not to make
|
| } | ||
| let optimizer = SGD<Classifier, Float>(learningRate: 0.02) | ||
| var classifier = Classifier(hiddenSize: 4) | ||
| let optimizer = SGD(learningRate: 0.02, for: classifier) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's actually a much more serious issue: Have you actually tested it to make sure this really works?
I don't think this would train at all. Floating point literal 0.02 defaults to Double, so the Scalar associated type in SGD would be double. Then it'll try to get key paths to Tensor<Double>, which do not exist at all in a model where all parameters are Tensor<Float>. So my hypothesis is that this will compile but definitely fail to train. This also exposes how confusing it would be to not make optimizers take a currency scalar type explicitly or not make the model provide one.
Potential solutions:
- Make the
Layerprotocol actually require aScalarassociated type as the currency type, but that would be really unusable because it would trigger a lot of protocol conformance errors when the user didn't specify aScalartype alias. - Still require a scalar type explicitly in the optimizer. The way to do that is to require a "scalarType" metatype argument as before: When it's specified as an argument, everything generic parameter gets inferred; when it's not specified as an argument, the compiler would expect it to be provided as a generic type parameter. The crucial thing here is to make
Scalarnot be inferred from thelearningRate:argument -- I thought about this earlier, and this is why I added an explicitscalarType:argument in the first place.
My conclusion is that our optimizers (without undergoing a massive change in all building blocks) really need to be explicit about what the parameter type is that they are optimizing. Given all the trade-offs, making the scalar type explicit on the optimizer side seems to be the most flexible approach that does not come at the cost of making mixed precision models hard to express.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with solution 2: with the current design, scalar type should be made explicit for optimizers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, you're right that what happens is that it doesn't train. (I didn't notice because it compiled, and there was a previously failing test.)
I've filed TF-302 to follow up with the larger usability issue, and fixed up the API and tests so they're all passing now.
- Adding back in the scalarType - Fixing the tests - Updating the other optimizers
Thanks for the suggestion @jekbradbury. Definitely much better.
|
Whoops, just saw @rxwei 's response to @jekbradbury 's comment about ordering (after making the change to the code). Sorry! I'd like to understand this a bit further. Maybe we can all discuss this later today? |
|
The argument ordering is not a big deal as it's mainly about style -- sure, we can chat about it today. But I think it's critical to get the parameter type issue resolved before discussing other minor issues. |
|
This is breaking a lot of existing code that use |
I'll make a PR that does this! |
Previously, in order to instantiate the Optimizer, you had to call
type(of: model)and pass that into the Optimizer constructor in order to gettype inference to pick the right type for
Model. This could be a littleconfusing for new users.
This commit proposes an alternate way to write this:
The above formulation is clear and readable. It avoids any unnecessary typing
of types, and should be generalizable to the other optimizers. (This is left as
an exercise for the reader. :-D)
We rely on the Swift optimizer to do the right thing such that only a reference
to the model is passed to the
SGD.initcall so that we don't pay for the costof a full model copy (which could eventually be very expensive). This is deemed
to be a reasonably safe assumption, especially given the ABI standardization in
Swift 5.