Skip to content

Commit

Permalink
Enable batch processing in scriptable tokenizer example
Browse files Browse the repository at this point in the history
  • Loading branch information
mreso committed Feb 15, 2023
1 parent 7e65972 commit 856b077
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
20 changes: 10 additions & 10 deletions examples/text_classification_with_scriptable_tokenizer/handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
Module for text classification with scriptable tokenizer
DOES NOT SUPPORT BATCH!
"""
import logging
from abc import ABC
Expand Down Expand Up @@ -51,18 +50,19 @@ def preprocess(self, data):

# Compat layer: normally the envelope should just return the data
# directly, but older versions of Torchserve didn't have envelope.
# Processing only the first input, not handling batch inference

line = data[0]
text = line.get("data") or line.get("body")
# Decode text if not a str but bytes or bytearray
if isinstance(text, (bytes, bytearray)):
text = text.decode("utf-8")
text_batch = []
for line in data:
text = line.get("data") or line.get("body")
# Decode text if not a str but bytes or bytearray
if isinstance(text, (bytes, bytearray)):
text = text.decode("utf-8")

text = remove_html_tags(text)
text = text.lower()
text = remove_html_tags(text)
text = text.lower()
text_batch.append(text)

return text
return text_batch

def inference(self, data, *args, **kwargs):
"""The Inference Request is made through this function and the user
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main(args):
model = XLMR_BASE_ENCODER.get_model(head=classifier_head)

# Load trained parameters and load them into the model
model.load_state_dict(torch.load(args.input_file))
model.load_state_dict(torch.load(args.input_file, map_location=torch.device("cpu")))

# Chain the tokenizer, the adapter and the model
combi_model = T.Sequential(
Expand All @@ -88,7 +88,7 @@ def main(args):
combi_model.eval()

# Make sure to move the model to CPU to avoid placement error during loading
combi_model.to("cpu")
combi_model.to(torch.device("cpu"))

combi_model_jit = torch.jit.script(combi_model)

Expand Down

0 comments on commit 856b077

Please sign in to comment.