Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Mar 31, 2023
1 parent 252bde1 commit fa022ec
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@
PROFILER_AVAILABLE = False

try:
import torch_xla
import torch_xla.core.xla_model as xm

TORCHXLA_AVAILABLE = True
except ImportError as error:
TORCHXLA_AVAILABLE = False


logger = logging.getLogger(__name__)


# Possible values for backend in utils.py
def check_pt2_enabled():
try:
Expand Down Expand Up @@ -111,7 +112,9 @@ def initialize(self, context):
properties = context.system_properties
if torch.cuda.is_available() and properties.get("gpu_id") is not None:
self.map_location = "cuda"
self.device = torch.device(self.map_location + ":" + str(properties.get("gpu_id")))
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
)
elif TORCHXLA_AVAILABLE:
self.device = xm.xla_device()
else:
Expand Down Expand Up @@ -187,8 +190,9 @@ def initialize(self, context):
# Compilation will delay your model initialization
try:
self.model = torch.compile(
self.model, backend=backend,
mode="default" if TORCHXLA_AVAILABLE else "reduce-overhead"
self.model,
backend=backend,
mode="default" if TORCHXLA_AVAILABLE else "reduce-overhead",
)
logger.info(f"Compiled model with backend {backend}")
except:
Expand Down Expand Up @@ -252,7 +256,11 @@ def _load_pickled_model(self, model_dir, model_file, model_pt_path):
model_class = model_class_definitions[0]
model = model_class()
if model_pt_path:
map_location = None if (TORCHXLA_AVAILABLE and self.map_location is None) else self.device
map_location = (
None
if (TORCHXLA_AVAILABLE and self.map_location is None)
else self.device
)
state_dict = torch.load(model_pt_path, map_location=map_location)
model.load_state_dict(state_dict)
return model
Expand Down

0 comments on commit fa022ec

Please sign in to comment.