Skip to content

Conversation

@szagoruyko
Copy link
Member

This introduces sanitization of nn modules so that

torch.save('model.t7', model:clearState())

will save a model without any internal buffers.
nn.Module has it's default clearState that clears output and gradInput.
Any module that uses buffers should have it's own clearState, I went over nn modules and fixed the ones I knew, it is possible that I missed something.

Test added into Jacobian.testIO that calls clearState before serializing the module. There might be a better solution because it does not test the old behaviour.

I will add clearState to cudnn, inn and nn.DataParallelTable with cunn tests later.

Comments? the original name was sanitize, we should discuss if the new name is better.

Was implemented when I was at facebook, it's been in use for a couple of months, pretty stable.

Module.lua Outdated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than creating a new tensor, isn't it better to do self.output:set() which resets the tensor? Then all references to it are kept intact.

@nicholas-leonard
Copy link
Member

👍

@andreaskoepf
Copy link
Contributor

@szagoruyko would it be possible to use t:set() instead of t = nil - at least for some of the tensors of nn.Module? See my last comment here. It would also allow to serialize tensor refs.

@andreaskoepf
Copy link
Contributor

(sorry, it was already mentioned by @dominikgrewe).

@szagoruyko
Copy link
Member Author

@fmassa I fixed Normalize.
@dominikgrewe @andreaskoepf calling set() on output and gradInput is tricky because it might cause referenced tensors in other places to be reseted, which is unexpected. For example ConcatTable, Narrow or Identity will not work as expected after clearState.
I have changed to use set in other places where it is possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will always malloc a new self.indices on updateOutput. You want to do: self.indices = self.indices or input.new()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops.

@dominikgrewe
Copy link
Member

@dominikgrewe @andreaskoepf calling set() on output and gradInput is tricky because it might cause referenced tensors in other places to be reseted, which is unexpected. For example ConcatTable, Narrow or Identity will not work as expected after clearState.

That's because those modules take references to tensors that it doesn't own, right? I think this should generally be avoided and instead of things like self.output = input we should do self.output:set(input). This also has the advantage that one can keep a reference to the module's output from outside (as is the case for most other modules).

@szagoruyko
Copy link
Member Author

@dominikgrewe I'd rather merge this and fix these cases self.output = input in a separate PR.

@hughperkins
Copy link
Contributor

I'm just a bystander, since not using this, but is there any reason you couldn't have a generic method in Module that does something like:

function clearState()
   for k,v in pairs(self) do
      if torch.type(k):find('Tensor') then
      self[k] = nil
   end
end

@fmassa
Copy link
Contributor

fmassa commented Jan 6, 2016

@hughperkins this would erase all module weights/bias. Some generic solutions were proposed in the past, like the sanitize from @soumith, but it turns out that those generic functions never worked in all cases, so a module-specific solution should be implemented in order for it to always work.

@hughperkins
Copy link
Contributor

Ah, I see. Good point :-)

@hughperkins
Copy link
Contributor

Hmmm... is the name misleading in fact? Should it be perhaps clearBuffers, rather than clearState? clearState to my mind would imply: clear everything, including weights.

@fmassa
Copy link
Contributor

fmassa commented Jan 6, 2016

@hughperkins well, it not only clear the internal buffers, but also the state variables output and gradInput. But I agree that the name should be discussed (it's even written in the first post)

@hughperkins
Copy link
Contributor

(By the way, if anyone is wondering, why would one ever want to clear the weights? well... in char-rnn-er, I'm saving the current network , and in order to avoid going down the rabbit hole of figuring out how to convert between saved cuda-network, cpu-network, cl-network , so I save in two parts:

  • network architecture, ie how many hidden layers, size of each layer. Essentially this could be saved as the network, without buffers, and also without weights, though I'm simply saving the network sizes for now, define netParams , store them
  • the weights themselves, which I simply clone to Floats, then write to disk get params store params

The first part, saving the network architecture, could plausibly be accomplished by saving the class structure of the network, including the module hyper-parameters (size of each Linear for example), but with neither weights nor buffers (nor hidden state, in case of LSTM module).
)

@fmassa
Copy link
Contributor

fmassa commented Jan 6, 2016

@hughperkins this issue of converting from cuda/float (with cudnn) is going to be addressed by soumith/cudnn.torch#76 .
The problem with your approach (if you save only the definition, without the buffers, weights, etc) is that you will need sometimes to recreate the sharing when you load the model, which can be painful.

This PR addresses the issue of not having to hack around saving parameters separately from the model definition. You simply call clearStates before saving (eventually converting to CPU to make it loadable everywhere), this will keep all the necessary information without the buffers.

Both PRs combined make model definition/saving/loading much simpler, without having to write specific code for converting between frameworks (nn/cudnn), as done in here for example

@hughperkins
Copy link
Contributor

The problem with your approach (if you save only the definition, without the buffers, weights, etc) is that you will need sometimes to recreate the sharing when you load the model, which can be painful.

Yes. Hence, why I dont attempt to save the network, just the hyperparameters used to create it :-) But, you are right, shared weights means that just storing the network object hierarchy, without weights, is not a generic solution to the problem. (Hmmm, unless we cached the function calls used to define the shared weights, within each object perhaps, so we can rerun those on demand).

@hughperkins
Copy link
Contributor

I do however feel intuitively somehow that separating out the definition of the network, from the weight parameters, is somehow clean, and relatively unpainful. Especially since we already have the technology to suck all the weights from a network, and drain them to disk, using getParameters(). ie, it'd be nice to be able to do something like:

net:writeDefinition('somefile.lua')

$ ls -lh somefile.lua

somefile.lua  2k

$ cat somefile.lua

local function make_net()
   local net = nn.Sequential()
   net:add(nn.Linear(32,16))
   net:add(nn.LogSoftMax())
   net:shareWeights(... whatever we do to do this...)
   return net
end

@fmassa
Copy link
Contributor

fmassa commented Jan 6, 2016

@hughperkins for that, each module need to know how to reconstruct it's own constructor from itself. As we only accept ordered arguments in nn (and some input arguments are not stored in self, see for example the Linear module), this would require to manually go through all the modules and write a :writeDefinition to it, which corresponds to as much work as the current clearState PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clearState is missing here, it should clean the following fields.

Not important note: I'd do in here for consistence something like

self.gradInput[1] = self.gradInput[1] or v1.new()
self.gradInput[2] = self.gradInput[2] or v1.new()

@dominikgrewe
Copy link
Member

I think we should have some clear guidelines though as to how clearState (or clearBuffers or maybe clearTemporaries?) should behave. How about something like this:

  1. After calling the function, all temporary tensors should be in the same state as they were when the module was constructed.
  2. Any outside references to temporary tensors should be preserved.

@davidsaxton
Copy link
Contributor

Some missing modules / tensors: Cosine, L1Penalty (loss), Linear (addBuffer), LookupTable (_count, _input, copiedInput), Mean (_gradInput), MixtureTable, PairwiseDistance (outExpand, grad, ones), Reshape (_input, _gradOutput), RRelu (noise), SoftMin (mininput), SoftSign, SparseLinear (lastInput), SpatialDivisiveNormalization, SpatialSubtractiveNormalization (don't forget its internal modules).

Also as Dominik says, it would be good to preserve outside references to temporary tensors where possible. So the default behaviour of Module:clearState should be self[f]:set() for tensors (and equivalent for tables). Where this is not appropriate (e.g., nn.Identity()), these modules can override the default behaviour to call new() instead of set().

@soumith
Copy link
Member

soumith commented Jan 19, 2016

@szagoruyko @dominikgrewe's last comment with the guidelines is pretty reasonable. Can we get this PR to a conclusion soon?

@szagoruyko
Copy link
Member Author

@dominikgrewe @soumith it requires big a change in nn, I am not going to do that. We can wait for someone to do it in a separate PR and then return to clearState.
@davidsaxton thanks, I will try to find time to add these.

@szagoruyko
Copy link
Member Author

alright, addressed @dominikgrewe and @davidsaxton comments. There is now only set everywhere, except Identity.lua and Container.lua where I had to add special new clearState. This is not the way I imagined it to do, but it can be fixed later. Tests are passing, rebased on master.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line should go before the THNN.*updateOutput call

@soumith
Copy link
Member

soumith commented Feb 2, 2016

there seem to still be changes needed from comments from me / francisco. is it the github diff tool, or did you not update? for example in CosineDistance and in DotProduct.

@szagoruyko
Copy link
Member Author

@soumith thanks, missed that. updated and added clearState for Volumetric and Temporal poolings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth factoring that out in to a utils functions, e.g. nn.utils.clearTensors(self, {...})

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had this at some point: https://gist.github.com/szagoruyko/bd39620cdf35964647b2 what do you think?

@szagoruyko
Copy link
Member Author

added nn.util.clear and doc.

@soumith
Copy link
Member

soumith commented Feb 3, 2016

i just did a pass on it, and looks good to me. once dominik signs off on it, it's ready for squash and merge. dont need to squash into 1 commit, but rebase so that the bugfix commits and the "set everywhere everywhere" are squashed up.

PReLU.lua Outdated

function PReLU:updateOutput(input)
self.gradWeightBuf = self.gradWeightBuf or input.new()
self.gradWeightBuf2 = self.gradWeightBuf2 or input.new()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this lazy initialization required?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 reasons

  • this buffers are needed just because cutorch reduce in not generic enough and cannot reduce with += to a buffer, so should be removed sometime later
  • I want to keep all buffers like this out of constructor for backend switching like cudnn.convert

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it is unrelated to the introduction of clearState?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is refactoring

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind moving this to a separate PR in that case? This one is already big enough as it is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, will factor out now

@dominikgrewe
Copy link
Member

nn.Min and nn.Max also need a clearState method.

if torch.isTensor(self[f]) then
self[f]:set()
elseif type(self[f]) == 'table' then
self[f] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for tables, this should probably recurse into the table - that way external references to that table or the tensors in it remain valid after the clearState() call.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's true, but I am a little bit afraid of doing that, who knows what's referenced there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand - as long as nn.utils.clear is only applied to fields of the module such as output and gradInput, or internal buffers, what could be referenced there that would be a problem?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for example you want to call clearState on one part of a network and it's connected to another part through SelectTable or NarrowTable like modules, then set() clearState will destroy tensors it points to, and everything else that was referenced there. Normally we should add new() clearState to all modules that reference tensors in tables like this before enabling set() in nn.Module I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about calling collectgarbage() at he end of nn.utils.clear() to free tensors that are no longer referenced (e.g. stuff built by zeroTableCopy() in SelectTable) after the = {} replacement?

@szagoruyko
Copy link
Member Author

factored out lazy inits and added Min and Max.

@andreaskoepf
Copy link
Contributor

For a v1 of clearState() this PR looks ready to me. +1 for merging this soon.

soumith added a commit that referenced this pull request Feb 9, 2016
@soumith soumith merged commit 948ac6a into torch:master Feb 9, 2016
@soumith
Copy link
Member

soumith commented Feb 9, 2016

all hail the sergey! thanks for the huge PR

@Atcold
Copy link
Contributor

Atcold commented Feb 12, 2016

I'm having some troubles with :clearState() and nngraph... it looks like not much is cleared. Perhaps I'm doing something wrong...

th> setprintlevel(3); print {m28.modules[1].modules[1].modules[5].innode.data}
{
  1 : 
    {
      annotations : 
        {
          _debugLabel : "[[C]]:-1"
        }
      input : 
        {
          1 : CudaTensor - size: 16x64x56x56
        }
      mapindex : {...}
      forwardNodeId : 8
      gradOutput : 
        {
          1 : CudaTensor - size: 16x64x56x56
        }
    }
}

Whereas it worked perfectly on the same network but implemented with ParallelTables().
I thought that, since nngraph extends nn.Modules(), it would have benefit as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants