Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-gpu example freeze and is not killable #24081

Open
Dubrzr opened this issue Aug 9, 2019 · 54 comments
Open

Multi-gpu example freeze and is not killable #24081

Dubrzr opened this issue Aug 9, 2019 · 54 comments
Labels
has workaround module: cuda Related to torch.cuda, and CUDA support in general module: data parallel module: deadlock Problems related to deadlocks (hang without exiting) module: dependency bug Problem is not caused by us, but caused by an upstream library we use module: multi-gpu Problem is related to running on multiple GPUs module: multiprocessing Related to torch.multiprocessing quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Dubrzr
Copy link

Dubrzr commented Aug 9, 2019

🐛 Bug

Running pytorch with multiple P40 gpus freeze and is not killable (even kill -9 by root). Only a reboot removes this process.

Inside docker container (with nvidia-docker2) it freezes docker. NVIDIA/nvidia-docker#1010

To Reproduce

Steps to reproduce the behavior:

  1. Install pytorch 1.0.2
  2. Run the following code on multiple P40 Gpus
import os


###tutorial from https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
###no error with only 1 gpu
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

#### to reproduce error allow multi gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

import torch


torch.cuda.device_count()

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Parameters and DataLoaders
input_size = 5000 #increased input size (works with 500 on multi gpu)
output_size = 2000 #increased output size (works with 200 on multi gpu)

batch_size = 300
data_size = 100

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
                         batch_size=batch_size, shuffle=True)

class Model(nn.Module):
    # Our model

    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)


        return output

model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)

for i in range(10000):
    for data in rand_loader:
        input = data.to(device)
        output = model(input)

Expected behavior

The training

Environment

Collecting environment information...
PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.2 LTS
GCC version: (crosstool-NG fa8859cb) 7.2.0
CMake version: Could not collect

Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Tesla P40
GPU 1: Tesla P40
GPU 2: Tesla P40
GPU 3: Tesla P40
GPU 4: Tesla P40
GPU 5: Tesla P40
GPU 6: Tesla P40
GPU 7: Tesla P40

Nvidia driver version: 410.79
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.4.2

Versions of relevant libraries:
[pip3] numpy==1.15.2
[conda] mkl 2018.0.3 1 defaults
[conda] mkl_fft 1.0.6 py35_0 conda-forge
[conda] mkl_random 1.0.1 py35_0 conda-forge
[conda] nomkl 2.0 0 defaults
[conda] numexpr 2.6.5 py35_nomklhaa809a4_0 [nomkl] defaults
[conda] pytorch 1.0.1 py3.5_cuda10.0.130_cudnn7.4.2_2 pytorch
[conda] torch 0.4.1
[conda] torchvision 0.2.2 py_3 pytorch

cc @ezyang @gchanan @zou3519 @ngimel

@Dubrzr
Copy link
Author

Dubrzr commented Aug 9, 2019

I've just tested with pytorch 1.2.0 it freezes in the same way.

@ezyang ezyang added module: data parallel needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user module: deadlock Problems related to deadlocks (hang without exiting) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 9, 2019
@ezyang
Copy link
Contributor

ezyang commented Aug 9, 2019

If someone wants to help us debug this issue, if you can try reproducing this issue that would be helpful; it would also be helpful to know if the problem goes away if you (1) upgrade your CUDA version and (2) run on different GPUs.

BTW, I notice in your environment, you have

[conda] pytorch 1.0.1 py3.5_cuda10.0.130_cudnn7.4.2_2 pytorch
[conda] torch 0.4.1 

You should delete the old torch 0.4.1 at some point.

@rrkarim
Copy link

rrkarim commented Aug 9, 2019

Same on pytorch 1.1.0 and two gtx 1080ti. Hard to say, maybe the issues related to the setup. Need more responses.

@ezyang ezyang added critical triage review high priority and removed needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user critical labels Aug 9, 2019
@ezyang
Copy link
Contributor

ezyang commented Aug 9, 2019

Upgrading priority

@raethlein
Copy link

We face a similar issue.

The environment below is with a single GPU mounted into the container, which is currently tested. Before, 2 GPUs were mounted in the container. When some GPU-specific code was run, such as MyDecoder().to(device), the jupyter notebook server became completely unresponsive when it was tried to interrupt / restart the kernel. docker stop or docker rm did not work then anymore.
Only way to remove the container was to kill the container process, which left zombie and orphaned processes blocking the GPU which can only be fixed via a system restart.
It seems that the last log output is ... NotebookApp] Starting buffering for <kernel-id>.

Though, I am not completely sure whether this issue is related or whether we have a hardware failure. Just thought it sounds so similar that I provide some information. I hope it helps.

Environment
Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.14.0

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: Tesla P100-PCIE-16GB
Nvidia driver version: 418.67
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.0

Versions of relevant libraries:
[pip3] gpytorch==0.3.4
[pip3] msgpack-numpy==0.4.4.3
[pip3] numpy==1.16.4
[pip3] pytorch-ignite==0.2.0
[pip3] pytorch-nlp==0.4.1
[pip3] pytorch-pretrained-bert==0.6.2
[pip3] torch==1.1.0
[pip3] torch-cluster==1.4.3
[pip3] torch-geometric==1.3.0
[pip3] torch-scatter==1.3.1
[pip3] torch-sparse==0.4.0
[pip3] torchbiggraph==1.dev1
[pip3] torchfile==0.1.0
[pip3] torchstat==0.0.7
[pip3] torchtext==0.3.1
[pip3] torchvision==0.3.0
[conda] _tflow_select 2.3.0 mkl defaults
[conda] blas 1.0 mkl defaults
[conda] faiss-cpu 1.5.3 py36h6bb024c_0 pytorch
[conda] gpytorch 0.3.4 pypi_0 pypi
[conda] mkl 2019.4 243 defaults
[conda] mkl-include 2019.4 243 defaults
[conda] mkl-service 2.0.2 py36h7b6447c_0 defaults
[conda] mkl_fft 1.0.12 py36ha843d7b_0 defaults
[conda] mkl_random 1.0.2 py36hd81dba3_0 defaults
[conda] pytorch 1.1.0 py3.6_cuda10.0.130_cudnn7.5.1_0 pytorch
[conda] pytorch-cpu 1.1.0 py3.6_cpu_0 pytorch
[conda] pytorch-ignite 0.2.0 pypi_0 pypi
[conda] pytorch-nlp 0.4.1 pypi_0 pypi
[conda] pytorch-pretrained-bert 0.6.2 pypi_0 pypi
[conda] tensorflow 1.14.0 mkl_py36h2526735_0 defaults
[conda] torch-cluster 1.4.3 pypi_0 pypi
[conda] torch-geometric 1.3.0 pypi_0 pypi
[conda] torch-scatter 1.3.1 pypi_0 pypi
[conda] torch-sparse 0.4.0 pypi_0 pypi
[conda] torchbiggraph 1.dev1 pypi_0 pypi
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchstat 0.0.7 pypi_0 pypi
[conda] torchtext 0.3.1 pypi_0 pypi
[conda] torchvision 0.3.0 py36_cu10.0.130_1 pytorch
[conda] torchvision-cpu 0.3.0 py36_cuNone_1 pytorch

@Dubrzr
Copy link
Author

Dubrzr commented Aug 19, 2019

I've realized a strace on this script : https://gist.github.com/Dubrzr/b058b54947b2688b7e02e64f6bdf78b8

It froze after those lines:

ioctl(4, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x2a, 0x20), 0x7ffe185b07d0) = 0
futex(0x7fe8ba7a5e60, FUTEX_WAKE_PRIVATE, 2147483647) = 0
futex(0x7fe8ba7c1ff8, FUTEX_WAKE_PRIVATE, 2147483647) = 0
futex(0x55735fbc5450, FUTEX_WAKE_PRIVATE, 2147483647) = 0
futex(0x55735fbc5454, FUTEX_WAKE_PRIVATE, 2147483647) = 0
futex(0x55735fbc5458, FUTEX_WAKE_PRIVATE, 2147483647) = 0
futex(0x55735fbc545c, FUTEX_WAKE_PRIVATE, 2147483647) = 0
futex(0x7fe932502424, FUTEX_WAKE_PRIVATE, 2147483647) = 0
futex(0x7fe92f71e520, FUTEX_WAKE_PRIVATE, 2147483647) = 0
futex(0x7fe8f77a62e8, FUTEX_WAKE_PRIVATE, 2147483647) = 0
pipe2([12, 13], O_CLOEXEC)              = 0
fcntl(12, F_SETFL, O_RDONLY|O_NONBLOCK) = 0
clone(child_stack=0x7fe8cdef2fb0, flags=CLONE_VM|CLONE_FS|CLONE_FILES|CLONE_SIGHAND|CLONE_THREAD|CLONE_SYSVSEM|CLONE_SETTLS|CLONE_PARENT_SETTID|CLONE_CHILD_CLEARTID, parent_tidptr=0x7fe8cdef39d0, tls=0x7fe8cdef3700, child_tidptr=0x7fe8cdef39d0) = 47266
ioctl(4, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x2a, 0x20), 0x7ffe185af8d0) = 0
ioctl(5, _IOC(0, 0x00, 0x19, 0x00)

Hope this will help you :)

@VitalyFedyunin VitalyFedyunin added module: multi-gpu Problem is related to running on multiple GPUs module: cuda Related to torch.cuda, and CUDA support in general module: multiprocessing Related to torch.multiprocessing labels Aug 19, 2019
@rrkarim
Copy link

rrkarim commented Aug 22, 2019

Is this one relatable: #1637? The solution: #1637 (comment)

@izdeby
Copy link
Contributor

izdeby commented Aug 26, 2019

Keeping high pri as investigation is required.

@jpellman
Copy link

jpellman commented Sep 9, 2019

I'm encountering the same issue as well- not even SIGKILL will work to stop a script using PyTorch. Honestly, if SIGKILL isn't working, I doubt that this issue is specific to PyTorch. I'd be more inclined to believe that the NVIDIA driver / tainted Linux kernel is at fault- anything outside the kernel space really shouldn't be able to mess up SIGKILL. Similar issues have occurred in the past with nvidia-smi (see here, here, here, and here).

[edit]: fastai also seemed to exhibit a similar problem around October, 2018 (see here)

For added context, my strace output is similar to @Dubrzr 's:

futex(0x7f6f6f90c1a8, FUTEX_WAKE_PRIVATE, 2147483647) = 0
pipe2([32, 33], O_CLOEXEC)              = 0
fcntl(32, F_SETFL, O_RDONLY|O_NONBLOCK) = 0
fcntl(33, F_SETFL, O_RDONLY|O_NONBLOCK) = 0
mmap(NULL, 8392704, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS|MAP_STACK, -1, 0) = 0x7f6ef5ef9000
mprotect(0x7f6ef5ef9000, 4096, PROT_NONE) = 0
clone(child_stack=0x7f6ef66f8fb0, flags=CLONE_VM|CLONE_FS|CLONE_FILES|CLONE_SIGHAND|CLONE_THREAD|CLONE_SYSVSEM|CLONE_SETTLS|CLONE_PARENT_SETTID|CLONE_CHILD_CLEARTID, parent_tidptr=0x7f6ef66f99d0, tls=0x7f6ef66f9700, child_tidptr=0x7f6ef66f99d0) = 26182
futex(0x560c1db6f668, FUTEX_WAKE_PRIVATE, 1) = 1
ioctl(4, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x2a, 0x20), 0x7ffe49a26910) = 0
ioctl(5, _IOC(0, 0x00, 0x19, 0x00)

The NVIDIA driver version is 418.87.00, cuda is 10.0.130, cudnn is 7.6.1.34-10.0, and PyTorch is 1.1.0. I'm not the chief developer of the script producing the behavior so I don't know many details about how PyTorch is being used.

@jpellman
Copy link

jpellman commented Sep 9, 2019

For quick reference, here's a set of tallies for some of the relevant software/version number pairings:

NVIDIA Driver Summary Count:

Version Count
410.79 1
418.67 1
418.87 2

CUDA Summary Count:

Version Count
10.0.130 3
9.2.88 1

cudNN Summary Count:

Version Count
7.4.2 1
7.6.0 1
7.6.1 2

PyTorch Summary Count:

Version Count
1.0.1.post2 1
1.2.0 1
1.1.0 4

Based off these frequencies, CUDA is the most common factor.

@jpellman
Copy link

I was able to replicate this issue using an environment identical to my earlier one except for the CUDA version, which was 9.2.88.

@jpellman
Copy link

jpellman commented Sep 10, 2019

As another data point, the version of fastai exhibiting a similar issue that I linked to in my earlier comment would have also used CUDA 9.2 (based off a contemporary copy of the README here).

@Dubrzr
Copy link
Author

Dubrzr commented Sep 12, 2019

@ezyang @gchanan
Is there a way to see what cuda calls are made during the execution of this pytorch code?

@jpellman
Copy link

jpellman commented Sep 12, 2019

Well, the syscall where PyTorch is getting messed up is in our strace output already:

ioctl(5, _IOC(0, 0x00, 0x19, 0x00)

ioctl is used for performing I/O calls outside the universal file I/O calls and is documented here.

File descriptor 5 was pointing at /dev/nvidia-uvm on the host I'm working with. This is the unified virtual memory module, which seems to be described in more detail here and here.

Based off section 6.1.1 here:

0 is the direction, defined as _IOC_NONE, which means no data transfer is occurring.
0x00 is the magic number, defined here- NVIDIA probably does not care about the magic number
0x19 is the number, 25 in base 10
0x00 is the size- in this case it seems to be ignored. No data transfer is occurring.

The number (25) seems to correspond to the following operation (from the MIT-licensed UVM source code- located at /usr/src/nvidia-*/nvidia-uvm/uvm_ioctl.h on a Linux install):

#define UVM_REGISTER_GPU_VASPACE                                      UVM_IOCTL_BASE(25)

typedef struct
{
    NvProcessorUuid gpuUuid;  // IN
    NvS32           rmCtrlFd; // IN
    NvHandle        hClient;  // IN
    NvHandle        hVaSpace; // IN
    NV_STATUS       rmStatus; // OUT
} UVM_REGISTER_GPU_VASPACE_PARAMS;

As near as I can tell, this means that there's some issue with allocating blocks of RAM on the GPU. I'm not a kernel programmer (and I'm definitely not intimately familiar with GPU programming beyond this), so take all of what I just said with a grain of salt.

@jpellman
Copy link

jpellman commented Sep 12, 2019

As yet another bit of info, I ran memtestG80 on each of the GPUs on my system. Pretty much all GPUs were fine except for the first one (index 0). On this one, memtestG80 hangs. When I run strace on it, I get some familiar output:

mmap(NULL, 8392704, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS|MAP_STACK, -1, 0) = 0x7fcf7c7e8000
mprotect(0x7fcf7c7e8000, 4096, PROT_NONE) = 0
clone(child_stack=0x7fcf7cfe7fb0, flags=CLONE_VM|CLONE_FS|CLONE_FILES|CLONE_SIGHAND|CLONE_THREAD|CLONE_SYSVSEM|CLONE_SETTLS|CLONE_PARENT_SETTID|CLONE_CHILD_CLEARTID, parent_tidptr=0x7fcf7cfe89d0, tls=0x7fcf7cfe8700, child_tidptr=0x7fcf7cfe89d0) = 17099
futex(0x8e21d8, FUTEX_WAKE_PRIVATE, 1)  = 1
ioctl(3, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x2a, 0x20), 0x7fff4a589d40) = 0
ioctl(4, _IOC(0, 0x00, 0x19, 0x00)

File descriptor 4, in this case, is also /dev/nvidia-uvm.

After a reboot, memtestG80 runs fine on this GPU. Then I run the PyTorch program and it gets stuck again at the same syscall. Consequently, if I run memtestG80 again on GPU 0 after PyTorch gets stuck, it also gets stuck at ioctl.

@jpellman
Copy link

jpellman commented Sep 12, 2019

I ran pdb on one of the consistently unkillable PyTorch programs. The point at which the script became unkillable was when it ran something similar to model.to(device). The function in the stack that where it ultimately got stuck was around here:

        def convert(t):
            return t.to(device, dtype if t.is_floating_point() else None, non_blocking)

Here's the exact output of pdb before it became unkillable:

(Pdb) s
--Call--
> /scratch/nklab/users/lms2308/ananconda/envs/hbox/lib/python3.7/site-packages/torch/nn/modules/module.py(383)convert()
-> def convert(t):
(Pdb) s
> /scratch/nklab/users/lms2308/ananconda/envs/hbox/lib/python3.7/site-packages/torch/nn/modules/module.py(384)convert()
-> return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
(Pdb) s
--Call--
> /scratch/nklab/users/lms2308/ananconda/envs/hbox/lib/python3.7/site-packages/torch/cuda/__init__.py(148)_lazy_init()
-> def _lazy_init():
(Pdb) s
> /scratch/nklab/users/lms2308/ananconda/envs/hbox/lib/python3.7/site-packages/torch/cuda/__init__.py(150)_lazy_init()
-> if _initialized:
(Pdb) s
> /scratch/nklab/users/lms2308/ananconda/envs/hbox/lib/python3.7/site-packages/torch/cuda/__init__.py(151)_lazy_init()
-> return
(Pdb) s
--Return--
> /scratch/nklab/users/lms2308/ananconda/envs/hbox/lib/python3.7/site-packages/torch/cuda/__init__.py(151)_lazy_init()->None
-> return
(Pdb) s

@rathken
Copy link

rathken commented Nov 4, 2019

From the SuperMicro X11DPG-OT-CPU. "Select Enable to program Access Control Services (ACS) to Chipset PCI-E Root Port Bridges. Select Disable to program Access Control Services to all PCI-E Root Port Bridges.The options are Enable and Disable." Are you saying that for this board it doesn't work? Just curious, I have several X10DRG-OT+-CPU and it work fine there. BIOS Version: 3.1 Release Date: 07/13/2018

@Dubrzr
Copy link
Author

Dubrzr commented Nov 5, 2019

If this can provide more info, it also freezes with this motherboard : ProLiant XL270d Gen10.
Thanks

@rathken
Copy link

rathken commented Nov 5, 2019

Is your ProLiant XL270d Gen10 with the NVlink tray or the PCIe tray?

Is it p2pBandwidthLatencyTest that is hanging.... I think this is no longer really a pytorch issues at this point but a system setup for p2p isuue...

@jpellman
Copy link

jpellman commented Nov 6, 2019

The "Select Enable to program Access Control Services (ACS)" option only shows up if VT-d is enabled, which it isn't on my system. It would seem logical that since the ACS toggle is nested under the VT-d option (and disappears when VT-d is disabled), the ACS toggle would inherit the VT-d toggle's disabled state. This doesn't seem to be the case, however. This is confirmed by a fresh boot with my configuration (VT-d disabled, ACS shows as enabled when VT-d enabled) indicating that the ACS register values are populated:

(base) [jsp2205@ax08 ~]$ for plx in $(lspci | grep -i plx | awk '{print $1}'); do sudo setpci -s ${plx} f2a.w; done
[sudo] password for jsp2205:
0000
001d
001d
0000
001d
001d
0000
001d
001d
0000
001d
001d

If I go into the UEFI/BIOS, temporarily enable VT-d to make the ACS toggle visible, disable ACS, and then toggle VT-d off again, I get the expected results after reboot:

(base) [jsp2205@ax08 ~]$ for plx in $(lspci | grep -i plx | awk '{print $1}'); do sudo setpci -s ${plx} f2a.w; done
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000

Overall, seems to be a mixture of user error and interface confusion :)

@artemZholus
Copy link

Same problem. I've tested on Ryzen 3700 with two gtx 1060 6gb. DataParallel freezes at forward if using two gpus while behaving normally with one. strace also showed me deadlock at ioctl

ioctl(6, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x4a, 0xc0), 0x7ffe8b783960) = 0
ioctl(7, _IOC(0, 0, 0x21, 0), 0x7ffe8b783380) = 0
ioctl(6, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x2a, 0x20), 0x7ffe8b7836b0) = 0
ioctl(6, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x2a, 0x20), 0x7ffe8b7836e0) = 0
ioctl(6, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x2a, 0x20), 0x7ffe8b7836c0) = 0
ioctl(6, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x2a, 0x20), 0x7ffe8b783520) = 0
ioctl(6, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x2b, 0x20), 0x7ffe8b7836b0) = 0
ioctl(7, _IOC(0, 0, 0x1d, 0), 0x7ffe8b783710) = 0
ioctl(7, _IOC(0, 0, 0x21, 0), 0x7ffe8b7829a0) = 0
ioctl(6, _IOC(_IOC_READ|_IOC_WRITE, 0x46, 0x4a, 0xc0), 0x7ffe8b783960) = 0
ioctl(7, _IOC(0, 0, 0x21, 0), 0x7ffe8b783380) = 0
ioctl(7, _IOC(0, 0, 0x21, 0)

I'm using PyTorch==1.3.0, driver version 418.87 (also tried with 435.21 but unsuccessfully), cuda==10.1, cudnn==7.6.5

@artemZholus
Copy link

I managed to solve this issue in my case by disabling IOMMU in bios.

@bes-dev
Copy link
Contributor

bes-dev commented Dec 28, 2019

I have the same issues with (2x1080ti + AMD CPU).
To solve this issue you should:

  1. sudo nano /etc/default/grub
  2. GRUB_CMDLINE_LINUX_DEFAULT="amd_iommu=soft"
  3. sudo update-grub
  4. reboot

@y-x-c
Copy link

y-x-c commented Jan 8, 2020

I managed to solve this issue in my case by disabling IOMMU in bios.

Works for me!

@rgommers rgommers added the quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. label Jan 27, 2020
@ezyang
Copy link
Contributor

ezyang commented Jan 27, 2020

I'm downgrading the priority of this issue as it looks like people have found a workaround, and it seems there is not much we can do about it in PyTorch side. I'll keep it open for visibility, though.

@fpoms
Copy link

fpoms commented Feb 6, 2020

Can confirm the workaround of disabling IOMMU also works for me.

@Dubrzr
Copy link
Author

Dubrzr commented Feb 17, 2020

Disabling IOMMU can be a security issue, by disabling it you authorize GPUs to access memory addresses of all other GPUs.

This bug is probably in Nvidia drivers.
@jpellman : do you have any news on your NVBug post?

Thanks.

@jpellman
Copy link

jpellman commented Feb 19, 2020

DMA attacks aren't (at present) as much of a concern in our environment (the machines have very limited access and aren't publicly accessible), although it's certainly a far from an ideal scenario. I closed out the NVBug post a while ago. I'd suggest that you re-open an issue with them and frame this as a security issue.

@Dubrzr
Copy link
Author

Dubrzr commented Feb 27, 2020

I've posted an issue, here is the answer from Nvidia :

Thanks you for reporting. I'm afraid IOMMU is not supported on Linux.
As per https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#iommu-on-linux, On Linux only, CUDA and the display driver does not support IOMMU-enabled bare-metal PCIe peer to peer memory copy. However, CUDA and the display driver does support IOMMU via VM pass through. As a consequence, users on Linux, when running on a native bare metal system, should disable the IOMMU. The IOMMU should be enabled and the VFIO driver be used as a PCIe pass through for virtual machines.
On Windows the above limitation does not exist."

@jpellman
Copy link

¯\(ツ)

Lock up your Linux gamer workstations in a closet I guess. If you're concerned about a DMA attack from some sort of rogue peripheral, it's probably easier to block peripheral ports in some way- either via epoxy resin on the low-tech end of things or at the firmware/driver level (I believe this is fairly common practice in fintech and some healthcare institutions). It seems that there are a couple other steps you could take, even you can't do anything with IOMMU.

@fushuyue
Copy link

fushuyue commented Mar 3, 2020

I managed to solve this issue in my case by disabling IOMMU in bios.

Works for me. Many thanks!

@leolb-aphp
Copy link

leolb-aphp commented Mar 3, 2020

@jpellman

¯_(ツ)_/¯

Lock up your Linux gamer workstations in a closet I guess. If you're concerned about a DMA attack from some sort of rogue peripheral, it's probably easier to block peripheral ports in some way- either via epoxy resin on the low-tech end of things or at the firmware/driver level (I believe this is fairly common practice in fintech and some healthcare institutions). It seems that there are a couple other steps you could take, even you can't do anything with IOMMU.

The concern is not only about plugging a rogue device. It's also about the device itself failing to respect the host system security policy. We have no information on the NVIDIA GPU internals and how it enforces security constraints to user-submitted GPU workloads.

If not enforced correctly, a user submitted GPU task could write to kernel memory at arbitrary addresses and escalate privileges.

I guess the security minded solution here is to use VMs as the isolation mechanism. A VM for each user, or a VM for several users, depending on the information that can or can't be shared between them per their access rights.

@leolb-aphp
Copy link

leolb-aphp commented Mar 3, 2020

@Dubrzr

Disabling IOMMU can be a security issue, by disabling it you authorize GPUs to access memory addresses of all other GPUs.

This bug is probably in Nvidia drivers.
@jpellman : do you have any news on your NVBug post?

Thanks.

Not all other GPUs, but rather, all available main system memory on the system. GPUs have their own separate memory subsystem.

@vikramg1
Copy link

on my HP Envy dual boot laptop, I had to disable virtulization in the BIOS settings. Seems to work now.

atalyaalon added a commit to MatanAvitan/openpifpaf that referenced this issue Apr 21, 2020
atalyaalon added a commit to MatanAvitan/openpifpaf that referenced this issue Apr 21, 2020
@jramapuram
Copy link

Is there a better solution here than disabling IOMMU? Does upgrading CUDA/NCCL help?

@nfrumkin
Copy link

nfrumkin commented Nov 22, 2021

Try NCCL_P2P_DISABLE=1 in front of your training script. See: #1637

Example: NCCL_P2P_DISABLE=1 python3 train.py

@oytunturk
Copy link

Thanks, setting NCCL_P2P_DISABLE=1 solved a similar issue I was having with RTX A5000s.

@Borodin
Copy link

Borodin commented Jan 26, 2023

I faced the same problem, all the solutions described here are not suitable for me.

I have AMD FX(tm)-4100 Quad-Core Processor
RAM: 8 GB
2x NVIDIA GeForce GTX 1060 6GB
PyTorch 1.13.1
Ubuntu 22.04.1 LTS

Any help is much appreciated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
has workaround module: cuda Related to torch.cuda, and CUDA support in general module: data parallel module: deadlock Problems related to deadlocks (hang without exiting) module: dependency bug Problem is not caused by us, but caused by an upstream library we use module: multi-gpu Problem is related to running on multiple GPUs module: multiprocessing Related to torch.multiprocessing quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests