Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] torch.jit.optimized_execution(True) greatly slows down some operations in PyTorch 1.8.0 #53824

Open
kuynzereb opened this issue Mar 11, 2021 · 1 comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects

Comments

@kuynzereb
Copy link

kuynzereb commented Mar 11, 2021

馃悰 Bug

In PyTorch 1.8.0 JIT recompiles some functions every time if input tensor changes its content (not the shape).

To Reproduce

If I run the following code

optimize = True
device = 'cuda:0'
num_runs = 5


import time

import torch


def func(mask: torch.Tensor):
    H, W = mask.size()

    tensor = torch.zeros([H, W], device=mask.device)
    masked_view = tensor[mask]
    output = torch.stack([masked_view, masked_view + W + 1], dim=1)

    return output


jit_func = torch.jit.script(func)


def get_random_mask():
    mask = torch.randint(2, size=[1000, 1000], dtype=torch.bool, device=device)

    return mask


with torch.jit.optimized_execution(optimize):
    times = []
    for i in range(num_runs):
        mask = get_random_mask()

        torch.cuda.synchronize(device)
        start = time.perf_counter()

        _ = jit_func(mask)

        torch.cuda.synchronize(device)
        elapsed_time = time.perf_counter() - start

        times.append(elapsed_time)

print(f'PyTorch version: {torch.__version__}')
print(f'Optimized execution: {optimize}')
print(f"Times:")
print("\n".join([f"{x:.4f} sec." for x in times]))

I got the following results:

PyTorch version: 1.8.0
Optimized execution: False
Times:
0.0007 sec.
0.0002 sec.
0.0002 sec.
0.0002 sec.
0.0002 sec.

PyTorch version: 1.8.0
Optimized execution: True
Times:
0.0402 sec.
0.1237 sec.
0.1194 sec.
0.1202 sec.
0.1204 sec.

PyTorch version: 1.7.1+cu110
Optimized execution: True
Times:
0.0024 sec.
0.1230 sec.
0.0003 sec.
0.0002 sec.
0.0002 sec.

PyTorch version: 1.7.1+cu110
Optimized execution: False
Times:
0.0007 sec.
0.0003 sec.
0.0002 sec.
0.0002 sec.
0.0002 sec.

Evidently, PyTorch 1.8.0 recompiles this function for every new random mask, even though its shape is unchanged.

Expected behavior

JIT should not recompile this function for each new mask.

Environment

PyTorch version: 1.8.0
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.10.2

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: GeForce RTX 2080 SUPER
Nvidia driver version: 460.32.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.1
[pip3] pytorch-lightning==1.1.4
[pip3] torch==1.8.0
[pip3] torchvision==0.9.0
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.1.1 h6406543_8 conda-forge
[conda] mkl 2020.2 256
[conda] mkl-service 2.3.0 py38he904b0f_0
[conda] mkl_fft 1.3.0 py38h54f3939_0
[conda] mkl_random 1.1.1 py38h0573a6f_0
[conda] numpy 1.20.1 pypi_0 pypi
[conda] pytorch 1.8.0 py3.8_cuda11.1_cudnn8.0.5_0 pytorch
[conda] pytorch-lightning 1.1.4 pypi_0 pypi
[conda] torchvision 0.9.0 py38_cu111 pytorch

cc @gmagogsfm

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 11, 2021
@github-actions github-actions bot added this to Need triage in JIT Triage Mar 11, 2021
@gmagogsfm
Copy link
Contributor

cc @eellison

@eellison eellison added this to Needs triage in NNC via automation Mar 12, 2021
@eellison eellison removed this from Need triage in JIT Triage Mar 12, 2021
@ZolotukhinM ZolotukhinM moved this from Needs triage to High priority in NNC Mar 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
NNC
High priority
Development

No branches or pull requests

3 participants