In [None]:
!pip install transformers==3.3.1 sagemaker==2.15.0  --quiet

In [None]:
from transformers import (RobertaForSequenceClassification,
                          RobertaTokenizer,
                          AdamW)


tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
model = RobertaForSequenceClassification.from_pretrained("roberta-base",
                                                                    num_labels = 2,
                                                                                       
                                                                    output_attentions = False, 
                                                                    output_hidden_states = False
                                                                )


In [None]:
import os

model_path = 'model/'
code_path = 'code/'

if not os.path.exists(model_path):
    os.mkdir(model_path)
    
model.save_pretrained(save_directory=model_path)
tokenizer.save_pretrained(save_directory=model_path)

In [None]:
!pygmentize code/inference.py

In [None]:
import tarfile

zipped_model_path = os.path.join(model_path, "model.tar.gz")

with tarfile.open(zipped_model_path, "w:gz") as tar:
    tar.add('model.pth')
    tar.add(model_path)
    tar.add(code_path)

In [None]:
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role

endpoint_name = 'roberta-project-final'

model = PyTorchModel(entry_point='inference.py', 
                     model_data=zipped_model_path, 
                     role=get_execution_role(), 
                     framework_version='1.5', 
                     py_version='py3')

predictor = model.deploy(initial_instance_count=1, 
                         instance_type='ml.m5.xlarge', 
                         endpoint_name=endpoint_name)



In [None]:
import boto3

sm = boto3.client('sagemaker-runtime')

prompt = "coronavius is very deadly"

response = sm.invoke_endpoint(EndpointName=endpoint_name, 
                              Body=prompt.encode(encoding='UTF-8'),
                              ContentType='text/csv')

result = response['Body'].read()
print(result)