Skip to content

Commit 47fae94

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Minor code rewriting to enable torch 1.13
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent 331c00b commit 47fae94

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

TrainingExtensions/torch/src/python/aimet_torch/_base/nn/modules/custom.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,17 @@ def create_wrapper_module(
155155
Normalize = create_wrapper_module("Normalize", torch.nn.functional.normalize)
156156
Pad = create_wrapper_module("Pad", torch.nn.functional.pad)
157157
GridSample = create_wrapper_module("GridSample", torch.nn.functional.grid_sample)
158-
ScaledDotProductAttention = create_wrapper_module(
159-
"ScaledDotProductAttention", torch.nn.functional.scaled_dot_product_attention
160-
)
158+
159+
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
160+
ScaledDotProductAttention = create_wrapper_module(
161+
"ScaledDotProductAttention", torch.nn.functional.scaled_dot_product_attention
162+
)
163+
elif hasattr(torch.nn.functional, "_scaled_dot_product_attention"):
164+
ScaledDotProductAttention = create_wrapper_module(
165+
"ScaledDotProductAttention", torch.nn.functional._scaled_dot_product_attention
166+
)
167+
else:
168+
ScaledDotProductAttention = None
161169

162170

163171
# following modules are for overloaded operators like + and *,

TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
# =============================================================================
3737

3838
# pylint: disable=missing-docstring
39-
from . import export
39+
try:
40+
from . import export
41+
except ImportError:
42+
pass
43+
4044
from . import onnx
4145
from .quantsim_utils import *

TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/backends/torch_builtins.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@
4040
from packaging import version
4141
from typing import Callable, Optional, List, Tuple
4242
import torch
43-
import torch.ao.quantization.fx._decomposed
43+
44+
try:
45+
import torch.ao.quantization.fx._decomposed
46+
except ImportError:
47+
pass
4448
from aimet_torch.v2.utils import (
4549
_is_expandable,
4650
_ContextManager,

0 commit comments

Comments
 (0)