Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Conv3d is not accelerated with tensorcores (using autocast/AMP) #57115

Closed
FabianIsensee opened this issue Apr 28, 2021 · 7 comments
Closed
Labels
high priority module: binaries Anything related to official binaries that we release to users module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@FabianIsensee
Copy link

FabianIsensee commented Apr 28, 2021

馃悰 Bug

When installing either the current nightly or version 1.8.1 using pip/conda the tensorcore acceleration of nn.Conv3D with Nvidia GPUs is not working. However, when compiling pytorch from source it works as intended and gives a ~3x speedup relative to regular fp32 training. Given that compiling from source is something not all users will be able to do (or want to do) it would be nice if that was fixed in the pip/conda installer.

To Reproduce

Steps to reproduce the behavior:

  1. run the following python script on different GPUs and pytorch versions
  2. the output of the script (stdout) is two numbers representing s/iter for mixed precision and fp32, respectively. So if the output is (0.2, 0.5) then mixed precision took 0.2s whereas fp32 took 0.5s/iter. Lower is better
from time import time
from torch import nn
import torch
from torch.backends import cudnn
import numpy as np
from torch.cuda.amp import GradScaler, autocast


def run_timed_iterations_fp32(n_steps, batch, gt, loss, optimizer, model, n_warmup):
    for n in range(n_warmup):
        optimizer.zero_grad()
        out = model(batch)
        l = loss(out, gt)
        l.backward()
        optimizer.step()

    times = []
    for n in range(n_steps):
        st = time()

        optimizer.zero_grad()
        out = model(batch)
        l = loss(out, gt)
        l.backward()
        optimizer.step()
        times.append(time() - st)
    return np.mean(times)


def run_timed_iterations_fp16(n_steps, batch, gt, loss, optimizer, model, n_warmup):
    scaler = GradScaler()

    for n in range(n_warmup):
        optimizer.zero_grad()
        with autocast():
            out = model(batch)
            l = loss(out, gt)
        scaler.scale(l).backward()
        scaler.step(optimizer)
        scaler.update()

    times = []
    for n in range(n_steps):
        st = time()

        optimizer.zero_grad()
        with autocast():
            out = model(batch)
            l = loss(out, gt)
        scaler.scale(l).backward()
        scaler.step(optimizer)
        scaler.update()
        times.append(time() - st)
    return np.mean(times)


class VGG3D(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(nn.Conv3d(3, 32, 3, 1, 1, bias=False),
                                   nn.BatchNorm3d(32),
                                   nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv3d(32, 32, 3, 1, 1, bias=False),
                                   nn.BatchNorm3d(32),
                                   nn.ReLU(True))

        self.conv3 = nn.Sequential(nn.Conv3d(32, 64, 3, 2, 1, bias=False),
                                   nn.BatchNorm3d(64),
                                   nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv3d(64, 64, 3, 1, 1, bias=False),
                                   nn.BatchNorm3d(64),
                                   nn.ReLU(True))

        self.conv5 = nn.Sequential(nn.Conv3d(64, 128, 3, 2, 1, bias=False),
                                   nn.BatchNorm3d(128),
                                   nn.ReLU(True))
        self.conv6 = nn.Sequential(nn.Conv3d(128, 128, 3, 1, 1, bias=False),
                                   nn.BatchNorm3d(128),
                                   nn.ReLU(True))
        self.conv7 = nn.Sequential(nn.Conv3d(128, 128, 3, 1, 1, bias=False),
                                   nn.BatchNorm3d(128),
                                   nn.ReLU(True))

        self.conv8 = nn.Sequential(nn.Conv3d(128, 256, 3, 2, 1, bias=False),
                                   nn.BatchNorm3d(256),
                                   nn.ReLU(True))
        self.conv9 = nn.Sequential(nn.Conv3d(256, 256, 3, 1, 1, bias=False),
                                   nn.BatchNorm3d(256),
                                   nn.ReLU(True))
        self.conv10 = nn.Sequential(nn.Conv3d(256, 256, 3, 1, 1, bias=False),
                                   nn.BatchNorm3d(256),
                                   nn.ReLU(True))

        self.conv11 = nn.Sequential(nn.Conv3d(256, 512, 3, 2, 1, bias=False),
                                    nn.BatchNorm3d(512),
                                    nn.ReLU(True))
        self.conv12 = nn.Sequential(nn.Conv3d(512, 512, 3, 1, 1, bias=False),
                                    nn.BatchNorm3d(512),
                                    nn.ReLU(True))
        self.conv13 = nn.Sequential(nn.Conv3d(512, 512, 3, 1, 1, bias=False),
                                    nn.BatchNorm3d(512),
                                    nn.ReLU(True))

        self.conv14 = nn.Sequential(nn.Conv3d(512, 512, 3, 2, 1, bias=False),
                                    nn.BatchNorm3d(512),
                                    nn.ReLU(True))
        self.conv15 = nn.Sequential(nn.Conv3d(512, 512, 3, 1, 1, bias=False),
                                    nn.BatchNorm3d(512),
                                    nn.ReLU(True))
        self.conv16 = nn.Sequential(nn.Conv3d(512, 512, 3, 1, 1, bias=False),
                                    nn.BatchNorm3d(512),
                                    nn.ReLU(True))

        self.gap = nn.AdaptiveAvgPool3d(output_size=1)
        self.classifier = nn.Conv3d(512, 10, 1, 1, 0, bias=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        x = self.conv11(x)
        x = self.conv12(x)
        x = self.conv13(x)
        x = self.conv14(x)
        x = self.conv15(x)
        x = self.conv16(x)
        x = self.gap(x)
        out = self.classifier(x).squeeze()
        return out


if __name__ == '__main__':
    net = VGG3D().cuda()
    data = torch.rand((4, 3, 128, 128, 128)).cuda()
    gt = torch.randint(10, (4, )).cuda()

    loss = nn.CrossEntropyLoss()

    cudnn.benchmark = True
    torch.cuda.empty_cache()

    optim = torch.optim.SGD(net.parameters(), 0.01)

    ret16 = run_timed_iterations_fp16(50, data, gt, loss, optim, net, 10)
    torch.cuda.empty_cache()
    ret32 = run_timed_iterations_fp32(50, data, gt, loss, optim, net, 10)
    print(ret16, ret32)

Expected behavior

fp16/mixed precision should be ~3x faster than fp32. This number comes from my own experiments using pytorch version that I compiled myself. I tested both Turing and Ampere GPUs. Here are my results:

pytorch 1.7.1 + cuDNN 8.1.0.77; compiled myself
3090: 0.18401368618011474 0.5708892440795899
2080ti: 0.2557570123672485 0.6653153610229492

pytorch 1.8.1 + cuDNN 8.2.0.53; compiled myself
3090: 0.1797804880142212 0.5501960325241089
2080ti: 0.21456592559814452 0.7236361074447631

pytorch '1.9.0.dev20210427+cu111' + cuDNN "8005"; installed with pip
3090: 0.685448751449585 0.7036943531036377
2080ti: 0.9805365180969239 0.7098741006851196

As you can see, both pytorch 1.7.1 + cuDNN 8.1.0.77 and pytorch 1.8.1 + cuDNN 8.2.0.53 have the expected speedup when using mixed precision. pytorch '1.9.0.dev20210427+cu111' + cuDNN "8005" does not.

(I do not have a pytorch version that I compiled myself with cuDNN 8005 but I used to have it and I know that it worked with Turing at least)

Environment

(this is the RTX 3090 system.If you need info for the RTX 2080ti system as well let me know)

Collecting environment information...
PyTorch version: 1.8.0a0+56b43f4
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.19.6

Python version: 3.9 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.3.58
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 465.19.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] efficientnet-pytorch==0.7.1
[pip3] numpy==1.20.1
[pip3] torch==1.8.0a0+56b43f4
[pip3] torchvision==0.9.1
[conda] blas 1.0 mkl
[conda] efficientnet-pytorch 0.7.1 pypi_0 pypi
[conda] magma-cuda112 2.5.2 1 pytorch
[conda] mkl 2021.2.0 h06a4308_296
[conda] mkl-include 2021.2.0 h06a4308_296
[conda] mkl-service 2.3.0 py39h27cfd23_1
[conda] mkl_fft 1.3.0 py39h42c9631_2
[conda] mkl_random 1.2.1 py39ha9443f7_2
[conda] numpy 1.20.1 py39h93e21f0_0
[conda] numpy-base 1.20.1 py39h7d8b39e_0
[conda] torch 1.8.0a0+56b43f4 pypi_0 pypi
[conda] torchvision 0.9.1 pypi_0 pypi

Thank you!
Best,
Fabian

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @seemethere @malfet @walterddr @ngimel @csarofeen @ptrblck @xwang233 @VitalyFedyunin

@zou3519 zou3519 added module: binaries Anything related to official binaries that we release to users module: cuda Related to torch.cuda, and CUDA support in general module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module high priority labels Apr 28, 2021
@zou3519
Copy link
Contributor

zou3519 commented Apr 28, 2021

Tentatively marking as hi-pri because it sounds like we have incorrect behavior (if the source build of pytorch utilizes tensor cores, the nightly should as well). Unless this has something to do with the cuda versions we build binaries for?

@FabianIsensee
Copy link
Author

The 2080ti system I tested only has cuda 11.0. Binaries are provided with cuda 11.1 (those are the ones I used), so that's probably not it. Thanks for looking into this!

@ngimel ngimel added the module: cudnn Related to torch.backends.cudnn, and CuDNN support label Apr 28, 2021
@ptrblck
Copy link
Collaborator

ptrblck commented Apr 28, 2021

Thanks for creating this issue. You are most likely hitting this bug.

@FabianIsensee
Copy link
Author

That could indeed be the case. I see the same behavior with nn.Conv2D (not sure why I thought this was a conv3d problem):

nn.Conv2D instead of 3d, RTX 3090 GPU:

1.8.0a0+56b43f4 8200 (built from source)
0.1864670705795288 0.28574702739715574

1.9.0.dev20210427+cu111 8005 (binary, pip install)
0.27138203144073486 0.30072999477386475

It's interesting that the problem only affects mixed precision training. FP32 has almost the same speed when I compare source vs binary

@karthikkrishnan81
Copy link

Is this fixed ?

@FabianIsensee
Copy link
Author

RTX 3090, pip install, 3D network

In [2]: torch.backends.cudnn.version()
Out[2]: 8005
In [3]: torch.version
Out[3]: '1.9.0+cu111'

0.2136370611190796 0.5353927850723267

RTX 3090, compiled myself, 3D network

In [2]: torch.backends.cudnn.version()
Out[2]: 8200
In [3]: torch.version
Out[3]: '1.8.0a0+56b43f4'

0.17133057594299317 0.5131863641738892

looks like it's fixed. But best test it yourself as well. Note that binaries were always slower than what I compiled myself.
Best,
Fabian

@karthikkrishnan81
Copy link

@FabianIsensee Indeed, you are right. I can verify this:

$ python3 verify.py
0.385063009262085 1.1304261589050293

CUDA Version: 11.3 GeForce 3060 ,
Torch version 1.10.0.dev20210717+cu113
torch.backends.cudnn.version() = 8200

I suppose you may remove the requirement to build pytorch from source on your nnunet page.
As a final note, a big thank you for nnUNet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: binaries Anything related to official binaries that we release to users module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants