Skip to content

Calling sparse_coo_tensor in accumulateGrad makes it much slower #36120

@skyw

Description

@skyw

🐛 Bug

This PR #33427 tries to avoid clone for sparse tensors during accumulation of grads. However, at::sparse_coo_tensor performs value checks on input values which are very slow. It defeats the goal of avoiding clones to make it faster.

To Reproduce

Steps to reproduce the behavior:

  1. Running the following code in 1.4 and master
from time import time
import torch

def main():
  batch_size = 2048
  num_features = 100000
  query_nnz = 100
  embed_size = 64

  ref_embedding = torch.nn.Embedding(num_features, embed_size, sparse=True).cuda()

  indices = torch.randint(0, high=num_features, size=(batch_size, query_nnz), device="cuda")
  grad = torch.rand(batch_size, query_nnz, embed_size, device="cuda")


  torch.cuda.synchronize()
  start = time()
  for _ in range(100):
    ref_embedding.weight.grad = None
    ref_lookup = ref_embedding(indices)
    ref_lookup.backward(grad)
  torch.cuda.synchronize()
  stop = time()
  print(F"Elapsed time {(stop - start) * 1000.:.1f} ms.")


if __name__ == '__main__':
  main()

Expected behavior

Tested on V100

Pytorch 1.4: Elapsed time 41.6 ms.
Pytorch master: Elapsed time 692.5 ms.

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch Version (e.g., 1.0): 1.4 and 1.5.0a0+d616cad
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): conda and source
  • Build command you used (if compiling from source):
    pip install --no-cache-dir -v .
  • Python version: 3.6
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration: Tesla V100-DGXS-32GB
  • Any other relevant information:

Additional context

Change https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/functions/accumulate_grad.h#L78-L82 to use at::_sparse_coo_tensor_unsafe will fix it. No need to use the safe version in backward as values and indices are already checked. Existing modules, like embedding uses the unsafe version

return at::_sparse_coo_tensor_unsafe(index, values, weight_size);
.

cc @vincentqb @VitalyFedyunin @ngimel

Metadata

Metadata

Labels

module: performanceIssues related to performance, either of kernel code or framework gluemodule: sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions