In [None]:
!pip install wandb

In [None]:
import wandb

wandb.init()

In [None]:
wandb.sagemaker_auth(path="./")

In [None]:
pretrained_model_name = "decapoda-research/llama-7b-hf"
WANDB_PROJECT_NAME = "alpoca-test"
dataset_name = "alpaca-test"
training_dataset_path = f"s3://unwind.dev.data/llm/{dataset_name}/"

In [None]:
base_job_prefix = "llama-01"

# Hyperparameters which are passed into the training job
hyperparameters = {
    "epochs": 10,
    "model_name": pretrained_model_name,
    "learning_rate": 1e-6,
    "warmup_step_ratio": 0.3,
}

In [None]:
from sagemaker import get_execution_role, Session

# from sagemaker.huggingface import HuggingFace

from sagemaker.pytorch import PyTorch

import sagemaker
import boto3
import os

from dotenv import load_dotenv

load_dotenv("./.env")

iam_client = boto3.client("iam")
role = iam_client.get_role(RoleName=os.getenv("AWS_ROLE_NAME"))["Role"]["Arn"]
# role = get_execution_role()

sess = Session()
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

base_job_name = f"{base_job_prefix}-{dataset_name}-{hyperparameters.get('model_name', '')}".replace(
    "/", "-"
)
hyperparameters["group_name"] = base_job_name
hyperparameters["project_name"] = WANDB_PROJECT_NAME

checkpoint_s3_uri = (
    f"s3://{sagemaker_session_bucket}/{base_job_name}/checkpoints"
)

env = {
    "SAGEMAKER_REQUIREMENTS": "requirements.txt",  # path relative to `source_dir` below.
}

# configuration for running training on smdistributed Data Parallel
distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}

# spot config
max_run = 86400 * 5
max_wait = max_run + 3600

hf_estimator = PyTorch(
    entry_point="train.py",
    source_dir=".",
    instance_type="ml.g5.xlarge",
    max_run=max_run,
    # cluster
    instance_count=1,
    # instance_count=2,
    # distribution=distribution,
    role=role,
    env=env,
    framework_version="1.13",
    py_version="py39",
    hyperparameters=hyperparameters,
    base_job_name=base_job_name,
    # spot settings
    checkpoint_s3_uri=checkpoint_s3_uri,
    use_spot_instances=True,  # enables spot training
    max_wait=max_wait,  # max time including spot start + training time
)

In [None]:
base_job_name

In [None]:
# Start the training job with the uploaded dataset as input

hf_estimator.fit(
    {
        "train": training_dataset_path,
    },
    wait=False,
    logs="Rules",
)