You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello!
I'm working on a optimization algorithm that requires computing the per sample gradients. Assuming the batch size is $N$ and the number of model parameters is $M$, I want to calculate $\partial \log p(\mathbf{x}^{(i)};\theta)/\partial \theta_j$, which is an $N \times M$ matrix. I found the [PER-SAMPLE-GRADIENTS](https://pytorch.org/tutorials/intermediate/per_sample_grads.html) tutorial and began my own experiments. As a proof of concept, I defined a generative model with a tractable likelihood, such as MADE (Masked Autoencoder for Distribution Estimation), PixelCNN, RNN, etc., and sepcified the log_prob and sample methods. I utilized the function transforms methods mentioned in the tutorial, but currently, it only works for MADE (I believed it would work for NADE and PixelCNN too, since these models need only one forward pass to calculate the log likelihood of $\mathbf{x}$. For RNN however, both sampling and inference require $N$ forward pass).
Below, I've provided my code snippets, and I'm interested in figuring out why it's not working for RNN. Making it work for RNN would significantly reduce the number of parameters for my research purpose.
Thank you!
Describe the bug
importmathimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFtorch.manual_seed(0)
classMADE(nn.Module):
'''A simple one-layer MADE (Masked Autoencoder for Distribution Estimation)'''def__init__(self, n=10, device='cpu', *args, **kwargs):
super().__init__()
self.n=nself.device=deviceself.weight=nn.Parameter(torch.randn(self.n, self.n) /math.sqrt(self.n))
self.bias=nn.Parameter(torch.zeros(self.n))
mask=torch.tril(torch.ones(self.n, self.n), diagonal=-1)
self.register_buffer('mask', mask)
defpred_logits(self, x):
returnF.linear(x, self.mask*self.weight, self.bias)
defforward(self, x):
logits=self.pred_logits(x)
log_probs=-F.binary_cross_entropy_with_logits(logits, x, reduction='none')
returnlog_probs.sum(-1)
@torch.no_grad()defsample(self, batch_size):
x=torch.zeros(batch_size, self.n, dtype=torch.float, device=self.device)
foriinrange(self.n):
logits=self.pred_logits(x)[:, i]
x[:, i] =torch.bernoulli(torch.sigmoid(logits))
returnxclassGRUModel(nn.Module):
'''GRU for density estimation'''def__init__(self, n=10, input_size=2, hidden_size=8, device='cpu'):
super().__init__()
self.n=nself.input_size=input_size# input_size=2 when x is binaryself.hidden_size=hidden_sizeself.device=deviceself.gru_cell=nn.GRUCell(self.input_size, self.hidden_size)
self.fc_layer=nn.Linear(self.hidden_size, 1)
defpred_logits(self, x, h=None):
x=torch.stack([x, 1-x], dim=1) # 1 -> (1, 0), 0 -> (0, 1), (batch_size, 2)h_next=self.gru_cell(x, h) # h_{i+1}logits=self.fc_layer(h_next).squeeze(1)
returnh_next, logitsdefforward(self, x):
log_prob_list= []
x=torch.cat([torch.zeros(x.shape[0], 1, dtype=torch.float, device=self.device), x], dim=1) # cat x_0h=torch.zeros(x.shape[0], self.hidden_size, dtype=torch.float, device=self.device) # h_0foriinrange(self.n):
h, logits=self.pred_logits(x[:, i], h)
log_prob=-F.binary_cross_entropy_with_logits(logits, x[:, i+1], reduction='none')
log_prob_list.append(log_prob)
returntorch.stack(log_prob_list, dim=1).sum(dim=1)
@torch.no_grad()defsample(self, batch_size):
x=torch.zeros(batch_size, self.n+1, dtype=torch.float, device=self.device)
foriinrange(self.n):
h, logits=self.pred_logits(x[:, i], h=Noneifi==0elseh)
x[:, i+1] =torch.bernoulli(torch.sigmoid(logits))
returnx[:, 1:]
if__name__=='__main__':
model=MADE()
# model = GRUModel()# Sample from the generative modelsamples=model.sample(128)
# Then I use the function transforms methods mentioned in the tutorial# to calculate the per sample meanfromtorch.funcimportfunctional_call, grad, vmapparams= {k: v.detach() fork, vinmodel.named_parameters()}
defloss_fn(log_probs):
returnlog_probs.mean(0)
defcompute_loss(params, sample):
batch=sample.unsqueeze(0)
log_prob=functional_call(model, (params,), (batch,))
loss=loss_fn(log_prob)
returnlossft_compute_grad=grad(compute_loss)
ft_compute_sample_grad=vmap(ft_compute_grad, in_dims=(None, 0))
ft_per_sample_grads=ft_compute_sample_grad(params, samples)
print(ft_per_sample_grads)
The above code works for MADE (I also check the values of gradients, they are correct!)
However, when I use model = GRUModel(), an error arises:
Traceback (most recent call last):
File "per_sample_grads.py", line 100, in<module>
ft_per_sample_grads = ft_compute_sample_grad(params, samples)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
return _flat_vmap(
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1380, in wrapper
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1245, in wrapper
output = func(*args, **kwargs)
File "per_sample_grads.py", line 94, in compute_loss
log_prob = functional_call(model, (params,), (batch,))
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
return nn.utils.stateless._functional_call(
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/utils/stateless.py", line 262, in _functional_call
return module(*args, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "per_sample_grads.py", line 63, in forward
h, logits = self.pred_logits(x[:, i], h)
File "per_sample_grads.py", line 54, in pred_logits
h_next = self.gru_cell(x, h) # h_{i+1}
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 1327, in forward
ret = _VF.gru_cell(
RuntimeError: output with shape [1, 8] doesn't match the broadcast shape [128, 1, 8]
Describe your environment
Platform: macOS
No CUDA
PyTorch version: 2.0.1
The above code is also tested on Ubuntu 18.04 with PyTorch 2.0.1, CUDA 11.7/11.8.
The text was updated successfully, but these errors were encountered:
Add Link
Hello!$N$ and the number of model parameters is $M$ , I want to calculate $\partial \log p(\mathbf{x}^{(i)};\theta)/\partial \theta_j$ , which is an $N \times M$ matrix. I found the [PER-SAMPLE-GRADIENTS](https://pytorch.org/tutorials/intermediate/per_sample_grads.html) tutorial and began my own experiments. As a proof of concept, I defined a generative model with a tractable likelihood, such as MADE (Masked Autoencoder for Distribution Estimation), PixelCNN, RNN, etc., and sepcified the $\mathbf{x}$ . For RNN however, both sampling and inference require $N$ forward pass).
I'm working on a optimization algorithm that requires computing the per sample gradients. Assuming the batch size is
log_prob
andsample
methods. I utilized the function transforms methods mentioned in the tutorial, but currently, it only works for MADE (I believed it would work for NADE and PixelCNN too, since these models need only one forward pass to calculate the log likelihood ofBelow, I've provided my code snippets, and I'm interested in figuring out why it's not working for RNN. Making it work for RNN would significantly reduce the number of parameters for my research purpose.
Thank you!
Describe the bug
The above code works for MADE (I also check the values of gradients, they are correct!)
However, when I use
model = GRUModel()
, an error arises:Describe your environment
The above code is also tested on Ubuntu 18.04 with PyTorch 2.0.1, CUDA 11.7/11.8.
The text was updated successfully, but these errors were encountered: