/
base_handler.py
128 lines (102 loc) · 4.35 KB
/
base_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Base default handler to load torchscript or eager mode [state_dict] models
Also, provides handle method per torch serve custom model specification
"""
import abc
import logging
import os
import importlib.util
import torch
from ..utils.util import list_classes_from_module, load_label_mapping
logger = logging.getLogger(__name__)
class BaseHandler(abc.ABC):
"""
Base default handler to load torchscript or eager mode [state_dict] models
Also, provides handle method per torch serve custom model specification
"""
def __init__(self):
self.model = None
self.mapping = None
self.device = None
self.initialized = False
self.context = None
self.manifest = None
self.map_location = None
def initialize(self, context):
"""First try to load torchscript else load eager mode state_dict based model"""
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() else self.map_location)
self.manifest = context.manifest
model_dir = properties.get("model_dir")
serialized_file = self.manifest['model']['serializedFile']
model_pt_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")
# model def file
model_file = self.manifest['model'].get('modelFile', '')
if model_file:
logger.debug('Loading eager model')
self.model = self._load_pickled_model(model_dir, model_file, model_pt_path)
else:
logger.debug('Loading torchscript model')
self.model = self._load_torchscript_model(model_pt_path)
self.model.to(self.device)
self.model.eval()
logger.debug('Model file %s loaded successfully', model_pt_path)
# Load class mapping for classifiers
mapping_file_path = os.path.join(model_dir, "index_to_name.json")
self.mapping = load_label_mapping(mapping_file_path)
self.initialized = True
def _load_torchscript_model(self, model_pt_path):
return torch.jit.load(model_pt_path, map_location=self.map_location)
def _load_pickled_model(self, model_dir, model_file, model_pt_path):
model_def_path = os.path.join(model_dir, model_file)
if not os.path.isfile(model_def_path):
raise RuntimeError("Missing the model.py file")
module = importlib.import_module(model_file.split(".")[0])
model_class_definitions = list_classes_from_module(module)
if len(model_class_definitions) != 1:
raise ValueError("Expected only one class as model definition. {}".format(
model_class_definitions))
model_class = model_class_definitions[0]
state_dict = torch.load(model_pt_path, map_location=self.map_location)
model = model_class()
model.load_state_dict(state_dict)
return model
def preprocess(self, data):
"""
Override to customize the pre-processing
:param data: Python list of data items
:return: input tensor on a device
"""
return torch.as_tensor(data, device=self.device)
def inference(self, data, *args, **kwargs):
"""
Override to customize the inference
:param data: Torch tensor, matching the model input shape
:return: Prediction output as Torch tensor
"""
marshalled_data = data.to(self.device)
with torch.no_grad():
results = self.model(marshalled_data, *args, **kwargs)
return results
def postprocess(self, data):
"""
Override to customize the post-processing
:param data: Torch tensor, containing prediction output from the model
:return: Python list
"""
return data.tolist()
def handle(self, data, context):
"""
Entry point for default handler
"""
# It can be used for pre or post processing if needed as additional request
# information is available in context
self.context = context
data = self.preprocess(data)
data = self.inference(data)
data = self.postprocess(data)
return data