diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh index 30bbefb78b6..4fa8c94905f 100644 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -9,7 +9,7 @@ set -exu # shellcheck source=/dev/null source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" -MODEL_NAME=$1 # stories110M.pt +MODEL_NAME=$1 # stories110M BUILD_TOOL=$2 # buck2 or cmake DTYPE=$3 # fp16 or fp32 MODE=${4:-"xnnpack+custom"} # portable or xnnpack+custom or xnnpack+custom+qe @@ -140,7 +140,7 @@ cmake_build_llama_runner() { cleanup_files() { echo "Deleting downloaded and generated files" - rm "${MODEL_NAME}" + rm "${CHECKPOINT_FILE_NAME}" rm tokenizer.model rm tokenizer.bin rm "${EXPORTED_MODEL_NAME}" @@ -159,8 +159,10 @@ prepare_artifacts_upload() { # Download and create artifacts. PARAMS="params.json" +CHECKPOINT_FILE_NAME="" touch "${PARAMS}" -if [[ "${MODEL_NAME}" == "stories110M.pt" ]]; then +if [[ "${MODEL_NAME}" == "stories110M" ]]; then + CHECKPOINT_FILE_NAME="stories110M.pt" download_stories_model_artifacts else echo "Unsupported model name ${MODEL_NAME}" @@ -181,7 +183,7 @@ fi # Export model. EXPORTED_MODEL_NAME="${EXPORTED_MODEL_NAME}.pte" echo "Exporting ${EXPORTED_MODEL_NAME}" -EXPORT_ARGS="-c stories110M.pt -p ${PARAMS} -d ${DTYPE} -n ${EXPORTED_MODEL_NAME} -kv" +EXPORT_ARGS="-c ${CHECKPOINT_FILE_NAME} -p ${PARAMS} -d ${DTYPE} -n ${EXPORTED_MODEL_NAME} -kv" if [[ "${XNNPACK}" == "ON" ]]; then EXPORT_ARGS="${EXPORT_ARGS} -X -qmode 8da4w -G 128" fi diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml index e1c6206123a..78cd342c874 100644 --- a/.github/workflows/android-perf.yml +++ b/.github/workflows/android-perf.yml @@ -156,14 +156,14 @@ jobs: BUILD_MODE="cmake" DTYPE="fp32" - if [[ ${{ matrix.model }} == "llama*" ]]; then + if [[ ${{ matrix.model }} == "stories*"" ]]; then # Install requirements for export_llama PYTHON_EXECUTABLE=python bash examples/models/llama2/install_requirements.sh # Test llama2 if [[ ${{ matrix.delegate }} == "xnnpack" ]]; then DELEGATE_CONFIG="xnnpack+custom+qe" fi - PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh "${{ matrix.model }}.pt" "${BUILD_MODE}" "${DTYPE}" "${DELEGATE_CONFIG}" "${ARTIFACTS_DIR_NAME}" + PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh "${{ matrix.model }}" "${BUILD_MODE}" "${DTYPE}" "${DELEGATE_CONFIG}" "${ARTIFACTS_DIR_NAME}" else PYTHON_EXECUTABLE=python bash .ci/scripts/test.sh "${{ matrix.model }}" "${BUILD_MODE}" "${{ matrix.delegate }}" "${ARTIFACTS_DIR_NAME}" fi diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 61d79340e6f..3e346c716e7 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -112,7 +112,7 @@ jobs: # Install requirements for export_llama PYTHON_EXECUTABLE=python bash examples/models/llama2/install_requirements.sh # Test llama2 - PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M.pt "${BUILD_TOOL}" "${DTYPE}" "${MODE}" + PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M "${BUILD_TOOL}" "${DTYPE}" "${MODE}" test-llama-runner-linux-android: name: test-llama-runner-linux-android @@ -406,4 +406,4 @@ jobs: # Install requirements for export_llama PYTHON_EXECUTABLE=python bash examples/models/llama2/install_requirements.sh # Test llama2 - PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M.pt "${BUILD_TOOL}" "${DTYPE}" "${MODE}" + PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M "${BUILD_TOOL}" "${DTYPE}" "${MODE}" diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 969ea3d361c..450f393687e 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -269,7 +269,7 @@ jobs: # Install requirements for export_llama PYTHON_EXECUTABLE=python ${CONDA_RUN} bash examples/models/llama2/install_requirements.sh # Test llama2 - PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_llama.sh stories110M.pt "${BUILD_TOOL}" "${DTYPE}" "${MODE}" + PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_llama.sh stories110M "${BUILD_TOOL}" "${DTYPE}" "${MODE}" test-qnn-model: name: test-qnn-model