Skip to content

[Dynamo][Compile]Torch compile with dynamic shapes not working #105279

@bohnstingl

Description

@bohnstingl

🐛 Describe the bug

My networks rely on varying shapes during training as well as during inference. Thus, I tried to use torch.compile(... dynamic=True) as well as the torch._dynamo.optimize(..., dynamic=True) feature. However, unfortunately I can't get it to work properly and the function gets recompiled always as soon as the input shape changes. I tried it with the latest nightly build verison 2.1.0.dev20230712 as well as with the current main branch from Friday Jul 14th 2023. I printed the guards that trigger the recompilation and the culprit guard is precisely what the dynamic feature targets:
GuardFail(reason="tensor 'L['input']' size mismatch at index 0. expected 146, actual 149", orig_code=<code object RNNScript at 0x2758840 ...)

I also tried to add a warm-up phase, where the input length is varied, but whenever the shape of the input changes, the function gets recompiled.

@ezyang could you maybe please advice here?

Error logs

GuardFail(reason="tensor 'L['input']' size mismatch at index 0. expected 146, actual 149", orig_code=<code object RNNScript at 0x2758840 ...)

Minified repro

from typing import List, Tuple, Optional, overload, Union, cast
import torch
import numpy as np
import time
import torch.optim as optim
from torch.nn.parameter import Parameter

def RNNScript(
    input,
    param1,
    param2,
    ):

    state1 = torch.zeros(32, 340, dtype=input.dtype, device=input.device)
    
    outs = []

    Wx = input @ param1
    Wx_inp, Wx_rec = torch.tensor_split(Wx, 2, 2)
    for wt_inp, wt_rec in zip(Wx_inp, Wx_rec):
        rec_mul_inp, rec_mul_rec = torch.tensor_split(state1 @ param2, 2, 1)
        input_prev = (wt_inp + rec_mul_inp)
        output_gate = (wt_rec + rec_mul_rec)

        state1 = input_prev * torch.sigmoid(output_gate)
        outs.append(state1)
    
    outs = torch.stack(outs)

    return outs, (outs)

if __name__ == "__main__":

    input_size = 140
    hidden_size = 340
    num_layers = 1
    num_timesteps = 111
    batch_size = 32

    bi_dir = True 
    rnnt_input = False
    num_threads = -1
    use_gpu = True
    load_weights = False

    forward_times = []
    backward_times = []

    if use_gpu:
        device = torch.device('cuda:0')
    else:
        device = None

    parameters = []

    w_ih = torch.empty((input_size, hidden_size), device=device)
    w_io = torch.empty((input_size, hidden_size), device=device)
    w_i_comb = Parameter(torch.cat([w_ih,w_io],1))
    parameters.append(w_i_comb)

    w_hh = torch.empty((hidden_size, hidden_size), device=device)
    w_ho = torch.empty((hidden_size, hidden_size), device=device)
    w_h_comb = Parameter(torch.cat([w_hh,w_ho],1))
    parameters.append(w_h_comb)
    
    def count_kernels(guard):
        print("[pt2_compile] guard failed: ", guard)

    rnnscript = torch.compile(RNNScript, mode='reduce-overhead', dynamic=True, fullgraph=True)
    #backend = torch._TorchCompileInductorWrapper('reduce-overhead', None, True)
    #rnnscript = torch._dynamo.optimize(backend=backend, nopython=True, dynamic=True, guard_fail_fn=count_kernels)(RNNScript)
    #rnnscript = RNNScript
    snu = lambda x: rnnscript(x, w_i_comb, w_h_comb)

    optimizer = optim.SGD(parameters, 0.1)

    inp = torch.rand((num_timesteps, batch_size, input_size))

    if use_gpu:
        inp = inp.cuda()

    optimizer.zero_grad()
    for execution in range(5):
        start_forward = time.time_ns()
        t_rnd = np.random.randint(0, 200)
        inp = torch.rand((t_rnd, batch_size, input_size))
        if use_gpu:
            inp = inp.cuda()
        out, state = snu(inp)

        if use_gpu:
            torch.cuda.synchronize()
        stop_forward = time.time_ns()
        forward_times.append((stop_forward - start_forward) / (10 ** 9))

        loss = 1. - torch.sum(out)

        start_time_backward = time.time_ns()
        #loss.backward()
        if use_gpu:
            torch.cuda.synchronize()
        stop_time_backward = time.time_ns()
        backward_times.append((stop_time_backward - start_time_backward) / (10 ** 9))

    print('================================================================')
    print('Model with sSNU-os:')
    print('# Layers: ' + str(num_layers))
    print('# Units per layer: ' + str(hidden_size))
    print('Bidirectional: ' + str(bi_dir))
    print('Load weights: '  + str(load_weights))
    print('RNN-T input: ' + str(rnnt_input))
    print('# CPU threads: ' + str(num_threads))
    print('GPU support: ' + str(use_gpu))
    print('----------------------------------------------------------------')
    print('Timing summary')
    print('Time of forward computation: {:.4f} +- {:.4f} s'.format(np.mean(np.array(forward_times)), np.std(np.array(forward_times))))
    print('Time of backward computation: {:.4f} +- {:.4f} s'.format(np.mean(np.array(backward_times)), np.std(np.array(backward_times))))

Versions

Collecting environment information...
PyTorch version: 2.1.0a0+gitfb376f8
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Red Hat Enterprise Linux release 8.7 (Ootpa) (x86_64)
GCC version: (GCC) 8.5.0 20210514 (Red Hat 8.5.0-16)
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.28

Python version: 3.11.3 (main, May 15 2023, 15:45:52) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-4.18.0-425.10.1.el8_7.x86_64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 525.60.13
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7742 64-Core Processor
Stepping: 0
CPU MHz: 3361.974
CPU max MHz: 2250.0000
CPU min MHz: 1500.0000
BogoMIPS: 4499.91
Virtualization: AMD-V
L1d cache: 32K
L1i cache: 32K
L2 cache: 512K
L3 cache: 16384K
NUMA node0 CPU(s): 0-63
NUMA node1 CPU(s): 64-127
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es

Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] pytorch-triton==2.1.0+3c400e7818
[pip3] torch==2.1.0a0+gitfb376f8
[pip3] torchaudio==2.1.0a0+cf53a48
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.8.0 h6a678d5_0
[conda] magma-cuda118 2.6.1 1 pytorch
[conda] mkl 2023.1.0 h6d00ec8_46342
[conda] mkl-include 2023.1.0 h06a4308_46342
[conda] mkl-service 2.4.0 py311h5eee18b_1
[conda] mkl_fft 1.3.6 py311ha02d727_1
[conda] mkl_random 1.2.2 py311ha02d727_1
[conda] numpy 1.24.3 py311h08b1b3b_1
[conda] numpy-base 1.24.3 py311hf175353_1
[conda] pytorch-triton 2.1.0+3c400e7818 pypi_0 pypi
[conda] torch 2.1.0a0+gitfb376f8 pypi_0 pypi
[conda] torchaudio 2.1.0a0+cf53a48 dev_0

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @wconstab

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions