Skip to content

Commit

Permalink
add torch_xla 2.0 baackend; update map_location and device; remove ma…
Browse files Browse the repository at this point in the history
…rk_step
  • Loading branch information
morgandu committed Mar 25, 2023
1 parent b38a88e commit 9447d8e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 39 deletions.
58 changes: 19 additions & 39 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
else:
PROFILER_AVAILABLE = False

try:
import torch_xla
import torch_xla.core.xla_model as xm
TORCHXLA_AVAILABLE = True
except ImportError as error:
TORCHXLA_AVAILABLE = False


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,24 +66,6 @@ def check_pt2_enabled():
)


def check_torch_xla_enabled() -> bool:
try:
import torch_xla
import torch_xla.core.xla_model as xm
torch_xla_enabled = True
except ImportError as error:
torch_xla_enabled = False
logger.info(
"Proceed without PyTorch/XLA."
)
return torch_xla_enabled

torch_xla_enabled = check_torch_xla_enabled()
if torch_xla_enabled:
import torch_xla
import torch_xla.core.xla_model as xm


class BaseHandler(abc.ABC):
"""
Base default handler to load torchscript or eager mode [state_dict] models
Expand Down Expand Up @@ -120,19 +109,15 @@ def initialize(self, context):
)

properties = context.system_properties
self.map_location = (
"cuda"
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
if torch_xla_enabled:
self.map_location = None
if torch.cuda.is_available() and properties.get("gpu_id") is not None:
self.map_location = "cuda"
self.device = torch.device(self.map_location + ":" + str(properties.get("gpu_id")))
elif TORCHXLA_AVAILABLE:
self.map_location = "xla"
self.device = xm.xla_device()
else:
self.map_location = "cpu"
self.device = torch.device(self.map_location)

self.manifest = context.manifest

Expand Down Expand Up @@ -168,7 +153,7 @@ def initialize(self, context):

# Convert your model by following instructions: https://pytorch.org/tutorials/intermediate/nvfuser_intro_tutorial.html
# For TensorRT support follow instructions here: https://pytorch.org/TensorRT/getting_started/getting_started_with_python_api.html#getting-started-with-python-api
elif self.model_pt_path.endswith(".pt") and not torch_xla_enabled:
elif self.model_pt_path.endswith(".pt"):
self.model = self._load_torchscript_model(self.model_pt_path)
self.model.eval()

Expand Down Expand Up @@ -203,7 +188,8 @@ def initialize(self, context):
# Compilation will delay your model initialization
try:
self.model = torch.compile(
self.model, backend=backend, mode="reduce-overhead"
self.model, backend=backend,
mode="default" if TORCHXLA_AVAILABLE else "reduce-overhead"
)
logger.info(f"Compiled model with backend {backend}")
except:
Expand Down Expand Up @@ -267,10 +253,7 @@ def _load_pickled_model(self, model_dir, model_file, model_pt_path):
model_class = model_class_definitions[0]
model = model_class()
if model_pt_path:
state_dict = torch.load(
model_pt_path,
map_location=self.device if not torch_xla_enabled else None
)
state_dict = torch.load(model_pt_path, map_location=self.device)
model.load_state_dict(state_dict)
return model

Expand Down Expand Up @@ -303,9 +286,6 @@ def inference(self, data, *args, **kwargs):
with torch.no_grad():
marshalled_data = data.to(self.device)
results = self.model(marshalled_data, *args, **kwargs)
if torch_xla_enabled:
xm.mark_step()

return results

def postprocess(self, data):
Expand Down
1 change: 1 addition & 0 deletions ts/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class PT2Backend(str, enum.Enum):
FX2TRT = "fx2trt"
ONNXRT = "onnxrt"
IPEX = "ipex"
TORCHXLA_TRACE_ONCE = "torchxla_trace_once"


logger = logging.getLogger(__name__)
Expand Down

0 comments on commit 9447d8e

Please sign in to comment.