# Hugging Face Sentiment Classification
__Binary Classification with `Trainer` and `sst2` dataset__

## Runtime

This notebook takes approximately 45 minutes to run.

## Contents

1. [Introduction](#Introduction)  
2. [Development environment and permissions](#Development-environment-and-permissions)
    1. [Installation](#Installation)  
    2. [Development environment](#Development-environment)  
    3. [Permissions](#Permissions)
3. [Pre-processing](#Pre-processing)   
    1. [Tokenize sentences](#Tokenize-sentences)  
    2. [Upload data to sagemaker_session_bucket](#Upload-data-to-sagemaker_session_bucket)  
4. [Fine-tune the model and start a SageMaker training job](#Fine-tune-the-model-and-start-a-SageMaker-training-job)  
    1. [Create an Estimator and start a training job](#Create-an-Estimator-and-start-a-training-job)  
    2. [Estimator Parameters](#Estimator-Parameters)   
    3. [Attach a previous training job to an estimator](#Attach-a-previous-training-job-to-an-estimator)  

## Introduction

Welcome to our end-to-end binary text classification example. This notebook uses Hugging Face's `transformers` library with a custom Amazon sagemaker-sdk extension to fine-tune a pre-trained transformer on binary text classification. The pre-trained model is fine-tuned using the `sst2` dataset. To get started, we need to set up the environment with a few prerequisite steps for permissions, configurations, and so on. 

This notebook is adapted from Hugging Face's notebook [Huggingface Sagemaker-sdk - Getting Started Demo](https://github.com/huggingface/notebooks/blob/master/sagemaker/01_getting_started_pytorch/sagemaker-notebook.ipynb) and provided here courtesy of Hugging Face.

<img src="text_classification.png" width="700"/>

## Runtime

This notebook takes approximately 40 minutes to run.

<i>NOTE: You can run this notebook in SageMaker Studio, a SageMaker notebook instance, or your local machine. This notebook was tested in a notebook instance using the conda\_pytorch\_p36 kernel.</i>


## Development environment and permissions 

### Installation

_*Note:* We install the required libraries from Hugging Face and AWS. You also need PyTorch, if you haven't installed it already._

In [1]:
!pip install "sagemaker" "transformers" "datasets[s3]" "s3fs" --upgrade

Collecting sagemaker
  Using cached sagemaker-2.103.0-py2.py3-none-any.whl
Collecting transformers
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
     |████████████████████████████████| 4.0 MB 4.7 MB/s            
[?25hCollecting datasets[s3]
  Downloading datasets-2.4.0-py3-none-any.whl (365 kB)
     |████████████████████████████████| 365 kB 75.0 MB/s            
Collecting s3fs
  Downloading s3fs-2022.1.0-py3-none-any.whl (25 kB)
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
     |████████████████████████████████| 880 kB 73.9 MB/s            
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
     |████████████████████████████████| 67 kB 8.2 MB/s             
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
     |████████████████████████████████| 6.6 MB

[?25h  Downloading boto3-1.23.5-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 36.3 MB/s            
[?25h  Downloading boto3-1.23.4-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 71.9 MB/s            
[?25h  Downloading boto3-1.23.3-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 74.8 MB/s            
[?25h  Downloading boto3-1.23.2-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 72.2 MB/s            
[?25h  Downloading boto3-1.23.1-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 73.5 MB/s            
[?25h  Downloading boto3-1.23.0-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 74.0 MB/s            
[?25h  Downloading boto3-1.22.13-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 72.8 MB/s            
[?25h  Downloading boto3-1.22.12-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132

[?25h  Downloading boto3-1.21.2-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 74.0 MB/s            
[?25h  Downloading boto3-1.21.1-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 67.9 MB/s            
[?25h  Downloading boto3-1.21.0-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 73.7 MB/s            
[?25h  Downloading boto3-1.20.54-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 67.8 MB/s            
[?25h  Downloading boto3-1.20.53-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 71.8 MB/s            
[?25h  Downloading boto3-1.20.52-py3-none-any.whl (131 kB)
     |████████████████████████████████| 131 kB 70.0 MB/s            
[?25h  Downloading boto3-1.20.51-py3-none-any.whl (131 kB)
     |████████████████████████████████| 131 kB 71.9 MB/s            
[?25h  Downloading boto3-1.20.50-py3-none-any.whl (131 kB)
     |████████████████████████████████| 

  Attempting uninstall: boto3
    Found existing installation: boto3 1.23.10
    Uninstalling boto3-1.23.10:
      Successfully uninstalled boto3-1.23.10
  Attempting uninstall: sagemaker
    Found existing installation: sagemaker 2.102.0
    Uninstalling sagemaker-2.102.0:
      Successfully uninstalled sagemaker-2.102.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
fastai 1.0.61 requires nvidia-ml-py3, which is not installed.
spacy 3.0.6 requires pydantic<1.8.0,>=1.7.1, but you have pydantic 1.8.2 which is incompatible.
awscli 1.24.10 requires botocore==1.26.10, but you have botocore 1.23.24 which is incompatible.[0m
Successfully installed aiobotocore-2.1.2 boto3-1.20.24 botocore-1.23.24 datasets-2.4.0 fsspec-2022.1.0 huggingface-hub-0.4.0 importlib-resources-5.4.0 responses-0.17.0 s3fs-2022.1.0 sacremoses-0.0.53 sagemaker-2.103.0 tokenizers-0.12.1 tq

### Development environment 

In [2]:
import sagemaker.huggingface

### Permissions

_If you are going to use SageMaker in a local environment, you need access to an IAM Role with the required permissions for SageMaker. You can read more at [SageMaker Roles](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html)._

In [3]:
import sagemaker

sess = sagemaker.Session()
# The SageMaker session bucket is used for uploading data, models and logs
# SageMaker will automatically create this bucket if it doesn't exist
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    # Set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

role = sagemaker.get_execution_role()
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"Role arn: {role}")
print(f"Bucket: {sess.default_bucket()}")
print(f"Region: {sess.boto_region_name}")

Role arn: arn:aws:iam::485013493900:role/UserSageMakerRole
Bucket: sagemaker-us-east-2-485013493900
Region: us-east-2


## Pre-processing

We use the `datasets` library to pre-process the `sst2` dataset (Stanford Sentiment Treebank). After pre-processing, the dataset is uploaded to the `sagemaker_session_bucket` for use within the training job. The [sst2](https://nlp.stanford.edu/sentiment/index.html) dataset consists of 67349 training samples and _ testing samples of highly polar movie reviews.

### Download the dataset

In [4]:
from datasets import Dataset
from transformers import AutoTokenizer
import pandas as pd

# Tokenizer used in pre-processing
tokenizer_name = "distilbert-base-uncased"

# S3 key prefix for the data
s3_prefix = "DEMO-samples/datasets/sst"

# Download the SST2 data from s3
!curl https://sagemaker-sample-files.s3.amazonaws.com/datasets/text/SST2/sst2.test > ./sst2.test
!curl https://sagemaker-sample-files.s3.amazonaws.com/datasets/text/SST2/sst2.train > ./sst2.train
!curl https://sagemaker-sample-files.s3.amazonaws.com/datasets/text/SST2/sst2.val > ./sst2.val

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  189k  100  189k    0     0  1308k      0 --:--:-- --:--:-- --:--:-- 1312k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 3716k  100 3716k    0     0  20.2M      0 --:--:-- --:--:-- --:--:-- 20.2M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 94916  100 94916    0     0   907k      0 --:--:-- --:--:-- --:--:--  908k


### Tokenize sentences

In [5]:
# Download tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

# Tokenizer helper function
def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True)


# Load dataset
test_df = pd.read_csv("sst2.test", sep="delimiter", header=None, engine="python", names=["line"])
train_df = pd.read_csv("sst2.train", sep="delimiter", header=None, engine="python", names=["line"])

test_df[["label", "text"]] = test_df["line"].str.split(" ", 1, expand=True)
train_df[["label", "text"]] = train_df["line"].str.split(" ", 1, expand=True)

test_df.drop("line", axis=1, inplace=True)
train_df.drop("line", axis=1, inplace=True)

test_df["label"] = pd.to_numeric(test_df["label"], downcast="integer")
train_df["label"] = pd.to_numeric(train_df["label"], downcast="integer")

train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

# Tokenize dataset
train_dataset = train_dataset.map(tokenize, batched=True)
test_dataset = test_dataset.map(tokenize, batched=True)

# Set format for pytorch
train_dataset = train_dataset.rename_column("label", "labels")
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

test_dataset = test_dataset.rename_column("label", "labels")
test_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

  0%|          | 0/68 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

### Upload data to `sagemaker_session_bucket`

After processing the `datasets`, we use the `FileSystem` [integration](https://huggingface.co/docs/datasets/filesystems.html) to upload the dataset to S3.

In [6]:
import botocore
from datasets.filesystems import S3FileSystem

s3 = S3FileSystem()

# save train_dataset to s3
training_input_path = f"s3://{sess.default_bucket()}/{s3_prefix}/train"
train_dataset.save_to_disk(training_input_path, fs=s3)

# save test_dataset to s3
test_input_path = f"s3://{sess.default_bucket()}/{s3_prefix}/test"
test_dataset.save_to_disk(test_input_path, fs=s3)

In [7]:
print("training_input_path:", training_input_path)
print("test_input_path:", test_input_path)


training_input_path: s3://sagemaker-us-east-2-485013493900/DEMO-samples/datasets/sst/train
test_input_path: s3://sagemaker-us-east-2-485013493900/DEMO-samples/datasets/sst/test


## Fine-tune the model and start a SageMaker training job

In order to create a SageMaker training job, we need a `HuggingFace` Estimator. The Estimator handles end-to-end Amazon SageMaker training and deployment tasks. In an Estimator, we define which fine-tuning script should be used as `entry_point`, which `instance_type` should be used, which `hyperparameters` are passed in, etc:



```python
hf_estimator = HuggingFace(entry_point="train.py",
                            source_dir="./scripts",
                            base_job_name="huggingface-sdk-extension",
                            instance_type="ml.p3.2xlarge",
                            instance_count=1,
                            transformers_version="4.4",
                            pytorch_version="1.6",
                            py_version="py36",
                            role=role,
                            hyperparameters = {"epochs": 1,
                                               "train_batch_size": 32,
                                               "model_name":"distilbert-base-uncased"
                                                })
```

When we create a SageMaker training job, SageMaker takes care of starting and managing all the required EC2 instances for us with the `huggingface` container, uploads the provided fine-tuning script `train.py`, and downloads the data from the `sagemaker_session_bucket` into the container at `/opt/ml/input/data`. Then, it starts the training job by running:

```python
/opt/conda/bin/python train.py --epochs 1 --model_name distilbert-base-uncased --train_batch_size 32
```

The `hyperparameters` defined in the `HuggingFace` estimator are passed in as named arguments. 

SageMaker provides useful properties about the training environment through various environment variables, including the following:

* `SM_MODEL_DIR`: A string representing the path where the training job writes the model artifacts to. After training, artifacts in this directory are uploaded to S3 for model hosting.

* `SM_NUM_GPUS`: An integer representing the number of GPUs available to the host.

* `SM_CHANNEL_XXXX:` A string representing the path to the directory that contains the input data for the specified channel. For example, if you specify two input channels in the Hugging Face estimator's `fit()` call, named `train` and `test`, the environment variables `SM_CHANNEL_TRAIN` and `SM_CHANNEL_TEST` are set.


To run the training job locally, you can define `instance_type="local"` or `instance_type="local_gpu"` for GPU usage.

_Note: local mode is not supported in SageMaker Studio._


In [8]:
!pygmentize ./scripts/train.py

[34mfrom[39;49;00m [04m[36mtransformers[39;49;00m [34mimport[39;49;00m AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer
[34mfrom[39;49;00m [04m[36msklearn[39;49;00m[04m[36m.[39;49;00m[04m[36mmetrics[39;49;00m [34mimport[39;49;00m accuracy_score, precision_recall_fscore_support
[34mfrom[39;49;00m [04m[36mdatasets[39;49;00m [34mimport[39;49;00m load_from_disk
[34mimport[39;49;00m [04m[36mrandom[39;49;00m
[34mimport[39;49;00m [04m[36mlogging[39;49;00m
[34mimport[39;49;00m [04m[36msys[39;49;00m
[34mimport[39;49;00m [04m[36margparse[39;49;00m
[34mimport[39;49;00m [04m[36mos[39;49;00m
[34mimport[39;49;00m [04m[36mtorch[39;49;00m

[34mif[39;49;00m [31m__name__[39;49;00m == [33m"[39;49;00m[33m__main__[39;49;00m[33m"[39;49;00m:

    parser = argparse.ArgumentParser()

    [37m# Hyperparameters sent by the client are passed as command-line arguments to the script[39;49;00m
    parser.add_argument([

### Create an Estimator and start a training job

In [9]:
from sagemaker.huggingface import HuggingFace

# Hyperparameters which are passed into the training job
hyperparameters = {"epochs": 1, "train_batch_size": 32, "model_name": "distilbert-base-uncased"}

In [10]:
hf_estimator = HuggingFace(
    entry_point="train.py",
    source_dir="./scripts",
    instance_type="ml.p3.2xlarge",
    instance_count=1,
    role=role,
    transformers_version="4.12",
    pytorch_version="1.9",
    py_version="py38",
    hyperparameters=hyperparameters,
)

In [11]:
# Start the training job with the uploaded dataset as input
hf_estimator.fit({"train": training_input_path, "test": test_input_path})

2022-08-08 15:11:13 Starting - Starting the training job...
2022-08-08 15:11:40 Starting - Preparing the instances for trainingProfilerReport-1659971473: InProgress
.........
2022-08-08 15:12:59 Downloading - Downloading input data...
2022-08-08 15:13:37 Training - Downloading the training image..............................
2022-08-08 15:18:38 Training - Training image download completed. Training in progress.[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2022-08-08 15:18:36,617 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2022-08-08 15:18:36,640 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2022-08-08 15:18:36,648 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2022-08-08 15:18:37,304 sagemaker-training-toolkit INFO     Invoking user script[0m
[34mTrain

[34mThe following columns in the training set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text.[0m
[34mThe following columns in the training set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text.[0m
[34m***** Running training *****
  Num examples = 67349[0m
[34m***** Running training *****
  Num examples = 67349[0m
[34mNum Epochs = 1
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32[0m
[34mNum Epochs = 1
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32[0m
[34mGradient Accumulation steps = 1
  Total optimization steps = 2105[0m
[34mGradient Accumulation steps = 1
  Total optimization steps = 2105[0m
[34m0%|          | 0/2105 [00:00<?, ?it/s][0m
[34m[2022-08-08 15:18:54.341 algo-1:26 INFO utils.py:27] RULE_JOB_STO

[34m1%|          | 16/2105 [00:09<16:52,  2.06it/s][0m
[34m1%|          | 17/2105 [00:10<16:46,  2.07it/s][0m
[34m1%|          | 18/2105 [00:10<16:45,  2.08it/s][0m
[34m1%|          | 19/2105 [00:11<16:46,  2.07it/s][0m
[34m1%|          | 20/2105 [00:11<16:44,  2.08it/s][0m
[34m1%|          | 21/2105 [00:12<16:40,  2.08it/s][0m
[34m1%|          | 22/2105 [00:12<16:36,  2.09it/s][0m
[34m1%|          | 23/2105 [00:13<16:35,  2.09it/s][0m
[34m1%|          | 24/2105 [00:13<16:36,  2.09it/s][0m
[34m1%|          | 25/2105 [00:14<16:38,  2.08it/s][0m
[34m1%|          | 26/2105 [00:14<16:34,  2.09it/s][0m
[34m1%|▏         | 27/2105 [00:14<16:35,  2.09it/s][0m
[34m1%|▏         | 28/2105 [00:15<16:35,  2.09it/s][0m
[34m1%|▏         | 29/2105 [00:15<16:31,  2.09it/s][0m
[34m1%|▏         | 30/2105 [00:16<16:29,  2.10it/s][0m
[34m1%|▏         | 31/2105 [00:16<16:28,  2.10it/s][0m
[34m2%|▏         | 32/2105 [00:17<16:28,  2.10it/s][0m
[34m2%|▏         | 33/2105 [00

[34m8%|▊         | 162/2105 [01:19<15:31,  2.09it/s][0m
[34m8%|▊         | 163/2105 [01:20<15:34,  2.08it/s][0m
[34m8%|▊         | 164/2105 [01:20<15:35,  2.07it/s][0m
[34m8%|▊         | 165/2105 [01:21<15:34,  2.08it/s][0m
[34m8%|▊         | 166/2105 [01:21<15:31,  2.08it/s][0m
[34m8%|▊         | 167/2105 [01:22<15:35,  2.07it/s][0m
[34m8%|▊         | 168/2105 [01:22<15:36,  2.07it/s][0m
[34m8%|▊         | 169/2105 [01:23<15:34,  2.07it/s][0m
[34m8%|▊         | 170/2105 [01:23<15:31,  2.08it/s][0m
[34m8%|▊         | 171/2105 [01:24<15:28,  2.08it/s][0m
[34m8%|▊         | 172/2105 [01:24<15:25,  2.09it/s][0m
[34m8%|▊         | 173/2105 [01:25<15:26,  2.09it/s][0m
[34m8%|▊         | 174/2105 [01:25<15:26,  2.08it/s][0m
[34m8%|▊         | 175/2105 [01:26<15:24,  2.09it/s][0m
[34m8%|▊         | 176/2105 [01:26<15:22,  2.09it/s][0m
[34m8%|▊         | 177/2105 [01:26<15:23,  2.09it/s][0m
[34m8%|▊         | 178/2105 [01:27<15:28,  2.08it/s][0m
[34m9%|▊     

[34m15%|█▍        | 308/2105 [02:29<14:19,  2.09it/s][0m
[34m15%|█▍        | 309/2105 [02:30<14:18,  2.09it/s][0m
[34m15%|█▍        | 310/2105 [02:30<14:18,  2.09it/s][0m
[34m15%|█▍        | 311/2105 [02:31<14:16,  2.09it/s][0m
[34m15%|█▍        | 312/2105 [02:31<14:18,  2.09it/s][0m
[34m15%|█▍        | 313/2105 [02:32<14:19,  2.09it/s][0m
[34m15%|█▍        | 314/2105 [02:32<14:18,  2.09it/s][0m
[34m15%|█▍        | 315/2105 [02:33<14:18,  2.08it/s][0m
[34m15%|█▌        | 316/2105 [02:33<14:19,  2.08it/s][0m
[34m15%|█▌        | 317/2105 [02:34<14:19,  2.08it/s][0m
[34m15%|█▌        | 318/2105 [02:34<14:19,  2.08it/s][0m
[34m15%|█▌        | 319/2105 [02:35<14:16,  2.09it/s][0m
[34m15%|█▌        | 320/2105 [02:35<14:14,  2.09it/s][0m
[34m15%|█▌        | 321/2105 [02:36<14:17,  2.08it/s][0m
[34m15%|█▌        | 322/2105 [02:36<14:14,  2.09it/s][0m
[34m15%|█▌        | 323/2105 [02:37<14:15,  2.08it/s][0m
[34m15%|█▌        | 324/2105 [02:37<14:17,  2.08it/s][

[34m22%|██▏       | 464/2105 [03:44<13:12,  2.07it/s][0m
[34m22%|██▏       | 465/2105 [03:45<13:09,  2.08it/s][0m
[34m22%|██▏       | 466/2105 [03:45<13:07,  2.08it/s][0m
[34m22%|██▏       | 467/2105 [03:46<13:06,  2.08it/s][0m
[34m22%|██▏       | 468/2105 [03:46<13:06,  2.08it/s][0m
[34m22%|██▏       | 469/2105 [03:47<13:04,  2.09it/s][0m
[34m22%|██▏       | 470/2105 [03:47<13:02,  2.09it/s][0m
[34m22%|██▏       | 471/2105 [03:48<13:00,  2.09it/s][0m
[34m22%|██▏       | 472/2105 [03:48<13:00,  2.09it/s][0m
[34m22%|██▏       | 473/2105 [03:49<13:03,  2.08it/s][0m
[34m23%|██▎       | 474/2105 [03:49<13:02,  2.08it/s][0m
[34m23%|██▎       | 475/2105 [03:50<13:01,  2.09it/s][0m
[34m23%|██▎       | 476/2105 [03:50<13:05,  2.08it/s][0m
[34m23%|██▎       | 477/2105 [03:51<13:04,  2.07it/s][0m
[34m23%|██▎       | 478/2105 [03:51<13:05,  2.07it/s][0m
[34m23%|██▎       | 479/2105 [03:52<13:04,  2.07it/s][0m
[34m23%|██▎       | 480/2105 [03:52<13:03,  2.07it/s][

[34m29%|██▉       | 607/2105 [04:55<12:00,  2.08it/s][0m
[34m29%|██▉       | 608/2105 [04:55<12:00,  2.08it/s][0m
[34m29%|██▉       | 609/2105 [04:56<12:02,  2.07it/s][0m
[34m29%|██▉       | 610/2105 [04:56<12:00,  2.08it/s][0m
[34m29%|██▉       | 611/2105 [04:57<12:00,  2.07it/s][0m
[34m29%|██▉       | 612/2105 [04:57<12:00,  2.07it/s][0m
[34m29%|██▉       | 613/2105 [04:58<12:02,  2.06it/s][0m
[34m29%|██▉       | 614/2105 [04:58<12:00,  2.07it/s][0m
[34m29%|██▉       | 615/2105 [04:59<12:00,  2.07it/s][0m
[34m29%|██▉       | 616/2105 [04:59<11:59,  2.07it/s][0m
[34m29%|██▉       | 617/2105 [05:00<11:56,  2.08it/s][0m
[34m29%|██▉       | 618/2105 [05:00<11:56,  2.08it/s][0m
[34m29%|██▉       | 619/2105 [05:00<11:57,  2.07it/s][0m
[34m29%|██▉       | 620/2105 [05:01<11:55,  2.08it/s][0m
[34m30%|██▉       | 621/2105 [05:01<11:52,  2.08it/s][0m
[34m30%|██▉       | 622/2105 [05:02<11:50,  2.09it/s][0m
[34m30%|██▉       | 623/2105 [05:02<11:48,  2.09it/s][

[34m36%|███▌      | 753/2105 [06:05<10:48,  2.09it/s][0m
[34m36%|███▌      | 754/2105 [06:05<10:47,  2.09it/s][0m
[34m36%|███▌      | 755/2105 [06:06<10:46,  2.09it/s][0m
[34m36%|███▌      | 756/2105 [06:06<10:46,  2.09it/s][0m
[34m36%|███▌      | 757/2105 [06:07<10:47,  2.08it/s][0m
[34m36%|███▌      | 758/2105 [06:07<10:48,  2.08it/s][0m
[34m36%|███▌      | 759/2105 [06:08<10:46,  2.08it/s][0m
[34m36%|███▌      | 760/2105 [06:08<10:44,  2.09it/s][0m
[34m36%|███▌      | 761/2105 [06:09<10:43,  2.09it/s][0m
[34m36%|███▌      | 762/2105 [06:09<10:44,  2.08it/s][0m
[34m36%|███▌      | 763/2105 [06:10<10:42,  2.09it/s][0m
[34m36%|███▋      | 764/2105 [06:10<10:41,  2.09it/s][0m
[34m36%|███▋      | 765/2105 [06:11<10:43,  2.08it/s][0m
[34m36%|███▋      | 766/2105 [06:11<10:47,  2.07it/s][0m
[34m36%|███▋      | 767/2105 [06:12<10:46,  2.07it/s][0m
[34m36%|███▋      | 768/2105 [06:12<10:45,  2.07it/s][0m
[34m37%|███▋      | 769/2105 [06:12<10:45,  2.07it/s][

[34m43%|████▎     | 899/2105 [07:15<09:38,  2.09it/s][0m
[34m43%|████▎     | 900/2105 [07:15<09:37,  2.09it/s][0m
[34m43%|████▎     | 901/2105 [07:16<09:36,  2.09it/s][0m
[34m43%|████▎     | 902/2105 [07:16<09:37,  2.08it/s][0m
[34m43%|████▎     | 903/2105 [07:17<09:38,  2.08it/s][0m
[34m43%|████▎     | 904/2105 [07:17<09:38,  2.07it/s][0m
[34m43%|████▎     | 905/2105 [07:18<09:37,  2.08it/s][0m
[34m43%|████▎     | 906/2105 [07:18<09:35,  2.08it/s][0m
[34m43%|████▎     | 907/2105 [07:19<09:34,  2.09it/s][0m
[34m43%|████▎     | 908/2105 [07:19<09:34,  2.08it/s][0m
[34m43%|████▎     | 909/2105 [07:20<09:32,  2.09it/s][0m
[34m43%|████▎     | 910/2105 [07:20<09:31,  2.09it/s][0m
[34m43%|████▎     | 911/2105 [07:21<09:31,  2.09it/s][0m
[34m43%|████▎     | 912/2105 [07:21<09:30,  2.09it/s][0m
[34m43%|████▎     | 913/2105 [07:22<09:30,  2.09it/s][0m
[34m43%|████▎     | 914/2105 [07:22<09:31,  2.09it/s][0m
[34m43%|████▎     | 915/2105 [07:23<09:30,  2.08it/s][

[34m49%|████▊     | 1022/2105 [08:16<08:37,  2.09it/s][0m
[34m49%|████▊     | 1023/2105 [08:16<08:36,  2.09it/s][0m
[34m49%|████▊     | 1024/2105 [08:17<08:36,  2.09it/s][0m
[34m49%|████▊     | 1025/2105 [08:17<08:36,  2.09it/s][0m
[34m49%|████▊     | 1026/2105 [08:18<08:36,  2.09it/s][0m
[34m49%|████▉     | 1027/2105 [08:18<08:37,  2.08it/s][0m
[34m49%|████▉     | 1028/2105 [08:18<08:35,  2.09it/s][0m
[34m49%|████▉     | 1029/2105 [08:19<08:36,  2.08it/s][0m
[34m49%|████▉     | 1030/2105 [08:19<08:37,  2.08it/s][0m
[34m49%|████▉     | 1031/2105 [08:20<08:35,  2.08it/s][0m
[34m49%|████▉     | 1032/2105 [08:20<08:34,  2.09it/s][0m
[34m49%|████▉     | 1033/2105 [08:21<08:34,  2.08it/s][0m
[34m49%|████▉     | 1034/2105 [08:21<08:35,  2.08it/s][0m
[34m49%|████▉     | 1035/2105 [08:22<08:36,  2.07it/s][0m
[34m49%|████▉     | 1036/2105 [08:22<08:36,  2.07it/s][0m
[34m49%|████▉     | 1037/2105 [08:23<08:33,  2.08it/s][0m
[34m49%|████▉     | 1038/2105 [08:23<08

[34m55%|█████▌    | 1168/2105 [09:26<07:28,  2.09it/s][0m
[34m56%|█████▌    | 1169/2105 [09:26<07:27,  2.09it/s][0m
[34m56%|█████▌    | 1170/2105 [09:27<07:26,  2.09it/s][0m
[34m56%|█████▌    | 1171/2105 [09:27<07:29,  2.08it/s][0m
[34m56%|█████▌    | 1172/2105 [09:28<07:30,  2.07it/s][0m
[34m56%|█████▌    | 1173/2105 [09:28<07:31,  2.07it/s][0m
[34m56%|█████▌    | 1174/2105 [09:29<07:32,  2.06it/s][0m
[34m56%|█████▌    | 1175/2105 [09:29<07:31,  2.06it/s][0m
[34m56%|█████▌    | 1176/2105 [09:30<07:32,  2.05it/s][0m
[34m56%|█████▌    | 1177/2105 [09:30<07:29,  2.06it/s][0m
[34m56%|█████▌    | 1178/2105 [09:31<07:28,  2.07it/s][0m
[34m56%|█████▌    | 1179/2105 [09:31<07:27,  2.07it/s][0m
[34m56%|█████▌    | 1180/2105 [09:32<07:27,  2.07it/s][0m
[34m56%|█████▌    | 1181/2105 [09:32<07:25,  2.07it/s][0m
[34m56%|█████▌    | 1182/2105 [09:32<07:24,  2.08it/s][0m
[34m56%|█████▌    | 1183/2105 [09:33<07:22,  2.08it/s][0m
[34m56%|█████▌    | 1184/2105 [09:33<07

[34m63%|██████▎   | 1324/2105 [10:41<06:14,  2.09it/s][0m
[34m63%|██████▎   | 1325/2105 [10:41<06:15,  2.08it/s][0m
[34m63%|██████▎   | 1326/2105 [10:42<06:15,  2.07it/s][0m
[34m63%|██████▎   | 1327/2105 [10:42<06:15,  2.07it/s][0m
[34m63%|██████▎   | 1328/2105 [10:43<06:14,  2.08it/s][0m
[34m63%|██████▎   | 1329/2105 [10:43<06:13,  2.08it/s][0m
[34m63%|██████▎   | 1330/2105 [10:44<06:12,  2.08it/s][0m
[34m63%|██████▎   | 1331/2105 [10:44<06:12,  2.08it/s][0m
[34m63%|██████▎   | 1332/2105 [10:45<06:12,  2.07it/s][0m
[34m63%|██████▎   | 1333/2105 [10:45<06:10,  2.08it/s][0m
[34m63%|██████▎   | 1334/2105 [10:46<06:09,  2.09it/s][0m
[34m63%|██████▎   | 1335/2105 [10:46<06:08,  2.09it/s][0m
[34m63%|██████▎   | 1336/2105 [10:47<06:07,  2.09it/s][0m
[34m64%|██████▎   | 1337/2105 [10:47<06:06,  2.09it/s][0m
[34m64%|██████▎   | 1338/2105 [10:47<06:06,  2.09it/s][0m
[34m64%|██████▎   | 1339/2105 [10:48<06:07,  2.09it/s][0m
[34m64%|██████▎   | 1340/2105 [10:48<06

[34m70%|██████▉   | 1470/2105 [11:51<05:05,  2.08it/s][0m
[34m70%|██████▉   | 1471/2105 [11:51<05:04,  2.08it/s][0m
[34m70%|██████▉   | 1472/2105 [11:52<05:04,  2.08it/s][0m
[34m70%|██████▉   | 1473/2105 [11:52<05:05,  2.07it/s][0m
[34m70%|███████   | 1474/2105 [11:53<05:03,  2.08it/s][0m
[34m70%|███████   | 1475/2105 [11:53<05:02,  2.08it/s][0m
[34m70%|███████   | 1476/2105 [11:54<05:02,  2.08it/s][0m
[34m70%|███████   | 1477/2105 [11:54<05:03,  2.07it/s][0m
[34m70%|███████   | 1478/2105 [11:55<05:03,  2.07it/s][0m
[34m70%|███████   | 1479/2105 [11:55<05:03,  2.06it/s][0m
[34m70%|███████   | 1480/2105 [11:56<05:03,  2.06it/s][0m
[34m70%|███████   | 1481/2105 [11:56<05:02,  2.07it/s][0m
[34m70%|███████   | 1482/2105 [11:57<05:00,  2.07it/s][0m
[34m70%|███████   | 1483/2105 [11:57<04:59,  2.08it/s][0m
[34m70%|███████   | 1484/2105 [11:58<04:58,  2.08it/s][0m
[34m71%|███████   | 1485/2105 [11:58<04:57,  2.08it/s][0m
[34m71%|███████   | 1486/2105 [11:59<04

[34m76%|███████▌  | 1591/2105 [12:51<04:06,  2.08it/s][0m
[34m76%|███████▌  | 1592/2105 [12:51<04:05,  2.09it/s][0m
[34m76%|███████▌  | 1593/2105 [12:52<04:04,  2.09it/s][0m
[34m76%|███████▌  | 1594/2105 [12:52<04:04,  2.09it/s][0m
[34m76%|███████▌  | 1595/2105 [12:53<04:05,  2.08it/s][0m
[34m76%|███████▌  | 1596/2105 [12:53<04:04,  2.08it/s][0m
[34m76%|███████▌  | 1597/2105 [12:54<04:03,  2.08it/s][0m
[34m76%|███████▌  | 1598/2105 [12:54<04:04,  2.07it/s][0m
[34m76%|███████▌  | 1599/2105 [12:55<04:04,  2.07it/s][0m
[34m76%|███████▌  | 1600/2105 [12:55<04:04,  2.07it/s][0m
[34m76%|███████▌  | 1601/2105 [12:56<04:03,  2.07it/s][0m
[34m76%|███████▌  | 1602/2105 [12:56<04:02,  2.07it/s][0m
[34m76%|███████▌  | 1603/2105 [12:57<04:01,  2.08it/s][0m
[34m76%|███████▌  | 1604/2105 [12:57<04:00,  2.09it/s][0m
[34m76%|███████▌  | 1605/2105 [12:58<03:59,  2.09it/s][0m
[34m76%|███████▋  | 1606/2105 [12:58<03:59,  2.09it/s][0m
[34m76%|███████▋  | 1607/2105 [12:59<03

[34m82%|████████▏ | 1736/2105 [14:01<02:57,  2.08it/s][0m
[34m83%|████████▎ | 1737/2105 [14:01<02:57,  2.07it/s][0m
[34m83%|████████▎ | 1738/2105 [14:02<02:56,  2.08it/s][0m
[34m83%|████████▎ | 1739/2105 [14:02<02:55,  2.08it/s][0m
[34m83%|████████▎ | 1740/2105 [14:03<02:55,  2.08it/s][0m
[34m83%|████████▎ | 1741/2105 [14:03<02:54,  2.08it/s][0m
[34m83%|████████▎ | 1742/2105 [14:04<02:54,  2.08it/s][0m
[34m83%|████████▎ | 1743/2105 [14:04<02:54,  2.08it/s][0m
[34m83%|████████▎ | 1744/2105 [14:05<02:54,  2.07it/s][0m
[34m83%|████████▎ | 1745/2105 [14:05<02:53,  2.08it/s][0m
[34m83%|████████▎ | 1746/2105 [14:05<02:52,  2.08it/s][0m
[34m83%|████████▎ | 1747/2105 [14:06<02:51,  2.09it/s][0m
[34m83%|████████▎ | 1748/2105 [14:06<02:50,  2.09it/s][0m
[34m83%|████████▎ | 1749/2105 [14:07<02:50,  2.09it/s][0m
[34m83%|████████▎ | 1750/2105 [14:07<02:50,  2.09it/s][0m
[34m83%|████████▎ | 1751/2105 [14:08<02:50,  2.08it/s][0m
[34m83%|████████▎ | 1752/2105 [14:08<02

[34m89%|████████▉ | 1882/2105 [15:11<01:47,  2.08it/s][0m
[34m89%|████████▉ | 1883/2105 [15:11<01:47,  2.07it/s][0m
[34m90%|████████▉ | 1884/2105 [15:12<01:46,  2.07it/s][0m
[34m90%|████████▉ | 1885/2105 [15:12<01:46,  2.07it/s][0m
[34m90%|████████▉ | 1886/2105 [15:13<01:45,  2.07it/s][0m
[34m90%|████████▉ | 1887/2105 [15:13<01:44,  2.08it/s][0m
[34m90%|████████▉ | 1888/2105 [15:14<01:44,  2.08it/s][0m
[34m90%|████████▉ | 1889/2105 [15:14<01:43,  2.09it/s][0m
[34m90%|████████▉ | 1890/2105 [15:15<01:42,  2.09it/s][0m
[34m90%|████████▉ | 1891/2105 [15:15<01:42,  2.10it/s][0m
[34m90%|████████▉ | 1892/2105 [15:16<01:41,  2.09it/s][0m
[34m90%|████████▉ | 1893/2105 [15:16<01:41,  2.09it/s][0m
[34m90%|████████▉ | 1894/2105 [15:17<01:40,  2.09it/s][0m
[34m90%|█████████ | 1895/2105 [15:17<01:40,  2.09it/s][0m
[34m90%|█████████ | 1896/2105 [15:18<01:40,  2.08it/s][0m
[34m90%|█████████ | 1897/2105 [15:18<01:39,  2.08it/s][0m
[34m90%|█████████ | 1898/2105 [15:19<01

[34m96%|█████████▌| 2014/2105 [16:16<00:44,  2.06it/s][0m
[34m96%|█████████▌| 2015/2105 [16:17<00:43,  2.07it/s][0m
[34m96%|█████████▌| 2016/2105 [16:17<00:42,  2.08it/s][0m
[34m96%|█████████▌| 2017/2105 [16:18<00:42,  2.08it/s][0m
[34m96%|█████████▌| 2018/2105 [16:18<00:41,  2.09it/s][0m
[34m96%|█████████▌| 2019/2105 [16:18<00:41,  2.09it/s][0m
[34m96%|█████████▌| 2020/2105 [16:19<00:40,  2.08it/s][0m
[34m96%|█████████▌| 2021/2105 [16:19<00:40,  2.08it/s][0m
[34m96%|█████████▌| 2022/2105 [16:20<00:39,  2.08it/s][0m
[34m96%|█████████▌| 2023/2105 [16:20<00:39,  2.08it/s][0m
[34m96%|█████████▌| 2024/2105 [16:21<00:38,  2.09it/s][0m
[34m96%|█████████▌| 2025/2105 [16:21<00:38,  2.09it/s][0m
[34m96%|█████████▌| 2026/2105 [16:22<00:37,  2.09it/s][0m
[34m96%|█████████▋| 2027/2105 [16:22<00:37,  2.09it/s][0m
[34m96%|█████████▋| 2028/2105 [16:23<00:36,  2.08it/s][0m
[34m96%|█████████▋| 2029/2105 [16:23<00:36,  2.09it/s][0m
[34m96%|█████████▋| 2030/2105 [16:24<00

[34m72%|███████▏  | 21/29 [00:06<00:02,  3.02it/s][0m
[34m76%|███████▌  | 22/29 [00:06<00:02,  3.01it/s][0m
[34m79%|███████▉  | 23/29 [00:07<00:01,  3.01it/s][0m
[34m83%|████████▎ | 24/29 [00:07<00:01,  3.01it/s][0m
[34m86%|████████▌ | 25/29 [00:07<00:01,  3.01it/s][0m
[34m90%|████████▉ | 26/29 [00:08<00:00,  3.02it/s][0m
[34m93%|█████████▎| 27/29 [00:08<00:00,  3.00it/s][0m
[34m97%|█████████▋| 28/29 [00:08<00:00,  2.99it/s][0m
[34m100%|██████████| 29/29 [00:09<00:00,  3.56it/s][0m
[34m100%|██████████| 29/29 [00:09<00:00,  3.17it/s][0m
[34m***** Eval results *****[0m
[34mSaving model checkpoint to /opt/ml/model[0m
[34mSaving model checkpoint to /opt/ml/model[0m
[34mConfiguration saved in /opt/ml/model/config.json[0m
[34mConfiguration saved in /opt/ml/model/config.json[0m
[34mModel weights saved in /opt/ml/model/pytorch_model.bin[0m
[34mModel weights saved in /opt/ml/model/pytorch_model.bin[0m
[34mtokenizer config file saved in /opt/ml/model/tokenizer_

### Deploy the endpoint

To deploy the endpoint, call `deploy()` on the HuggingFace estimator object, passing in the desired number of instances and instance type.

In [None]:
predictor = hf_estimator.deploy(1, "ml.p3.2xlarge")

------------

Then use the returned predictor object to perform inference.

In [None]:
sentiment_input = {"inputs": "I love using the new Inference DLC."}

predictor.predict(sentiment_input)

We see that the fine-tuned model classifies the test sentence "I love using the new Inference DLC." as having positive sentiment with 98% probability!

Finally, delete the endpoint.

In [None]:
predictor.delete_endpoint()

## Extras

### Estimator Parameters

In [None]:
print(f"Container image used for training job: \n{hf_estimator.image_uri}\n")
print(f"S3 URI where the trained model is located: \n{hf_estimator.model_data}\n")
print(f"Latest training job name for this estimator: \n{hf_estimator.latest_training_job.name}\n")

In [None]:
hf_estimator.sagemaker_session.logs_for_job(hf_estimator.latest_training_job.name)

### Attach a previous training job to an estimator

In SageMaker, you can attach a previous training job to an estimator to continue training, get results, etc.

In [None]:
from sagemaker.estimator import Estimator

# Uncomment the following lines and supply your training job name

# old_training_job_name = "<your-training-job-name>"
# hf_estimator_loaded = Estimator.attach(old_training_job_name)
# hf_estimator_loaded.model_data