Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

illegal memory access for torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1) #111574

Closed
puririshi98 opened this issue Oct 19, 2023 · 11 comments
Assignees
Labels
high priority module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: cuda Related to torch.cuda, and CUDA support in general module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@puririshi98
Copy link
Contributor

puririshi98 commented Oct 19, 2023

🐛 Describe the bug

Original Issue from PyG: pyg-team/pytorch_geometric#8213
Failing example: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rev_gnn.py

CUDA_LAUNCH_BLOCKING=1 python3 /workspace/examples/rev_gnn.py
Traceback (most recent call last):
  File "/workspace/examples/rev_gnn.py", line 187, in <module>
    loss = train(epoch)
  File "/workspace/examples/rev_gnn.py", line 125, in train
    out = model(data.x, data.adj_t)[data.train_mask]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/examples/rev_gnn.py", line 76, in forward
    x = conv(x, edge_index, mask)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/models/rev_gnn.py", line 166, in forward
    return self._fn_apply(args, self._forward, self._inverse)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/models/rev_gnn.py", line 181, in _fn_apply
    out = InvertibleFunction.apply(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/models/rev_gnn.py", line 52, in forward
    outputs = ctx.fn(*x)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/models/rev_gnn.py", line 283, in _forward
    y_in = xs[i] + self.convs[i](y_in, edge_index, *args[i])
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/examples/rev_gnn.py", line 35, in forward
    return self.conv(x, edge_index)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/conv/sage_conv.py", line 130, in forward
    out = self.propagate(edge_index, x=x, size=size)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/conv/message_passing.py", line 431, in propagate
    out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/conv/sage_conv.py", line 149, in message_and_aggregate
    return spmm(adj_t, x[0], reduce=self.aggr)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/utils/spmm.py", line 99, in spmm
    return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1)
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Versions

python collect_env.py
Collecting environment information...
PyTorch version: 2.1.0a0+32f93b1
Is debug build: False
CUDA used to build PyTorch: 12.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.27.6
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-150-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA RTX A5000
GPU 1: NVIDIA RTX A5000

Nvidia driver version: 530.41.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          16
On-line CPU(s) list:             0-15
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Core(TM) i7-9800X CPU @ 3.80GHz
CPU family:                      6
Model:                           85
Thread(s) per core:              2
Core(s) per socket:              8
Socket(s):                       1
Stepping:                        4
CPU max MHz:                     4500.0000
CPU min MHz:                     1200.0000
BogoMIPS:                        7599.80
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti ssbd mba ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req md_clear flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       256 KiB (8 instances)
L1i cache:                       256 KiB (8 instances)
L2 cache:                        8 MiB (8 instances)
L3 cache:                        16.5 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-15
Vulnerability Itlb multihit:     KVM: Mitigation: Split huge pages
Vulnerability L1tf:              Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds:               Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:          Mitigation; IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Mitigation; Clear CPU buffers; SMT vulnerable

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.2
[pip3] onnx==1.14.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.1.0a0+32f93b1
[pip3] torch_geometric==2.4.0
[pip3] torch-tensorrt==0.0.0
[pip3] torchdata==0.6.0+5bbcd77
[pip3] torchmetrics==1.2.0
[pip3] torchtext==0.16.0a0
[pip3] torchvision==0.16.0a0
[pip3] triton==2.1.0+e621604
[pip3] tritonclient==2.38.0.69485441
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @ptrblck

@colesbury colesbury added module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: cuda Related to torch.cuda, and CUDA support in general module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module high priority labels Oct 19, 2023
@cpuhrsch
Copy link
Contributor

@pearu @amjames - Please prioritize and look into this asap. This is marked as high priority.

@pearu
Copy link
Collaborator

pearu commented Oct 20, 2023

With

$ git diff
diff --git a/torch_geometric/utils/spmm.py b/torch_geometric/utils/spmm.py
index d951b7e0e..dbdd80e93 100644
--- a/torch_geometric/utils/spmm.py
+++ b/torch_geometric/utils/spmm.py
@@ -96,7 +96,7 @@ def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor:
                 deg = scatter(torch.ones_like(src.values()), src.row_indices(),
                               dim=0, dim_size=src.size(0), reduce='sum')
 
-            return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1)
+            return torch.sparse.mm(src.cpu(), other.cpu()).cuda() / deg.view(-1, 1).clamp_(min=1)

(that works fine), I confirm that the issue is in torch.sparse.mm when applied to CUDA tensors.

@Aidyn-A
Copy link
Collaborator

Aidyn-A commented Oct 20, 2023

I think there is something wrong with the way sparse CSR matrix is stored. If I do something like this:

torch.sparse.mm(src.to_sparse_csc(), other) / deg.view(-1, 1).clamp_(min=1)

or even this:

torch.sparse.mm(src.to_sparse_csc().to_sparse_csr(), other) / deg.view(-1, 1).clamp_(min=1)

it works fine as well. However, the same trick using to_sparse_coo() does not work.

@pearu
Copy link
Collaborator

pearu commented Oct 20, 2023

I think there is something wrong with the way sparse CSR matrix is stored.

With the current reports, I think we can only say that something is wrong:)

Doing

return torch.sparse.mm(src.clone(), other) / deg.view(-1, 1).clamp_(min=1)

also works fine.

@puririshi98 , you can use this as a sensible workaround within pytorch_geometric until we'll fix the bug reported here. UPDATE: A better workaround is

$ git diff
diff --git a/torch_geometric/utils/sparse.py b/torch_geometric/utils/sparse.py
index fa15e547d..2db918108 100644
--- a/torch_geometric/utils/sparse.py
+++ b/torch_geometric/utils/sparse.py
@@ -264,7 +264,7 @@ def to_torch_csr_tensor(
     adj = torch.sparse_csr_tensor(
         crow_indices=index2ptr(edge_index[0], size[0]),
         col_indices=edge_index[1],
-        values=edge_attr,
+        values=edge_attr.contiguous(),
         size=tuple(size) + edge_attr.size()[1:],
         device=edge_index.device,
     )

@pearu pearu self-assigned this Oct 21, 2023
@pearu
Copy link
Collaborator

pearu commented Oct 21, 2023

The issue was caused by the fact that src.values() is not a contiguous tensor: it is created in
https://github.com/pyg-team/pytorch_geometric/blob/cd6d1f2b614aa411aa2eb07b96f8387e30b5225b/torch_geometric/utils/sparse.py#L256-257
that leads to a tensor that has zero strides.
On the other hand, torch.sparse.mm (or any tool that uses CuSparseSpMatCsrDescriptor) expects a CSR tensor to have contiguous values per

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());

@puririshi98
Copy link
Contributor Author

@pearu thanks for the fix :). How soon do you expect the fix PRs linked above to land? if it could be by the end of this week it would probably be simpler to have @Aidyn-A cherry pick the fix into our pytorch containers instead of me changing pyg to use the temporary workaround and then reverting it when the real fix lands.

@pearu
Copy link
Collaborator

pearu commented Oct 23, 2023

@puririshi98 The PR is now approved and is in the process of merging. If CI's will be all green, the PR should land shortly.

@malfet
Copy link
Contributor

malfet commented Oct 23, 2023

@puririshi98 can you please clarify whether or not this is a regression? (I.e. can you run the same sample using torch-2.0.0 and pass or will it crash as well?)

@pearu pearu added this to To do in Sparse tensors via automation Oct 23, 2023
@pearu
Copy link
Collaborator

pearu commented Oct 23, 2023

FWIW, from reading pytorch_geometric code, the crash should occur with torch-2.0 as well.

@puririshi98
Copy link
Contributor Author

i just confirmed the same issue occurs w/ torch2.0

Sparse tensors automation moved this from To do to Done Oct 23, 2023
andreigh pushed a commit to andreigh/pytorch that referenced this issue Oct 26, 2023
xuhancn pushed a commit to xuhancn/pytorch that referenced this issue Nov 7, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this issue Nov 14, 2023
@atalman atalman added this to the 2.1.2 milestone Nov 22, 2023
@atalman atalman modified the milestones: 2.1.2, 2.2.0 Dec 29, 2023
@atalman
Copy link
Contributor

atalman commented Jan 18, 2024

Validated this issue with final rc:

CUDA_LAUNCH_BLOCKING=1 python3 examples/rev_gnn.py
This will download 1.38GB. Will you proceed? (y/N)
y
Downloading http://snap.stanford.edu/ogb/data/nodeproppred/products.zip
Downloaded 1.38 GB: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1414/1414 [00:41<00:00, 33.99it/s]
Extracting /data/home/atalman/temp33/pytorch_geometric/examples/../data/products/products.zip
Processing...
Loading necessary files...
This might take a while.

Processing graphs...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.62s/it]
Converting graphs into PyG objects...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 90.46it/s]
Saving...
Done!
Training epoch: 001:   0%|                                                                                                                                                          | 0/10 [00:00<?, ?it/s]/data/home/atalman/miniconda3/envs/jan12/lib/python3.11/site-packages/torch_geometric/utils/sparse.py:264: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:53.)
  adj = torch.sparse_csr_tensor(
Training epoch: 001: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:12<00:00,  1.25s/it]
Evaluating epoch: 001: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.49s/it]
Loss: 2.8235, Train: 0.3085, Val: 0.3084, Test: 0.2695
Training epoch: 002: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.21it/s]
Evaluating epoch: 002: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.66s/it]
Loss: 2.0747, Train: 0.6023, Val: 0.5991, Test: 0.4749
Training epoch: 003: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.19it/s]
Evaluating epoch: 003: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.52s/it]
Loss: 1.5033, Train: 0.7203, Val: 0.7130, Test: 0.5634
Training epoch: 004: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.20it/s]
Evaluating epoch: 004: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.76s/it]
Loss: 1.1805, Train: 0.7664, Val: 0.7614, Test: 0.6023
Training epoch: 005: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.21it/s]
Evaluating epoch: 005: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.61s/it]
Loss: 0.9739, Train: 0.8197, Val: 0.8153, Test: 0.6498
Training epoch: 006: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.19it/s]
Evaluating epoch: 006: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.51s/it]
Loss: 0.8369, Train: 0.8490, Val: 0.8451, Test: 0.6818
Training epoch: 007: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.19it/s]
Evaluating epoch: 007: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.73s/it]
Loss: 0.7415, Train: 0.8661, Val: 0.8641, Test: 0.7089
Training epoch: 008: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.20it/s]
Evaluating epoch: 008: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.60s/it]
Loss: 0.6818, Train: 0.8762, Val: 0.8746, Test: 0.7210
Training epoch: 009: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.22it/s]
Evaluating epoch: 009: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.76s/it]
Loss: 0.6383, Train: 0.8815, Val: 0.8809, Test: 0.7289
Training epoch: 010: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.21it/s]
Evaluating epoch: 010: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.78s/it]
Loss: 0.6057, Train: 0.8853, Val: 0.8824, Test: 0.7359
Training epoch: 011: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.21it/s]
Evaluating epoch: 011: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:16<00:00, 16.15s/it]
Loss: 0.5810, Train: 0.8874, Val: 0.8847, Test: 0.7393
Training epoch: 012: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.20it/s]
Evaluating epoch: 012: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.65s/it]
Loss: 0.5637, Train: 0.8909, Val: 0.8876, Test: 0.7413
Training epoch: 013: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.20it/s]
Evaluating epoch: 013: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.82s/it]
Loss: 0.5481, Train: 0.8927, Val: 0.8901, Test: 0.7453
Training epoch: 014: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00,  1.10it/s]
Evaluating epoch: 014: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:16<00:00, 16.25s/it]
Loss: 0.5391, Train: 0.8938, Val: 0.8895, Test: 0.7435
Training epoch: 015: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.19it/s]
Evaluating epoch: 015:   0%|                    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: cuda Related to torch.cuda, and CUDA support in general module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Development

Successfully merging a pull request may close this issue.

8 participants