Skip to content

Commit

Permalink
clear map_location for torchxla
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Mar 29, 2023
1 parent 0b3e1e8 commit f4a9619
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def initialize(self, context):
self.map_location = "cuda"
self.device = torch.device(self.map_location + ":" + str(properties.get("gpu_id")))
elif TORCHXLA_AVAILABLE:
self.map_location = "xla"
self.device = xm.xla_device()
else:
self.map_location = "cpu"
Expand Down Expand Up @@ -253,7 +252,8 @@ def _load_pickled_model(self, model_dir, model_file, model_pt_path):
model_class = model_class_definitions[0]
model = model_class()
if model_pt_path:
state_dict = torch.load(model_pt_path, map_location=self.device)
map_location = None if (TORCHXLA_AVAILABLE and self.map_location is None) else self.device
state_dict = torch.load(model_pt_path, map_location=map_location)
model.load_state_dict(state_dict)
return model

Expand Down

0 comments on commit f4a9619

Please sign in to comment.