Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] nested dictionaries are not traced correctly. #75012

Open
randolf-scholz opened this issue Mar 31, 2022 · 3 comments
Open

[JIT] nested dictionaries are not traced correctly. #75012

randolf-scholz opened this issue Mar 31, 2022 · 3 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects

Comments

@randolf-scholz
Copy link
Contributor

馃悰 Describe the bug

Second and third level dictionaries are only accepted if they are all the same type.

from importlib import import_module
from typing import Any, Dict, Final, List, Union
from torch import Tensor, jit, nn

class Demo(nn.Module):
    config: Dict[str, Any]
    """Chosen parameters."""

    def __init__(self) -> None:
        super().__init__()
        self.config = {
            "a": True,
            "b": 42,
            "c": "foo",
            "foo": {"a": 1, "b": 2, "c": 3},
            "bar": {"a": {"a": 1}, "b": {"a": 1}, "c": {"a": 1}},
            "baz": {"a": True, "b": 2},  # <-- does not work!
        } 


model = Demo()
scripted_model = jit.script(model)
scripted_model.save("scripted_model")
scripted_model = jit.load("scripted_model")
print(scripted_model.config)

Raises

Could not cast attribute 'config' to type Dict[str, Any]: Tracer cannot infer type of {'a': True, 'b': 2}
:Dictionary inputs to traced functions must have consistent type. Found bool and int

Which is weird because top level dictionaries with variable types get traced just fine. Commenting baz and it compiles fine as well as then all nested dictionaries are of the same type.

Versions

PyTorch version: 1.11.0
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.9.12 | packaged by conda-forge | (main, Mar 24 2022, 23:25:59)  [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.13.0-37-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.5.119
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 510.47.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.3.3
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.2.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.2.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn.so.8.2.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.1
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] functorch==0.0.1a0+8ffa328
[pip3] mypy==0.942
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.5
[pip3] numpydoc==1.2.1
[pip3] pytorch-forecasting==0.10.1
[pip3] pytorch-lightning==1.6.0
[pip3] torch==1.11.0
[pip3] torchaudio==0.11.0
[pip3] torchdiffeq==0.2.2
[pip3] torchinfo==1.6.5
[pip3] torchmetrics==0.7.3
[pip3] torchtext==0.12.0
[pip3] torchvision==0.12.0
[conda] blas                      2.113                       mkl    conda-forge
[conda] blas-devel                3.9.0            13_linux64_mkl    conda-forge
[conda] cudatoolkit               11.3.1               ha36c431_9    nvidia
[conda] libblas                   3.9.0            13_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            13_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            13_linux64_mkl    conda-forge
[conda] liblapacke                3.9.0            13_linux64_mkl    conda-forge
[conda] mkl                       2022.0.1           h06a4308_117  
[conda] mkl-devel                 2022.0.1           h66538d2_117  
[conda] mkl-include               2022.0.1           h06a4308_117  
[conda] mkl-service               2.4.0            py39h404a4ab_0    conda-forge
[conda] mypy                      0.942            py39hb9d737c_0    conda-forge
[conda] mypy_extensions           0.4.3            py39h06a4308_1  
[conda] numexpr                   2.8.0           mkl_py39hba4d566_1    conda-forge
[conda] numpy                     1.21.5           py39haac66dc_0    conda-forge
[conda] numpydoc                  1.2.1              pyhd8ed1ab_0    conda-forge
[conda] pytorch                   1.11.0          py3.9_cuda11.3_cudnn8.2.0_0    pytorch
[conda] pytorch-forecasting       0.10.1             pyhd8ed1ab_0    conda-forge
[conda] pytorch-lightning         1.6.0              pyhd8ed1ab_0    conda-forge
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                0.11.0               py39_cu113    pytorch
[conda] torchdiffeq               0.2.2              pyhd8ed1ab_0    conda-forge
[conda] torchinfo                 1.6.5              pyhd8ed1ab_0    conda-forge
[conda] torchmetrics              0.7.3              pyhd8ed1ab_0    conda-forge
[conda] torchtext                 0.12.0                     py39    pytorch
[conda] torchvision               0.12.0               py39_cu113    pytorch
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 31, 2022
@github-actions github-actions bot added this to Need triage in JIT Triage Mar 31, 2022
@davidberard98 davidberard98 self-assigned this Apr 8, 2022
@nimaous
Copy link

nimaous commented Nov 4, 2022

Hi.

I have the same problem. have you found any solution?

@randolf-scholz
Copy link
Contributor Author

randolf-scholz commented Nov 4, 2022

@nimaous Currently I just flatten and unflatten the dictionary as needed

from typing import Any, Callable, Sequence


def flatten_dict(
    d: dict[str, Any],
    /,
    *,
    recursive: bool = True,
    join_fn: Callable[[Sequence[str]], str] = ".".join,
) -> dict[str, Any]:
    r"""Flatten dictionaries recursively."""
    result: dict[str, Any] = {}
    for key, item in d.items():
        if isinstance(item, dict) and recursive:
            subdict = flatten_dict(item, recursive=True, join_fn=join_fn)
            for subkey, subitem in subdict.items():
                result[join_fn((key, subkey))] = subitem
        else:
            result[key] = item
    return result


def unflatten_dict(
    d: dict[str, Any],
    /,
    *,
    recursive: bool = True,
    split_fn: Callable[[str], Sequence[str]] = lambda s: s.split(".", maxsplit=1),
) -> dict[str, Any]:
    r"""Unflatten dictionaries recursively."""
    result = {}
    for key, item in d.items():
        split = split_fn(key)
        result.setdefault(split[0], {})
        if len(split) > 1 and recursive:
            assert len(split) == 2
            subdict = unflatten_dict(
                {split[1]: item}, recursive=recursive, split_fn=split_fn
            )
            result[split[0]] |= subdict
        else:
            result[split[0]] = item
    return result



a = {
    "a": True,
    "b": 42,
    "c": "foo",
    "foo": {"a": 1, "b": 2, "c": 3},
    "bar": {"a": {"a": 1}, "b": {"a": 1}, "c": {"a": 1}},
    "baz": {"a": True, "b": 2},  # <-- does not work!
}

print(flat := flatten_dict(a))
print(unflat := unflatten_dict(a))
assert unflat == a

@ecolss
Copy link

ecolss commented Mar 21, 2023

Same problem, will pytorch team support this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
JIT Triage
  
Need triage
Development

No branches or pull requests

5 participants