In [2]:
from pathlib import Path

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from max import engine

In [3]:
HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
hf_model = AutoModelForSequenceClassification.from_pretrained(HF_MODEL_NAME)

# Converting model to TorchScript
model_path = Path("roberta.torchscript")
batch = 1
seqlen = 128
inputs = {
    "input_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
    "attention_mask": torch.zeros((batch, seqlen), dtype=torch.int64),
}
with torch.no_grad():
    traced_model = torch.jit.trace(
        hf_model, example_kwarg_inputs=dict(inputs), strict=False
    )

torch.jit.save(traced_model, model_path)



config.json:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/499M [00:00<?, ?B/s]



In [4]:
# We use the same `inputs` that we used above to trace the model
input_spec_list = [
    engine.TorchInputSpec(shape=tensor.size(), dtype=engine.DType.int64)
    for tensor in inputs.values()
]

In [5]:
session = engine.InferenceSession()
model = session.load(model_path, input_specs=input_spec_list)

Compiling model     
Done!


In [6]:
for tensor in model.input_metadata:
    print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')

name: input_ids, shape: [1, 128], dtype: DType.int64
name: attention_mask, shape: [1, 128], dtype: DType.int64


In [7]:
INPUT="There are many exciting developments in the field of AI Infrastructure!"

tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
inputs = tokenizer(INPUT, return_tensors="pt", padding='max_length', truncation=True, max_length=seqlen)
print(inputs)



tokenizer_config.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

{'input_ids': tensor([[    0,   970,    32,   171,  3571,  5126,    11,     5,   882,     9,
          4687, 13469,   328,     2,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,  

In [8]:
outputs = model.execute(**inputs)
print(outputs)

{'result0': {'logits': array([[-3.798778  ,  0.49929348, -4.28773   , -2.5863957 ,  2.9503946 ,
        -2.1120906 ,  2.5074234 , -4.4121103 , -4.901352  , -2.1473591 ,
        -0.5741749 ]], dtype=float32)}}
