Skip to content

Speed up gaussian_nll_loss when variance is always the same scalar value #138747

@michael-diggin

Description

@michael-diggin

🚀 The feature, motivation and pitch

I've been using gaussian_nll_loss and noticed that it can be quite slow, due to the check for non-negative variances here: https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L3295, which forces a gpu-cpu sync. See the following screenshot from a profile:
gaussian_nll_loss
This was reliably costing ~10ms for me (admittedly, that includes the some overhead of the profiler).

In my case, the variance I was using was essentially 0.05*torch.ones_like(inputs) and since it was known to be > 0 ahead of time, the slow check was redundant.

It would be nice if gaussian_nll_loss accepted either a Tensor or a float for var. If it's a float, have a check on the value and then convert it to a Tensor to avoid the expensive sync, otherwise perform the check as normal. eg at the beginning of the function:

if isinstance(var, float):
   if var < 0:
        raise ValueError("var has negative entry/entries")
    var = var*torch.ones_like(inputs)
elif torch.any(var < 0):
    raise ValueError("var has negative entry/entries")

Alternatives

Alternatively, I was wondering if the non-negative check could be removed, as it is documented that variance should be positive.
The concern is that it could be silently set to the (positive) value of eps due to the clamp call here: https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L3301

Additional context

If you are happy with the proposed change, I'd be happy to raise a PR for it.

cc @msaroufim @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @ptrblck

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: nnRelated to torch.nnmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions