# Multi-Node Training on SageMaker Training job

In [None]:
# ## Update sagemaker python sdk version
!pip install -U sagemaker

## Set model, Code and data

In [None]:
import sagemaker
from sagemaker import get_execution_role

sess = sagemaker.Session()
role = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()
region = sess.boto_session.region_name
print("sagemaker_default_bucket:", sagemaker_default_bucket)
print("sagemaker_region:", region)

## upload pretrain models to s3

In [None]:
# Code language: python
from huggingface_hub import snapshot_download
from pathlib import Path

local_cache_path = Path("./deepseek_coder")
local_cache_path.mkdir(exist_ok=True)

model_name = "deepseek-ai/deepseek-coder-6.7b-base"

# Only download pytorch checkpoint files
allow_patterns = ["*"]

model_download_path = snapshot_download(
    repo_id=model_name,
    cache_dir=local_cache_path,
    allow_patterns=allow_patterns,
)
model_snapshot_path = list(local_cache_path.glob("**/snapshots/*"))[0]

In [None]:
!aws s3 cp {model_snapshot_path} s3://{sagemaker_default_bucket}/Foundation-Models/deepseek_coder --recursive

## Setup for wandb

In [None]:
!pip install wandb

In [None]:
import wandb
wandb.login()

## Submit Training job

In [None]:
from sagemaker.estimator import Estimator
from sagemaker.pytorch import PyTorch
from datetime import datetime


instance_count = 1
# instance_type = 'ml.p4d.24xlarge'  ## 8*40G
instance_type = 'ml.g5.48xlarge'  ## 8*24G
max_time = 200000

# Get the current time
current_time = datetime.now()

wandb.sagemaker_auth(path="llama_factory/")
# Format the current time as a string
formatted_time = current_time.strftime("%Y%m%d%H%M%S")
print(formatted_time)

base_job_name = 'deepseek6.7B-finetune'
environment = {
    'NODE_NUMBER':str(instance_count),
    'MODEL_S3_PATH': f's3://{sagemaker_default_bucket}/Foundation-Models/deepseek_coder', # source model files
    'MODEL_LOCAL_PATH': '/tmp/pretrain_model',
    'OUTPUT_MODEL_S3_PATH': f's3://{sagemaker_default_bucket}/deepseek-coder-6.7b-base/finetuned_model/', # destination
}

estimator = PyTorch(entry_point='entry.py',
                            source_dir='llama_factory/',
                            role=role,
                            base_job_name=base_job_name,
                            environment=environment,
                            framework_version='2.1.0',
                            py_version='py310',
                            script_mode=True,
                            instance_count=instance_count,
                            instance_type=instance_type,
                            max_run=max_time)

# # data in channel will be automatically copied to each node - /opt/ml/input/data/train1
#input_channel = {'train': f's3://{sagemaker_default_bucket}/datasets/qiandao/{version}/train.json'}
estimator.fit()