In [None]:
import os

from sagemaker import get_execution_role
from sagemaker.experiments import Run
from sagemaker.pytorch.estimator import PyTorch
from sagemaker.inputs import TrainingInput

In [None]:
# role = os.environ("SM_EXECUTION_ROLE")
role = get_execution_role()

In [None]:
# hyperparameters sent by the client are passed as command-line arguments to the script.
hyperparameters={
    "num_train_epochs": 2,
    "per_device_training_batch_size": 16,
    "pretrained_model_name_or_path": "distilbert-base-cased"
}

In [None]:
# SageMaker metrics automatically parses training job logs for metrics and sends them
# to CloudWatch. If you want SageMaker to parse the logs, you must specify the metricâ€™s
# name and a regular expression for SageMaker to use to find the metric.
metric_definitions = [
    {"Name": "train_runtime", "Regex": "train_runtime.*=\D*(.*?)$"},
    {"Name": "eval_accuracy", "Regex": "eval_accuracy.*=\D*(.*?)$"},
    {"Name": "eval_loss", "Regex": "eval_loss.*=\D*(.*?)$"},
]

In [None]:
huggingface_estimator = PyTorch(
    entry_point="start.py",
    source_dir="./src",
    instance_type="ml.c5.2xlarge",
    instance_count=1,
    role=role,
    framework_version="1.13.1",
    py_version="py39",
    metric_definitions=metric_definitions,
    hyperparameters=hyperparameters,
)

In [None]:
data_location = "s3://sagemaker-mlops2023/document-classification/processed/sample/distilbert-base-cased"

In [None]:
# The input dictionary is keyed on channel name.
# If using multiple channels for training data, you can specify a dict mapping channel names to strings or TrainingInput() 
# objects or FileSystemInput() objects.
# https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Framework.fit
inputs = {
    "train": TrainingInput(s3_data=f"{data_location}/train/"),
    "test": TrainingInput(s3_data=f"{data_location}/test/"),
}

In [None]:
with Run(experiment_name="document-classification-pm-test") as run:
    run.log_parameters(parameters=hyperparameters)
    huggingface_estimator = HuggingFace(
        entry_point="start.py",
        source_dir="./src",
        instance_type="ml.p2.xlarge",
        instance_count=1,
        role=role,
        transformers_version="4.26.0",
        pytorch_version="1.13.1",
        py_version="py39",
        hyperparameters=hyperparameters,
    )
    huggingface_estimator.fit(inputs=inputs)