Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] LossModule's functional parameters duplicate weights in memory #1769

Closed
wbinventor opened this issue Jan 3, 2024 · 10 comments
Closed

Comments

@wbinventor
Copy link

wbinventor commented Jan 3, 2024

The LossModule.convert_to_functional(...) method creates a deep copy of the parameters. If I understand correctly, this leads to the parameters being duplicated in memory and a larger memory footprint than necessary. Is my understanding correct? If so, why is this necessary? Is there any way for the LossModule simply contain a single reference to the weights for, e.g., its actor and critic TorchModules?

This can be seen in the following line:

self.__dict__[module_name] = deepcopy(module)

This is the specific snippet of code containing the deep copy that this question pertains to:

        # set the functional module: we need to convert the params to non-differentiable params
        # otherwise they will appear twice in parameters
        with params.apply(
            self._make_meta_params, device=torch.device("meta")
        ).to_module(module):
            # avoid buffers and params being exposed
            self.__dict__[module_name] = deepcopy(module)
@wbinventor wbinventor changed the title LossModule's functional parameters duplicate weights in memory [QUESTION] LossModule's functional parameters duplicate weights in memory Jan 3, 2024
@wbinventor
Copy link
Author

I am confused by this deepcopy of the weights. It seems that after just one gradient update, the loss_module.parameters() will differ from the actor/critic TensorDictModule used to generate rollouts to insert into a replay buffer, since two separate copies of the weights will are used to generate rollouts vs. compute the loss. I'm sure I must be missing something here, so any clarifications would be greatly appreciated!

@vmoens
Copy link
Contributor

vmoens commented Jan 4, 2024

Thanks for posting this!
The important bit in the code snippet you linked is the context manager: we take the parameters, send them to "meta" device (ie, create a stateless copy of the parameters) and then populate temporarily the module with these.
This is the module instance we will copy.
After we exit the context manager, the module retrieves its original parameters, and only a stateless module is stored.

It seems that after just one gradient update, the loss_module.parameters() will differ from the actor/critic TensorDictModule used to generate rollouts to insert into a replay buffer, since two separate copies of the weights will are used to generate rollouts vs. compute the loss.

Not really since we always call the stored module with a functional call!

I hope that clarifies things!

@vmoens vmoens closed this as completed Jan 4, 2024
@wbinventor
Copy link
Author

wbinventor commented Jan 4, 2024

Thanks, @vmoens! That clarification about the context manager is helpful 🙏

However, can you confirm if my understanding is correct that as a result of the deepcopy in LossModule, the memory footprint is (approximately) twice as large since the actor/critic model parameters are duplicated?

At least when I instantiate a LossModule, I notice that the (e.g., CUDA) memory footprint approximately doubles as a result of this deepcopy. This is an issue for large models, so I'm trying to understand if this duplication of parameters can be avoided.

@vmoens
Copy link
Contributor

vmoens commented Jan 4, 2024

However, can you confirm if my understanding is correct that as a result of the deepcopy in LossModule, the memory footprint is (approximately) twice as large since the actor/critic model parameters are duplicated?

No, a tensor on "meta" device has no content, so it has (approximately) 0 memory footprint.

If the memory increases by a factor 2x there must be an issue somewhere, this isn't the indented behaviour (it's a bug).

@wbinventor
Copy link
Author

Ok, that's what I thought as well re: "meta" device behavior.

I can very clearly see the memory footprint double when the deepcopy on line 290 of LossModule is called. Should this be reported separately as a bug?

@vmoens vmoens reopened this Jan 4, 2024
@vmoens
Copy link
Contributor

vmoens commented Jan 4, 2024

Nope, I will have a look and push a patch!

@vmoens
Copy link
Contributor

vmoens commented Jan 4, 2024

Do you have any way to check that the memory doubles?
This piece of code indicates that the parameters are on "meta" device as expected

from torchrl.modules import MLP, QValueActor
from torchrl.data import OneHotDiscreteTensorSpec
from torchrl.objectives import DQNLoss
n_obs, n_act = 4, 3
value_net = MLP(in_features=n_obs, out_features=n_act)
spec = OneHotDiscreteTensorSpec(n_act)
actor = QValueActor(value_net, in_keys=["observation"], action_space=spec)
loss = DQNLoss(actor, action_space=spec)
list(loss.value_network.parameters())
[Parameter containing:
 tensor(..., device='meta', size=(32, 4)),
 Parameter containing:
 tensor(..., device='meta', size=(32,)),
 Parameter containing:
 tensor(..., device='meta', size=(32, 32)),
 Parameter containing:
 tensor(..., device='meta', size=(32,)),
 Parameter containing:
 tensor(..., device='meta', size=(32, 32)),
 Parameter containing:
 tensor(..., device='meta', size=(32,)),
 Parameter containing:
 tensor(..., device='meta', size=(3, 32)),
 Parameter containing:
 tensor(..., device='meta', size=(3,))]

@vmoens
Copy link
Contributor

vmoens commented Jan 4, 2024

Are you sure the memory footprint isn't twice as big because of the target parameters?

@wbinventor
Copy link
Author

I was unable to reproduce this with some simplified code (e.g., the example scripts), and have determined from that my full code contains some other, non-parameter tensor attributes on my nn.Module that are duplicated with the deepcopy. I'm closing this issue since the deepcopy of "meta" parameters is not the reason that the memory footprint doubles.

@vmoens
Copy link
Contributor

vmoens commented Jan 5, 2024

Got it
Unfortunately, as of now, we have to copy the module. We took care of making this work with parameters and buffers. Tensors stored as non-buffer, non-parameters are a strange thing to handle in general. They usually won't be part of your state-dict, they won't be cast when you call module.cuda() etc. They are usually regarded as non-proper usage of nn.Module.

We could design an ad-hoc strategy to avoid deepcopying the non-parameter, non-buffer tensors, but I'm not sure that this is what users will want (I can imagine that some users will want them to be copied, others no).

If you feel like this deepcopy is causing troubles in your use case, I'd be happy to look at an adequate solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants