# Wenet Training on SageMaker Training job

In [None]:
# ## Update sagemaker python sdk version
!pip install -U sagemaker

## Package training image
- 复制下面的命令在 SageMaker notebook terminal 界面运行，完成训练镜像的打包

In [None]:
# bash build_and_push.sh

## Set model, Code and data

In [None]:
import sagemaker
from sagemaker import get_execution_role

sess = sagemaker.Session()
role = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()
region = sess.boto_session.region_name
account_id = sess.account_id()
print("sagemaker_default_bucket:", sagemaker_default_bucket)
print("sagemaker_region:", region)
print("account_id:", account_id)

## upload pretrain models to s3

In [None]:
!pip install -r wenet_src/requirements.txt

In [None]:
# Code language: python
from huggingface_hub import snapshot_download
from pathlib import Path

model_name = "FireRedTeam/FireRedASR-AED-L"
model_file =model_name.split("/")[-1]
wenet_weight_path = f"{model_file}_wenet"
local_cache_path = Path(model_file)
local_cache_path.mkdir(exist_ok=True)

# Only download pytorch checkpoint files
allow_patterns = ["*"]

model_download_path = snapshot_download(
    repo_id=model_name,
    cache_dir=local_cache_path,
    allow_patterns=allow_patterns,
)
model_snapshot_path = list(local_cache_path.glob("**/snapshots/*"))[0]

In [None]:
!pip install git+https://github.com/wenet-e2e/wenet.git

In [None]:
!python wenet_src/wenet/firered/convert_FireRed_AED_L_to_wenet_config_and_ckpt.py --firered_model_dir {model_snapshot_path} --output_dir {wenet_weight_path}

In [None]:
!sed 's|FireRedASR-AED-L_wenet/|/tmp/model/|g' {wenet_weight_path}/train.yaml > {wenet_weight_path}/train_modefied.yaml

In [None]:
!aws s3 cp {wenet_weight_path} s3://{sagemaker_default_bucket}/Foundation-Models/{model_file} --recursive

## Submit Training job

In [None]:
REPO_NAME = "sagemaker-training/wenet"
image_uri = f"{account_id}.dkr.ecr.{region}.amazonaws.com/{REPO_NAME}:latest"
prefix = "wenet-ft"

instance_count = 1
# instance_type = 'ml.p4d.24xlarge' # 8*40G
# instance_type = 'ml.g5.48xlarge'  # 8*24G
instance_type = 'ml.g6e.48xlarge'  # 8*48G
model_s3_checkpoint_path = f"s3://{sagemaker_default_bucket}/finetuned_model/{model_file}_checkpoints/"
environment = {
    'NODE_NUMBER':str(instance_count),
    'MODEL_S3_PATH': f's3://{sagemaker_default_bucket}/Foundation-Models/{model_file}', # source model files
    'MODEL_LOCAL_PATH': '/tmp/model',
    'OUTPUT_MODEL_S3_PATH': f's3://{sagemaker_default_bucket}/finetuned_model/{model_file}', # destination
}

est = sagemaker.estimator.Estimator(
    image_uri,
    role,
    entry_point='entry.py',
    source_dir='wenet_src/',
    environment=environment,
    checkpoint_s3_uri=model_s3_checkpoint_path,
    train_instance_count=1,
    train_instance_type='ml.g6e.48xlarge',
    # keep_alive_period_in_seconds=3600, # sagemaker warmpool setting
    base_job_name=prefix
)

input_channel = {'zh': "s3://audio-train-datasets/wenet/zh/"}
est.fit(input_channel)

In [None]:
!aws s3 ls {model_s3_checkpoint_path}

In [None]:
!aws s3 cp {model_s3_checkpoint_path}epoch_10.pt ./checkpoints/epoch_10.pt

In [None]:
!ls -l ./checkpoints

In [None]:
!sudo yum update -y
!sudo yum install -y sox sox-devel

In [None]:
# Full-parameter
# 需要修改checkpoints路径下 yaml 文件的路径
!python wenet_src/wenet/bin/recognize.py --config {wenet_weight_path}/train.yaml  --test_data ../wenet_finetuning/data/zh/test_local.list --gpu 0 --device cuda --checkpoint  ./checkpoints/epoch_10.pt --result_dir ./results --modes attention

# LoRA
# ... --use_lora True 


In [None]:
!python wenet_src/wenet/bin/recognize.py --config {wenet_weight_path}/train.yaml  --test_data ../wenet_finetuning/data/zh/test_local.list --gpu 0 --device cuda --checkpoint  {wenet_weight_path}/wenet_firered.pt --result_dir ./results2 --modes attention

In [None]:
import json
def process_test_data(file):
    save_ref = open("test.ref", "w")
    with open(file, "r") as rf:
        for line in rf:
            data = json.loads(line.strip())
            save_ref.write("\t".join((data["key"], data["txt"])))
            save_ref.write("\n")
            
    save_ref.close()
    
process_test_data("/home/ec2-user/SageMaker/asr_xiaohongshu/wenet_finetuning/data/zh/test.list")

In [None]:
!python wenet_src/tools/compute-wer.py --char=1 --v=1 wenet_src/test.ref ./results2/attention/text

In [None]:
!python wenet_src/tools/compute-wer.py --char=1 --v=1 wenet_src/test.ref ./results/attention/text

#### 训练完以后转回 FieredASR 模型官方格式

In [None]:
!python wenet_src/wenet/firered/convert_wenet_to_FireRed_AED_L_ckpt.py --wenet_config_path results/training/train.yaml --wenet_pt_path ./checkpoints/epoch_10.pt --original_fireredaed_dir weights/FireRedwenet_src/ --output_dir weights/full_epoch10