-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Description
🚀 The feature, motivation and pitch
I've been using gaussian_nll_loss and noticed that it can be quite slow, due to the check for non-negative variances here: https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L3295, which forces a gpu-cpu sync. See the following screenshot from a profile:

This was reliably costing ~10ms for me (admittedly, that includes the some overhead of the profiler).
In my case, the variance I was using was essentially 0.05*torch.ones_like(inputs) and since it was known to be > 0 ahead of time, the slow check was redundant.
It would be nice if gaussian_nll_loss accepted either a Tensor or a float for var. If it's a float, have a check on the value and then convert it to a Tensor to avoid the expensive sync, otherwise perform the check as normal. eg at the beginning of the function:
if isinstance(var, float):
if var < 0:
raise ValueError("var has negative entry/entries")
var = var*torch.ones_like(inputs)
elif torch.any(var < 0):
raise ValueError("var has negative entry/entries")Alternatives
Alternatively, I was wondering if the non-negative check could be removed, as it is documented that variance should be positive.
The concern is that it could be silently set to the (positive) value of eps due to the clamp call here: https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L3301
Additional context
If you are happy with the proposed change, I'd be happy to raise a PR for it.
cc @msaroufim @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @ptrblck