In [None]:
import os

from sagemaker import get_execution_role
from sagemaker.huggingface import HuggingFace

In [None]:
# role = os.environ("SM_EXECUTION_ROLE")
role = get_execution_role()

In [None]:
# hyperparameters sent by the client are passed as command-line arguments to the script.
hyperparameters={
    "epochs": 1,
    "per_device_train_batch_size": 32,
    "model_name_or_path": "distilbert-base-cased"
}

In [None]:
# SageMaker metrics automatically parses training job logs for metrics and sends them
# to CloudWatch. If you want SageMaker to parse the logs, you must specify the metric’s
# name and a regular expression for SageMaker to use to find the metric.
metric_definitions = [
    {"Name": "train_runtime", "Regex": "train_runtime.*=\D*(.*?)$"},
    {"Name": "eval_accuracy", "Regex": "eval_accuracy.*=\D*(.*?)$"},
    {"Name": "eval_loss", "Regex": "eval_loss.*=\D*(.*?)$"},
]

In [None]:
huggingface_estimator = HuggingFace(
    entry_point="start.py",
    source_dir="./src",
    instance_type="ml.p3.2xlarge",
    instance_count=1,
    role=role,
    transformers_version="4.28.1",
    pytorch_version="2.0",
    py_version="py39",
    hyperparameters=hyperparameters,
)

In [None]:
huggingface_estimator.fit(
    {
        "train": "s3://sagemaker-project-p-lo6kmrzwou9t/processed/sample/distilbert-base-cased/train/",
        "test": "s3://sagemaker-project-p-lo6kmrzwou9t/processed/sample/distilbert-base-cased/test/"
    }
)