-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Bug
This PR #33427 tries to avoid clone for sparse tensors during accumulation of grads. However, at::sparse_coo_tensor performs value checks on input values which are very slow. It defeats the goal of avoiding clones to make it faster.
To Reproduce
Steps to reproduce the behavior:
- Running the following code in 1.4 and master
from time import time
import torch
def main():
batch_size = 2048
num_features = 100000
query_nnz = 100
embed_size = 64
ref_embedding = torch.nn.Embedding(num_features, embed_size, sparse=True).cuda()
indices = torch.randint(0, high=num_features, size=(batch_size, query_nnz), device="cuda")
grad = torch.rand(batch_size, query_nnz, embed_size, device="cuda")
torch.cuda.synchronize()
start = time()
for _ in range(100):
ref_embedding.weight.grad = None
ref_lookup = ref_embedding(indices)
ref_lookup.backward(grad)
torch.cuda.synchronize()
stop = time()
print(F"Elapsed time {(stop - start) * 1000.:.1f} ms.")
if __name__ == '__main__':
main()Expected behavior
Tested on V100
Pytorch 1.4: Elapsed time 41.6 ms.
Pytorch master: Elapsed time 692.5 ms.
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
- PyTorch Version (e.g., 1.0): 1.4 and 1.5.0a0+d616cad
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip, source): conda and source - Build command you used (if compiling from source):
pip install --no-cache-dir -v . - Python version: 3.6
- CUDA/cuDNN version: 10.2
- GPU models and configuration: Tesla V100-DGXS-32GB
- Any other relevant information:
Additional context
Change https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/functions/accumulate_grad.h#L78-L82 to use at::_sparse_coo_tensor_unsafe will fix it. No need to use the safe version in backward as values and indices are already checked. Existing modules, like embedding uses the unsafe version
| return at::_sparse_coo_tensor_unsafe(index, values, weight_size); |