-
Notifications
You must be signed in to change notification settings - Fork 956
clearState #526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
clearState #526
Conversation
Module.lua
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than creating a new tensor, isn't it better to do self.output:set() which resets the tensor? Then all references to it are kept intact.
|
👍 |
|
@szagoruyko would it be possible to use |
|
(sorry, it was already mentioned by @dominikgrewe). |
|
@fmassa I fixed Normalize. |
SpatialMaxPooling.lua
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will always malloc a new self.indices on updateOutput. You want to do: self.indices = self.indices or input.new()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops.
That's because those modules take references to tensors that it doesn't own, right? I think this should generally be avoided and instead of things like |
|
@dominikgrewe I'd rather merge this and fix these cases |
|
I'm just a bystander, since not using this, but is there any reason you couldn't have a generic method in |
|
@hughperkins this would erase all module weights/bias. Some generic solutions were proposed in the past, like the sanitize from @soumith, but it turns out that those generic functions never worked in all cases, so a module-specific solution should be implemented in order for it to always work. |
|
Ah, I see. Good point :-) |
|
Hmmm... is the name misleading in fact? Should it be perhaps |
|
@hughperkins well, it not only clear the internal buffers, but also the state variables |
|
(By the way, if anyone is wondering, why would one ever want to clear the weights? well... in char-rnn-er, I'm saving the current network , and in order to avoid going down the rabbit hole of figuring out how to convert between saved cuda-network, cpu-network, cl-network , so I save in two parts:
The first part, saving the network architecture, could plausibly be accomplished by saving the class structure of the network, including the module hyper-parameters (size of each Linear for example), but with neither weights nor buffers (nor hidden state, in case of LSTM module). |
|
@hughperkins this issue of converting from cuda/float (with cudnn) is going to be addressed by soumith/cudnn.torch#76 . This PR addresses the issue of not having to hack around saving parameters separately from the model definition. You simply call Both PRs combined make model definition/saving/loading much simpler, without having to write specific code for converting between frameworks (nn/cudnn), as done in here for example |
Yes. Hence, why I dont attempt to save the network, just the hyperparameters used to create it :-) But, you are right, shared weights means that just storing the network object hierarchy, without weights, is not a generic solution to the problem. (Hmmm, unless we cached the function calls used to define the shared weights, within each object perhaps, so we can rerun those on demand). |
|
I do however feel intuitively somehow that separating out the definition of the network, from the weight parameters, is somehow clean, and relatively unpainful. Especially since we already have the technology to suck all the weights from a network, and drain them to disk, using |
|
@hughperkins for that, each module need to know how to reconstruct it's own constructor from itself. As we only accept ordered arguments in nn (and some input arguments are not stored in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clearState is missing here, it should clean the following fields.
Not important note: I'd do in here for consistence something like
self.gradInput[1] = self.gradInput[1] or v1.new()
self.gradInput[2] = self.gradInput[2] or v1.new()|
I think we should have some clear guidelines though as to how
|
|
Some missing modules / tensors: Cosine, L1Penalty (loss), Linear (addBuffer), LookupTable (_count, _input, copiedInput), Mean (_gradInput), MixtureTable, PairwiseDistance (outExpand, grad, ones), Reshape (_input, _gradOutput), RRelu (noise), SoftMin (mininput), SoftSign, SparseLinear (lastInput), SpatialDivisiveNormalization, SpatialSubtractiveNormalization (don't forget its internal modules). Also as Dominik says, it would be good to preserve outside references to temporary tensors where possible. So the default behaviour of Module:clearState should be self[f]:set() for tensors (and equivalent for tables). Where this is not appropriate (e.g., nn.Identity()), these modules can override the default behaviour to call new() instead of set(). |
|
@szagoruyko @dominikgrewe's last comment with the guidelines is pretty reasonable. Can we get this PR to a conclusion soon? |
|
@dominikgrewe @soumith it requires big a change in nn, I am not going to do that. We can wait for someone to do it in a separate PR and then return to |
|
alright, addressed @dominikgrewe and @davidsaxton comments. There is now only |
SpatialMaxPooling.lua
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line should go before the THNN.*updateOutput call
|
there seem to still be changes needed from comments from me / francisco. is it the github diff tool, or did you not update? for example in CosineDistance and in DotProduct. |
|
@soumith thanks, missed that. updated and added clearState for Volumetric and Temporal poolings. |
BatchNormalization.lua
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth factoring that out in to a utils functions, e.g. nn.utils.clearTensors(self, {...})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had this at some point: https://gist.github.com/szagoruyko/bd39620cdf35964647b2 what do you think?
|
added |
|
i just did a pass on it, and looks good to me. once dominik signs off on it, it's ready for squash and merge. dont need to squash into 1 commit, but rebase so that the bugfix commits and the "set everywhere everywhere" are squashed up. |
PReLU.lua
Outdated
|
|
||
| function PReLU:updateOutput(input) | ||
| self.gradWeightBuf = self.gradWeightBuf or input.new() | ||
| self.gradWeightBuf2 = self.gradWeightBuf2 or input.new() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this lazy initialization required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 reasons
- this buffers are needed just because cutorch reduce in not generic enough and cannot reduce with += to a buffer, so should be removed sometime later
- I want to keep all buffers like this out of constructor for backend switching like
cudnn.convert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it is unrelated to the introduction of clearState?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is refactoring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind moving this to a separate PR in that case? This one is already big enough as it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, will factor out now
|
|
| if torch.isTensor(self[f]) then | ||
| self[f]:set() | ||
| elseif type(self[f]) == 'table' then | ||
| self[f] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for tables, this should probably recurse into the table - that way external references to that table or the tensors in it remain valid after the clearState() call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that's true, but I am a little bit afraid of doing that, who knows what's referenced there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand - as long as nn.utils.clear is only applied to fields of the module such as output and gradInput, or internal buffers, what could be referenced there that would be a problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for example you want to call clearState on one part of a network and it's connected to another part through SelectTable or NarrowTable like modules, then set() clearState will destroy tensors it points to, and everything else that was referenced there. Normally we should add new() clearState to all modules that reference tensors in tables like this before enabling set() in nn.Module I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about calling collectgarbage() at he end of nn.utils.clear() to free tensors that are no longer referenced (e.g. stuff built by zeroTableCopy() in SelectTable) after the = {} replacement?
|
factored out lazy inits and added Min and Max. |
|
For a v1 of clearState() this PR looks ready to me. +1 for merging this soon. |
|
all hail the sergey! thanks for the huge PR |
|
I'm having some troubles with th> setprintlevel(3); print {m28.modules[1].modules[1].modules[5].innode.data}
{
1 :
{
annotations :
{
_debugLabel : "[[C]]:-1"
}
input :
{
1 : CudaTensor - size: 16x64x56x56
}
mapindex : {...}
forwardNodeId : 8
gradOutput :
{
1 : CudaTensor - size: 16x64x56x56
}
}
}Whereas it worked perfectly on the same network but implemented with |
This introduces sanitization of nn modules so that
will save a model without any internal buffers.
nn.Modulehas it's default clearState that clearsoutputandgradInput.Any module that uses buffers should have it's own clearState, I went over nn modules and fixed the ones I knew, it is possible that I missed something.
Test added into
Jacobian.testIOthat calls clearState before serializing the module. There might be a better solution because it does not test the old behaviour.I will add
clearStatetocudnn,innand nn.DataParallelTable withcunntests later.Comments? the original name was
sanitize, we should discuss if the new name is better.Was implemented when I was at facebook, it's been in use for a couple of months, pretty stable.