| 
11 | 11 | 
 
  | 
12 | 12 | from enum import Enum  | 
13 | 13 | from pathlib import Path  | 
14 |  | -from typing import Any, Callable, Dict, List, Tuple  | 
 | 14 | +from typing import Any, Callable, Dict, List, Optional, Tuple  | 
15 | 15 | 
 
  | 
16 | 16 | import torch  | 
17 | 17 | 
 
  | 
@@ -77,31 +77,39 @@ def unpack_packed_weights(  | 
77 | 77 | def set_backend(dso, pte, aoti_package):  | 
78 | 78 |     global active_builder_args_dso  | 
79 | 79 |     global active_builder_args_pte  | 
 | 80 | +    global active_builder_args_aoti_package  | 
80 | 81 |     active_builder_args_dso = dso  | 
81 | 82 |     active_builder_args_aoti_package = aoti_package  | 
82 | 83 |     active_builder_args_pte = pte  | 
83 | 84 | 
 
  | 
84 | 85 | 
 
  | 
85 | 86 | class _Backend(Enum):  | 
86 |  | -    AOTI = (0,)  | 
 | 87 | +    AOTI = 0  | 
87 | 88 |     EXECUTORCH = 1  | 
88 | 89 | 
 
  | 
89 | 90 | 
 
  | 
90 |  | -def _active_backend() -> _Backend:  | 
 | 91 | +def _active_backend() -> Optional[_Backend]:  | 
91 | 92 |     global active_builder_args_dso  | 
92 | 93 |     global active_builder_args_aoti_package  | 
93 | 94 |     global active_builder_args_pte  | 
94 | 95 | 
 
  | 
95 |  | -    # eager == aoti, which is when backend has not been explicitly set  | 
96 |  | -    if (not active_builder_args_pte) and (not active_builder_args_aoti_package):  | 
97 |  | -        return True  | 
 | 96 | +    args = (  | 
 | 97 | +        active_builder_args_dso,  | 
 | 98 | +        active_builder_args_pte,  | 
 | 99 | +        active_builder_args_aoti_package,  | 
 | 100 | +    )  | 
 | 101 | + | 
 | 102 | +    # Return None, as default  | 
 | 103 | +    if not any(args):  | 
 | 104 | +        return None  | 
98 | 105 | 
 
  | 
99 |  | -    if active_builder_args_pte and active_builder_args_aoti_package:  | 
 | 106 | +    # Catch more than one arg  | 
 | 107 | +    if sum(map(bool, args)) > 1:  | 
100 | 108 |         raise RuntimeError(  | 
101 |  | -            "code generation needs to choose different implementations for AOTI and PTE path.  Please only use one export option, and call export twice if necessary!"  | 
 | 109 | +            "Code generation needs to choose different implementations.  Please only use one export option, and call export twice if necessary!"  | 
102 | 110 |         )  | 
103 | 111 | 
 
  | 
104 |  | -    return _Backend.AOTI if active_builder_args_pte else _Backend.EXECUTORCH  | 
 | 112 | +    return _Backend.EXECUTORCH if active_builder_args_pte else _Backend.AOTI  | 
105 | 113 | 
 
  | 
106 | 114 | 
 
  | 
107 | 115 | def use_aoti_backend() -> bool:  | 
 | 
0 commit comments