# 05. Draft Model training

In [None]:
import os
import sys
from pprint import pprint
import yaml

current_dir = os.getcwd()
kit_dir =  os.path.abspath(os.path.join(current_dir, '..'))
repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
sys.path.append(repo_dir)

from utils.fine_tuning.src.snsdk_wrapper import SnsdkWrapper

## Step by Step / Manual setting

First instantiate the SambaStudio client 

In [None]:
sambastudio_client = SnsdkWrapper()

In [None]:
# Load the data generation config
config_draft_model_training_yaml = '../05_config_draft_model_training.yaml'

# Open and load the YAML file into a dictionary
with open(config_draft_model_training_yaml, 'r') as file:
    config_draft_model_training = yaml.safe_load(file)
pprint('Draft model training:')
pprint(config_draft_model_training)

model_name = config_draft_model_training['model_checkpoint']['model_name']
dataset_name = config_draft_model_training['dataset']['dataset_name']

### Check model and dataset

In [None]:
available_models = [model["model_checkpoint_name"]for model in sambastudio_client.list_models(filter_job_types=["train"])]

assert model_name in available_models

### List available datasets

In [None]:
assert dataset_name in [dataset["dataset_name"] for dataset in sambastudio_client.list_datasets()]

### Create a project

#### Set Project configs 

In [None]:
project = {
    'project_name': config_draft_model_training['project']['project_name'],
    'project_description': config_draft_model_training['project']['project_description']
}

In [None]:
# Execute the create project method from client with project parameters
sambastudio_client.create_project(
    project_name = project['project_name'],
    project_description = project['project_description']
)

### Set train job config

In [None]:
# check required hyperparams for training job 
hyperparams = sambastudio_client.get_default_hyperparms(model,'train')
pprint(hyperparams)

In [None]:
job = {
    'job_name': 'e2e_draft_model_training_job',
    'job_description': 'Training job description.',
    'job_type': 'train',
    'model': model_name,
    'model_version': '1',
    'parallel_instances': '1',
    'dataset_name': dataset_name,
    'load_state': False,
    'sub_path': '',
    'hyperparams': {
        "batch_size": 8,
        "do_eval": False,
        "eval_steps":50,
        "evaluation_strategy": "no",
        "learning_rate": 0.00001,
        "logging_steps": 1,
        "lr_schedule": "fixed_lr",
        "max_sequence_length": 8192,
        "num_iterations": 100,
        "prompt_loss_weight": 0.0,
        "save_optimizer_state": True,
        "save_steps": 50,
        "skip_checkpoint": False,
        "subsample_eval": 0.01,
        "subsample_eval_seed": 123,
        "use_token_type_ids": True,
        "vocab_size": 128256,
        "warmup_steps": 0,
        "weight_decay": 0.1,
    }
}

### Execute training job

In [None]:
sambastudio_client.run_training_job(
    project_name = project["project_name"],
    job_name = job['job_name'],
    job_description = job['job_description'],
    job_type = job['job_type'],
    model = job['model'],
    model_version = job['model_version'],
    dataset_name = job['dataset_name'],
    parallel_instances = job['parallel_instances'],
    load_state = job['load_state'],
    sub_path = job['sub_path'],
    rdu_arch = 'SN40L-8',
    hyperparams = job['hyperparams']
)

In [None]:
sambastudio_client.check_job_progress(
    project_name=project['project_name'],
    job_name=job['job_name'],
    verbose=True,
    wait=False
)

### Promote Checkpoint

In [None]:
# we will promote the checkpoint with less training loss so we list it sorted 
checkpoints = sambastudio_client.list_checkpoints(
    project_name=project['project_name'],
    job_name=job['job_name'],
    sort=True
)
checkpoints

#### Promoted checkpoint config

In [None]:
# set checkpoint to promote config
model_checkpoint = {
    'checkpoint_name': checkpoints[0]['name'],
    'model_name': 'Suzume-Llama-3-8B-Multilingual-Publichealth',
    'model_description': 'finetuned suzume multilingual in public health qa dataset',
    'model_type': 'finetuned'
}

In [None]:
# Execute the promote_checkpoint method from client with checkpoint parameters
sambastudio_client.promote_checkpoint(
    checkpoint_name = model_checkpoint['name'],
    project_name=project['project_name'],
    job_name=job['job_name'],
    model_name=model_checkpoint['model_name'],
    model_description=model_checkpoint['model_description'],
    model_type=model_checkpoint['model_type']
)

In [None]:
# check the promoted model is now in SambaStudio models
[model for model in sambastudio_client.list_models() if model['model_checkpoint_name']==model_checkpoint['model_name']]

#### Delete all saved training checkpoints, after promotion (optional)

In [None]:
# We can delete all intermediate checkpoints saved during the training job 
for checkpoint in checkpoints:
    sambastudio_client.delete_checkpoint(checkpoint["name"])