New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.normal
accepts Variables but does not propagate gradients
#4620
Comments
Even worse, somehow I've found a way to successfully run |
Ok, here's another example of the behavior described in the previous comment: It's something like a VAE. It demonstrates using import torch
from torch.autograd import Variable
import math
torch.manual_seed(0)
LOG2PI = torch.log(torch.FloatTensor([2 * math.pi]))[0]
class DiagonalMVN(object):
def __init__(self, mean, log_stddev):
assert mean.size() == log_stddev.size()
self.mean = mean
self.log_stddev = log_stddev
def sample_bad(self):
return torch.normal(self.mean, torch.exp(self.log_stddev))
def sample_good(self):
return self.mean + torch.exp(self.log_stddev) * Variable(torch.randn(self.mean.size()))
def logprob(self, x):
return -0.5 * (
self.mean.numel() * LOG2PI
+ 2 * torch.sum(self.log_stddev)
+ torch.sum((x - self.mean) * torch.exp(-2 * self.log_stddev) * (x - self.mean))
)
class FixedVarianceNet(object):
def __init__(self, mean_net, log_stddev):
self.mean_net = mean_net
self.log_stddev = log_stddev
def __call__(self, x):
return DiagonalMVN(self.mean_net(x), self.log_stddev)
inference_net = FixedVarianceNet(torch.nn.Linear(2, 1), Variable(torch.zeros(1)))
generative_net = FixedVarianceNet(torch.nn.Linear(1, 2), Variable(torch.zeros(2)))
print('### torch.normal broken!')
Xvar = Variable(torch.randn(2))
vi_posterior = inference_net(Xvar)
loss = -generative_net(vi_posterior.sample_bad()).logprob(Xvar)
loss.backward()
print('inference_net.mean_net.bias.grad ==', inference_net.mean_net.bias.grad)
print()
print('### Custom sample() works...')
Xvar = Variable(torch.randn(2))
vi_posterior = inference_net(Xvar)
loss = -generative_net(vi_posterior.sample_good()).logprob(Xvar)
loss.backward()
print('inference_net.mean_net.bias.grad ==', inference_net.mean_net.bias.grad) It produces:
|
I'm using
|
It becomes backprop-able because an I don't think
But I do agree that this is confusing. |
But I've also marked Personally I would like to lobby for |
I recently tracked down a particularly nasty bug that was due to the behavior shown in the second example: no error, but also no gradients. |
I believe
It is useful in case of the reparameterization trick. However, it still doesn't quite make any sense for backprop to work on sampling methods. For example, what would you say is the "gradient" for discrete distributions? And we should definitely not allow backprop through things like entire MCMC trace, naive RL rewards, etc. When users want noisy gradients, the best way is to let them manually write out the thing like N(0, 1) in my opinion. I'm not sure that we should throw an error. A warning might be a good solution. But I definitely agree that we should make the doc clearer. |
I'm still confused as to why one would error and the other would not. Shouldn't I be able to write my own linear layer
That makes sense. In that case |
nn.Linear has no special treatment. Here is what happens when it is involved.
|
Maybe, but Tensor and Variable will soon become the same thing :P |
Ok, so this all makes sense to me. But I don't see why mu = Variable(torch.Tensor([1]), requires_grad=True)
sigma = Variable(torch.Tensor([1]), requires_grad=True)
x = torch.normal(mu, sigma)
loss = torch.pow(x, 2)
loss.backward() should produce a RuntimeError. |
Exciting! |
Without |
if you are on master, you can do |
The following code doesn't cause an error because
|
@talesa yeah but mu and sigma won't have grads. |
After some discussions we decided that all random ops in autograd should never propagate gradient. If you want a reparametrized sampler use |
@apaszke Would it be possible to have a warning or error when you try to use Variables with random ops then? That way it becomes much more difficult to fall into this trap. |
Yeah, the "fix" should be to throw an error in the backwards of:
Basically, anywhere Generator appears in derivatives.yaml |
Just encountered this trap. I came from tensorflow and is hoping to write a policy gradient agent for reinforcement learning, where I need to sample an action tensor from a normal distribution where the mean is the output of a network and the deviation is a variable. I need to propagate to both the mean and the variance.
|
|
@colesbury Is there a reason that you say the error should be thrown when calling backwards, as opposed to when calling these functions with |
Closing this now that 0.4 has come out. |
Simple example:
produces:
The text was updated successfully, but these errors were encountered: