From 451b882aeb2c2276802049e3f70f993f4233d7cb Mon Sep 17 00:00:00 2001 From: bhsueh Date: Sun, 11 Apr 2021 14:00:59 +0800 Subject: [PATCH] feat: Add v1.0 codes --- .clang-format | 37 + .gitignore | 7 + CMakeLists.txt | 283 ++++ Dockerfile | 140 ++ LICENSE | 25 + README.md | 124 +- all_models/transformer/config.pbtxt | 145 +++ build.env | 39 + cmake/Modules/FindNCCL.cmake | 164 +++ cmake/TritonTransformerBackendConfig.cmake.in | 39 + src/libtransformer.cc | 1155 +++++++++++++++++ src/libtriton_transformer.ldscript | 30 + tools/identity_test.py | 180 +++ tools/kill_server.sh | 30 + tools/run_client.sh | 47 + tools/run_server.sh | 42 + 16 files changed, 2486 insertions(+), 1 deletion(-) create mode 100644 .clang-format create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 Dockerfile create mode 100644 LICENSE create mode 100644 all_models/transformer/config.pbtxt create mode 100644 build.env create mode 100644 cmake/Modules/FindNCCL.cmake create mode 100644 cmake/TritonTransformerBackendConfig.cmake.in create mode 100644 src/libtransformer.cc create mode 100644 src/libtriton_transformer.ldscript create mode 100644 tools/identity_test.py create mode 100644 tools/kill_server.sh create mode 100755 tools/run_client.sh create mode 100755 tools/run_server.sh diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..98c6497 --- /dev/null +++ b/.clang-format @@ -0,0 +1,37 @@ +--- +BasedOnStyle: Google + +IndentWidth: 2 +ContinuationIndentWidth: 4 +UseTab: Never +MaxEmptyLinesToKeep: 2 + +SortIncludes: true +CompactNamespaces: true +ReflowComments: true + +DerivePointerAlignment: false +PointerAlignment: Left + +AllowShortIfStatementsOnASingleLine: false +AllowShortBlocksOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline + +AlwaysBreakAfterReturnType: TopLevelDefinitions +AlignAfterOpenBracket: AlwaysBreak +BreakBeforeBraces: Custom +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: true + AfterNamespace: false + AfterStruct: false + AfterUnion: false + BeforeCatch: true + +BinPackArguments: true +BinPackParameters: true +ConstructorInitializerAllOnOneLineOrOnePerLine: false + +IndentCaseLabels: true \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..65b81a2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +/build +/.vscode +*.so +*.run +.clangd +compile_commands.json +../all_models/transformer/1/* \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..a3add7d --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,283 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cmake_minimum_required (VERSION 3.18) + +project(tritontransformerbackend LANGUAGES C CXX) + +# +# Options +# +option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON) +option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON) +set(TRITON_PYTORCH_INCLUDE_PATHS "" CACHE PATH "Paths to Torch includes") +set(TRITON_PYTORCH_LIB_PATHS "" CACHE PATH "Paths to Torch libraries") + +set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo") +set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo") +set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo") + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) + +# Python.h needed by torch headers. +find_package(Python3 REQUIRED COMPONENTS Development) + +find_package(FasterTransformer) +find_package(CUDA 10.1 REQUIRED) +find_package(MPI REQUIRED) +find_package(NCCL REQUIRED) + +message(STATUS "Found MPI (include: ${MPI_INCLUDE_DIRS}, library: ${MPI_LIBRARIES})") + +if (${CUDA_VERSION} GREATER_EQUAL 11.0) + message(STATUS "Add DCUDA11_MODE") + add_definitions("-DCUDA11_MODE") +endif() + +# +# Dependencies +# +# FetchContent's composibility isn't very good. We must include the +# transitive closure of all repos so that we can override the tag. +# +include(FetchContent) + +FetchContent_Declare( + repo-common + GIT_REPOSITORY https://github.com/triton-inference-server/common.git + GIT_TAG ${TRITON_COMMON_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-core + GIT_REPOSITORY https://github.com/triton-inference-server/core.git + GIT_TAG ${TRITON_CORE_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-backend + GIT_REPOSITORY https://github.com/triton-inference-server/backend.git + GIT_TAG ${TRITON_BACKEND_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-ft + GIT_REPOSITORY https://github.com/NVIDIA/FasterTransformer.git + GIT_TAG main + GIT_SHALLOW ON +) +FetchContent_MakeAvailable(repo-common repo-core repo-backend repo-ft) + +# +# CUDA +# +if(${TRITON_ENABLE_GPU}) + find_package(CUDAToolkit REQUIRED) +endif() # TRITON_ENABLE_GPU + +# +# Shared library implementing the Triton Backend API +# +configure_file(src/libtriton_transformer.ldscript libtriton_transformer.ldscript COPYONLY) + +add_library( + triton-transformer-backend SHARED + src/libtransformer.cc +) + +add_library( + TritonTransformerBackend::triton-transformer-backend ALIAS triton-transformer-backend +) + +#find_package(CUDAToolkit REQUIRED) +find_package(CUDA 10.1 REQUIRED) +find_package(MPI REQUIRED) +##find_package(NCCL REQUIRED) +#if (${CUDA_VERSION} GREATER_EQUAL 11.0) +message(STATUS "Add DCUDA11_MODE") +add_definitions("-DCUDA11_MODE") +#endif() + +set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) + +target_compile_definitions(triton-transformer-backend + PUBLIC + USE_TRITONSERVER_DATATYPE + BUILD_GPT) + +target_include_directories( + triton-transformer-backend + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src +# ${CMAKE_CURRENT_SOURCE_DIR}/test + ${TRITON_PYTORCH_INCLUDE_PATHS} + ${Python3_INCLUDE_DIRS} + ${MPI_INCLUDE_PATH} + ${repo-ft_SOURCE_DIR} + ) + +target_link_directories( + triton-transformer-backend + PRIVATE + ${CUDA_PATH}/lib64 + ${MPI_Libraries} + ) + +target_compile_features(triton-transformer-backend PRIVATE cxx_std_14) +target_compile_options( + triton-transformer-backend PRIVATE + $<$,$,$>: + -Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror> +) + +if(${TRITON_ENABLE_GPU}) + target_compile_definitions( + triton-transformer-backend + PRIVATE TRITON_ENABLE_GPU=1 + ) +endif() # TRITON_ENABLE_GPU + +set_target_properties( + triton-transformer-backend + PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_transformer + SKIP_BUILD_RPATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH_USE_LINK_PATH FALSE + INSTALL_RPATH "$\{ORIGIN\}" + LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_transformer.ldscript + LINK_FLAGS "-Wl,--no-as-needed,--version-script libtriton_transformer.ldscript" +) + +# Need to turn off unused-but-set-variable due to Torchvision +# Need to turn off unknown-pragmas due to ATen OpenMP +set_target_properties( + triton-transformer-backend + PROPERTIES COMPILE_FLAGS + "-Wno-unknown-pragmas -Wno-unused-but-set-variable" +) + +set(TRITON_PYTORCH_LDFLAGS "") +FOREACH(p ${TRITON_PYTORCH_LIB_PATHS}) + set(TRITON_PYTORCH_LDFLAGS ${TRITON_PYTORCH_LDFLAGS} "-L${p}") +ENDFOREACH(p) + +target_link_libraries( + triton-transformer-backend + PRIVATE + triton-core-serverapi # from repo-core + triton-core-backendapi # from repo-core + triton-core-serverstub # from repo-core + triton-backend-utils # from repo-backend + transformer-shared # from repo-ft + ${TRITON_PYTORCH_LDFLAGS} + ${NCCL_LIBRARIES} + ${MPI_LIBRARIES} + #-ltorch + #-ltorch_cpu + #-ltorch_cuda + #-ltorchvision + #-lc10 + #-lc10_cuda + #-lmkl_core + #-lmkl_gnu_thread + #-lmkl_intel_lp64 + #-lmkl_avx2 + #-lmkl_def + #-liomp5 + #-lmkl_intel_thread + #-lmkl_vml_def + #-lmkl_rt + -lcublas + -lcublasLt + -lcudart + -lcurand + #-lnccl + #-lmpi +) + +if(${TRITON_ENABLE_GPU}) + target_link_libraries( + triton-transformer-backend + PRIVATE + CUDA::cudart + ) +endif() # TRITON_ENABLE_GPU + +# +# Install +# +include(GNUInstallDirs) +set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/TritonTransformerBackend) + +install( + TARGETS + triton-transformer-backend + EXPORT + triton-transformer-backend-targets + LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/transformer + ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/transformer +) + +install( + EXPORT + triton-transformer-backend-targets + FILE + TritonTransformerBackendTargets.cmake + NAMESPACE + TritonTransformerBackend:: + DESTINATION + ${INSTALL_CONFIGDIR} +) + +include(CMakePackageConfigHelpers) +configure_package_config_file( + ${CMAKE_CURRENT_LIST_DIR}/cmake/TritonTransformerBackendConfig.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/TritonTransformerBackendConfig.cmake + INSTALL_DESTINATION ${INSTALL_CONFIGDIR} +) + +install( + FILES + ${CMAKE_CURRENT_BINARY_DIR}/TritonTransformerBackendConfig.cmake + DESTINATION ${INSTALL_CONFIGDIR} +) + +# +# Export from build tree +# +export( + EXPORT triton-transformer-backend-targets + FILE ${CMAKE_CURRENT_BINARY_DIR}/TritonTransformerBackendTargets.cmake + NAMESPACE TritonTransformerBackend:: +) + +export(PACKAGE TritonTransformerBackend) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..dae7fa1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,140 @@ + + +ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:21.02-py3 +ARG SDK_IMAGE=nvcr.io/nvidia/tritonserver:21.02-py3-sdk + +FROM ${SDK_IMAGE} AS sdk_image + +FROM ${BASE_IMAGE} as ftbe_sdk +#RUN mkdir /usr/local/mpi +#COPY --from=mpi_image /usr/local/mpi/ /usr/local/mpi/ + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + software-properties-common \ + autoconf \ + automake \ + build-essential \ + docker.io \ + git \ + libre2-dev \ + libssl-dev \ + libtool \ + libboost-dev \ + libcurl4-openssl-dev \ + libb64-dev \ + patchelf \ + python3-dev \ + python3-pip \ + python3-setuptools \ + rapidjson-dev \ + unzip \ + wget \ + zlib1g-dev \ + pkg-config \ + uuid-dev + +RUN pip3 install --upgrade pip && \ + pip3 install --upgrade wheel setuptools docker && \ + pip3 install grpcio-tools grpcio-channelz + +RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | \ + gpg --dearmor - | \ + tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null && \ + apt-add-repository 'deb https://apt.kitware.com/ubuntu/ focal main' && \ + apt-get update && \ + apt-get install -y --no-install-recommends \ + cmake-data=3.18.4-0kitware1ubuntu20.04.1 cmake=3.18.4-0kitware1ubuntu20.04.1 + + +################################################################################ +## COPY from Dockerfile.sdk +################################################################################ +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + software-properties-common \ + autoconf \ + automake \ + build-essential \ + curl \ + git \ + libb64-dev \ + libopencv-dev \ + libopencv-core-dev \ + libssl-dev \ + libtool \ + pkg-config \ + python3 \ + python3-pip \ + python3-dev \ + rapidjson-dev \ + vim \ + wget && \ + pip3 install --upgrade wheel setuptools && \ + pip3 install --upgrade grpcio-tools && \ + pip3 install --upgrade pip + +# Build expects "python" executable (not python3). +RUN rm -f /usr/bin/python && \ + ln -s /usr/bin/python3 /usr/bin/python +# Install the dependencies needed to run the client examples. These +# are not needed for building but including them allows this image to +# be used to run the client examples. +RUN pip3 install --upgrade numpy pillow +## find install/python/ -maxdepth 1 -type f -name \ +## "tritonclient-*-manylinux1_x86_64.whl" | xargs printf -- '%s[all]' | \ +## xargs pip3 install --upgrade + +# Install DCGM +# DCGM version to install for Model Analyzer +## ARG DCGM_VERSION=2.0.13 +## RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/datacenter-gpu-manager_${DCGM_VERSION}_amd64.deb && \ +## dpkg -i datacenter-gpu-manager_${DCGM_VERSION}_amd64.deb +COPY --from=sdk_image /workspace/datacenter-gpu-manager_*_amd64.deb /tmp/ +RUN dpkg -i /tmp/datacenter-gpu-manager_*_amd64.deb && rm -rf /tmp/datacenter-gpu-manager_*_amd64.deb + +# Install Model Analyzer +ARG TRITON_MODEL_ANALYZER_REPO_TAG=r20.12 +ARG TRITON_MODEL_ANALYZER_REPO="https://github.com/triton-inference-server/model_analyzer@${TRITON_MODEL_ANALYZER_REPO_TAG}" +RUN pip3 install "git+${TRITON_MODEL_ANALYZER_REPO}" +################################################################################ +## COPY from Dockerfile.QA +################################################################################ +RUN apt-get update && apt-get install -y --no-install-recommends \ + libpng-dev \ + curl \ + libopencv-dev \ + libopencv-core-dev \ + libzmq3-dev \ + python3-dev \ + python3-pip \ + python3-protobuf \ + python3-setuptools \ + swig \ + golang-go \ + nginx \ + protobuf-compiler \ + valgrind + +RUN pip3 install --upgrade wheel setuptools && \ + pip3 install --upgrade numpy pillow future grpcio requests gsutil awscli six boofuzz grpcio-channelz azure-cli + +# need protoc-gen-go to generate go specific gRPC modules +RUN go get github.com/golang/protobuf/protoc-gen-go && \ + go get google.golang.org/grpc + +COPY --from=sdk_image /workspace/install/python/tritonclient-2.7.0-py3-none-manylinux1_x86_64.whl /tmp/ +RUN pip3 install --upgrade /tmp/tritonclient-2.7.0-py3-none-manylinux1_x86_64.whl[all] + +RUN mkdir /opt/tritonserver/backends/transformer && chmod 777 /opt/tritonserver/backends/transformer + +FROM ftbe_sdk as ftbe_work +# for debug +RUN apt update -q && apt install -y --no-install-recommends openssh-server zsh tmux mosh locales-all clangd sudo +RUN sed -i 's/#X11UseLocalhost yes/X11UseLocalhost no/g' /etc/ssh/sshd_config +RUN mkdir /var/run/sshd + +## add user because root cannot access mounted user's directories. +RUN useradd --uid 40235 --shell /bin/bash liweim + +ENTRYPOINT service ssh restart && bash diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e8584b9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of NVIDIA CORPORATION nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index 9cc652c..83275d4 100644 --- a/README.md +++ b/README.md @@ -1 +1,123 @@ -# fastertransformer_backend \ No newline at end of file + + +# FasterTransformer Backend + +The Triton backend for the [FasterTransformer](https://github.com/NVIDIA/FasterTransformer). This repository provides a script and recipe to run the highly optimized transformer-based encoder and decoder component, and it is tested and maintained by NVIDIA. In the FasterTransformer v4.0, it supports multi-gpu inference on GPT-3 model. This backend integrates FasterTransformer into Triton to use giant GPT-3 model serving by Triton. In the below example, we will show how to use the FasterTransformer backend in Triton to run inference on a GPT-3 model with 345M parameters trained by [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). + +Note that this is a research and prototyping tool, not a formal product or maintained framework. User can learn more about Triton backends in the [backend repo](https://github.com/triton-inference-server/backend). Ask questions or report problems on the issues page in this FasterTransformer_backend repo. + + + +## Table Of Contents + +- [FasterTransformer Backend](#fastertransformer-backend) + - [Table Of Contents](#table-of-contents) + - [Setup](#setup) + - [Run Serving](#run-serving) + +## Setup + +* Prepare Machine + +We provide a docker file, which bases on Triton image `nvcr.io/nvidia/tritonserver:21.02-py3`, to setup the environment. + +```bash +mkdir workspace && cd workspace +git clone https://gitlab-master.nvidia.com/liweim/transformer_backend.git +nvidia-docker build --tag ft_backend --file transformer_backend/Dockerfile . +nvidia-docker run --gpus=all -it --rm --volume $HOME:$HOME --volume $PWD:$PWD -w $PWD --name ft-work ft_backend +cd workspace +export WORKSPACE=$(pwd) +``` + +* Install libraries for Megatron (option) + +```bash +pip install torch regex fire +git clone https://github.com/NVIDIA/apex +cd apex +pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ +``` + +* Build FT backend + +```bash +cd $WORKSPACE +git clone https://github.com/triton-inference-server/server.git +export PATH=/usr/local/mpi/bin:$PATH +source transformer_backend/build.env +mkdir -p transformer_backend/build && cd $WORKSPACE/transformer_backend/build +cmake -DCMAKE_EXPORT_COMPILE_COMMANDS=1 .. && make -j32 +``` + +* Prepare model + +```bash +git clone https://github.com/NVIDIA/Megatron-LM.git +wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -P models +wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -P models +wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip +mkdir -p models/megatron-models/345m +unzip megatron_lm_345m_v0.0.zip -d models/megatron-models/345m +python ../sample/pytorch/utils/megatron_ckpt_convert.py -i ./models/megatron-models/345m/release/ -o ./models/megatron-models/c-model/345m/ -t_g 1 -i_g 8 +python _deps/repo-ft-src/sample/pytorch/utils/megatron_ckpt_convert.py -i ./models/megatron-models/345m/release/ -o ./models/megatron-models/c-model/345m/ -t_g 1 -i_g 8 +cp ./models/megatron-models/c-model/345m/8-gpu $WORKSPACE/transformer_backend/all_models/transformer/1/ -r +``` + +## Run Serving + +* Run servning directly + +```bash +cp $WORKSPACE/transformer_backend/build/libtriton_transformer.so $WORKSPACE/transformer_backend/build/lib/libtransformer-shared.so /opt/tritonserver/backends/transformer +cd $WORKSPACE && ln -s server/qa/common . +# Recommend to modify the SERVER_TIMEOUT of common/utils.sh to longer time +cd $WORKSPACE/transformer_backend/build/ +bash $WORKSPACE/transformer_backend/tools/run_server.sh +bash $WORKSPACE/transformer_backend/tools/run_client.sh +python _deps/repo-ft-src/sample/pytorch/utils/convert_gpt_token.py --out_file=triton_out # Used for checking result +``` + +* Modify the model configuration + +The model configuration for Triton server is put in `all_models/transformer/config.pbtxt`. User can modify the following hyper-parameters: + +- candidate_num: k value of top k +- probability_threshold: p value of top p +- tensor_para_size: size of tensor parallelism +- layer_para_size: size of layer parallelism +- layer_para_batch_size: Useless in Triton backend becuase this backend only supports single node, and user are recommended to use tensor parallel in single node +- max_seq_len: max supported sequence length +- is_half: Using half or not +- head_num: head number of attention +- size_per_head: size per head of attention +- vocab_size: size of vocabulary +- decoder_layers: number of transformer layers +- batch_size: max supported batch size +- is_fuse_QKV: fusing QKV in one matrix multiplication or not. It also depends on the weights of QKV. \ No newline at end of file diff --git a/all_models/transformer/config.pbtxt b/all_models/transformer/config.pbtxt new file mode 100644 index 0000000..a8f5461 --- /dev/null +++ b/all_models/transformer/config.pbtxt @@ -0,0 +1,145 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "transformer" +backend: "transformer" +default_model_filename: "gpt3-c-model-8gpu" +max_batch_size: 128 +input [ + { + name: "INPUT_ID" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_UINT32 + dims: [ 1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ 1 ] + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] + +parameters { + key: "candidate_num" + value: { + string_value: "1" + } +} +parameters { + key: "probability_threshold" + value: { + string_value: "0.0" + } +} +parameters { + key: "tensor_para_size" + value: { + string_value: "8" + } +} +parameters { + key: "layer_para_size" + value: { + string_value: "1" + } +} +parameters { + key: "layer_para_batch_size" + value: { + string_value: "8" + } +} +parameters { + key: "max_seq_len" + value: { + string_value: "32" + } +} +parameters { + key: "is_half" + value: { + string_value: "1" + } +} +parameters { + key: "head_num" + value: { + string_value: "16" + } +} +parameters { + key: "size_per_head" + value: { + string_value: "64" + } +} +parameters { + key: "vocab_size" + value: { + string_value: "50304" + } +} +parameters { + key: "decoder_layers" + value: { + string_value: "24" + } +} +parameters { + key: "model_name" + value: { + string_value: "gpt_345M" + } +} +parameters { + key: "batch_size" + value: { + string_value: "128" + } +} +parameters { + key: "is_fuse_QKV" + value: { + string_value: "1" + } +} diff --git a/build.env b/build.env new file mode 100644 index 0000000..705f666 --- /dev/null +++ b/build.env @@ -0,0 +1,39 @@ +export NPP_VERSION=11.1.2.301 +export NVIDIA_VISIBLE_DEVICES=all +export DALI_BUILD=1758882 +export CUSOLVER_VERSION=11.0.1.105 +export CUBLAS_VERSION=11.3.0.106 +export HOSTNAME=triton +#export NVIDIA_REQUIRE_CUDA=cuda>=9.0 +export CUFFT_VERSION=10.3.0.105 +export CUDA_CACHE_DISABLE=1 +export NCCL_VERSION=2.8.3 +export CUSPARSE_VERSION=11.3.0.10 +export ENV=/etc/shinit_v2 +export OPENUCX_VERSION=1.9.0 +export NSIGHT_SYSTEMS_VERSION=2020.3.4.32 +export NVIDIA_DRIVER_CAPABILITIES=compute,utility,video +export OMPI_MCA_pml=^ucx +export TRT_VERSION=7.2.2.1 +export CUDA_VERSION=11.1.1.002 +export CURAND_VERSION=10.2.2.105 +export DLPROF_VERSION=20.12 +export TERM=xterm +export TRITON_SERVER_VERSION=2.6.0 +export OPENMPI_VERSION=4.0.5 +export NVJPEG_VERSION=11.3.0.105 +export LIBRARY_PATH=/usr/local/cuda/lib64/stubs: +export SHLVL=1 +export BASH_ENV=/etc/bash.bashrc +export CUDNN_VERSION=8.0.5.43 +export NSIGHT_COMPUTE_VERSION=2020.2.1.8 +export DALI_VERSION=0.28.0 +export NVIDIA_TRITON_SERVER_VERSION=20.12 +export LD_LIBRARY_PATH=/usr/local/cuda/compat/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +export CUDA_DRIVER_VERSION=455.32.00 +export _CUDA_COMPAT_PATH=/usr/local/cuda/compat +export PATH=/usr/local/mpi/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/ucx/bin +export MOFED_VERSION=5.1-2.3.7 +export TRTOSS_VERSION=20.12 +export DEBIAN_FRONTEND=noninteractive +export _=/usr/bin/env diff --git a/cmake/Modules/FindNCCL.cmake b/cmake/Modules/FindNCCL.cmake new file mode 100644 index 0000000..9b5babe --- /dev/null +++ b/cmake/Modules/FindNCCL.cmake @@ -0,0 +1,164 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# From PyTorch: +# +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) +# +# From Caffe2: +# +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. +# +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. +# +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. +# +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. +# +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain +# +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. +# +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. +# +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Find the nccl libraries +# +# The following variables are optionally searched for defaults +# NCCL_ROOT: Base directory where all NCCL components are foundHong Xu, 1 year ago: • Let CMake handle NCCL detection instead of ou… +# NCCL_INCLUDE_DIR: Directory where NCCL header is foundPieter Noordhuis, 3 years ago: • Bump gloo +# NCCL_LIB_DIR: Directory where NCCL library is found +# +# The following are set after configuration is done: +# NCCL_FOUND +# NCCL_INCLUDE_DIRS +# NCCL_LIBRARIES +# +# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks +# install NCCL in the same location as the CUDA toolkit. +# See https://github.com/caffe2/caffe2/issues/1601 + +set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers") +set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries") +set(NCCL_VERSION $ENV{NCCL_VERSION} CACHE STRING "Version of NCCL to build with") + +if ($ENV{NCCL_ROOT_DIR}) + message(WARNING "NCCL_ROOT_DIR is deprecated. Please set NCCL_ROOT instead.") +endif() +list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) +# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT}) + +find_path(NCCL_INCLUDE_DIRS + NAMES nccl.h + HINTS ${NCCL_INCLUDE_DIR}) + +if (USE_STATIC_NCCL) + MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.") + SET(NCCL_LIBNAME "nccl_static") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +else() + SET(NCCL_LIBNAME "nccl") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +endif() + +find_library(NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES) + +if(NCCL_FOUND) # obtaining NCCL version and some sanity checks + set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") + message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...") + set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS}) + include(CheckCXXSymbolExists) + check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) + + if (NCCL_VERSION_DEFINED) + set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") + file(WRITE ${file} " + #include + #include + int main() + { + std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; + int x; + ncclGetVersion(&x); + return x == NCCL_VERSION_CODE; + } +") + try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER + LINK_LIBRARIES ${NCCL_LIBRARIES}) + if (NOT NCCL_VERSION_MATCHED) + message(FATAL_ERROR "Found NCCL header version and library version do not match! \ +(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") + endif() + message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") + else() + # message(STATUS "NCCL version < 2.3.5-5") + endif () + set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) + + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) +endif() diff --git a/cmake/TritonTransformerBackendConfig.cmake.in b/cmake/TritonTransformerBackendConfig.cmake.in new file mode 100644 index 0000000..4cf09b7 --- /dev/null +++ b/cmake/TritonTransformerBackendConfig.cmake.in @@ -0,0 +1,39 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + TRITONPYTORCHBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${TRITONPYTORCHBACKEND_CMAKE_DIR}) + +if(NOT TARGET TritonPyTorchBackend::triton-pytorch-backend) + include("${TRITONPYTORCHBACKEND_CMAKE_DIR}/TritonPyTorchBackendTargets.cmake") +endif() + +set(TRITONPYTORCHBACKEND_LIBRARIES TritonPyTorchBackend::triton-pytorch-backend) diff --git a/src/libtransformer.cc b/src/libtransformer.cc new file mode 100644 index 0000000..4cace68 --- /dev/null +++ b/src/libtransformer.cc @@ -0,0 +1,1155 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include +#include + +#pragma GCC diagnostic push +//#pragma GCC diagnostic ignored "-Wsign-compare" +#pragma GCC diagnostic ignored "-Wcast-function-type" +#pragma warning(push, 0) +#include "fastertransformer/gpt.h" +#pragma warning(pop) +#pragma GCC diagnostic pop + +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_input_collector.h" +#include "triton/backend/backend_memory.h" +#include "triton/backend/backend_model.h" +#include "triton/backend/backend_model_instance.h" +#include "triton/backend/backend_output_responder.h" +#include "triton/core/tritonbackend.h" + +#include "fastertransformer/triton_backend/transformer.hpp" +#include "fastertransformer/triton_backend/gpt_triton_backend.hpp" +#include + +// +// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API. +// + +namespace triton { namespace backend { namespace pytorch { + +#define RESPOND_ALL_AND_RETURN_IF_ERROR(RESPONSES, RESPONSES_COUNT, X) \ + do { \ + TRITONSERVER_Error* raarie_err__ = (X); \ + if (raarie_err__ != nullptr) { \ + SendErrorForResponses(RESPONSES, RESPONSES_COUNT, raarie_err__); \ + return; \ + } \ + } while (false) + +// +// ModelState +// +// State associated with a model that is using this backend. An object +// of this class is created and associated with each +// TRITONBACKEND_Model. +// +class ModelState : public BackendModel { + public: + static TRITONSERVER_Error* Create( + TRITONBACKEND_Model* triton_model, ModelState** state); + virtual ~ModelState() = default; + + // Load a TorchScript model using 'artifact_name' as the name for the + // TorchScript file. Return in 'model_path' the full path to the + // TorchScript file, return in 'torch_model' the Torch Module + // representing the model. + TRITONSERVER_Error* LoadModel + (const std::string& artifact_name, + const int32_t node_id, + const int32_t device_id, + const cudaStream_t stream, + std::string* model_path, + std::unique_ptr* ft_model_instance); + + int GetGpuSize() {return gpu_size;}; + + private: + ModelState(TRITONBACKEND_Model* triton_model); + TRITONSERVER_Error* AutoCompleteConfig(); + std::shared_ptr ftModel; + int node_id, gpu_size, world_size; + ncclComm_t tensor_nccl_comms[8]; + ncclComm_t layer_nccl_comms[8]; + cudaStream_t streams_[8]; + std::vector nccl_ids; +}; + + +TRITONSERVER_Error* +ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) +{ + try { + *state = new ModelState(triton_model); + } + catch (const BackendModelException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelException")); + RETURN_IF_ERROR(ex.err_); + } + + // Auto-complete the configuration if requested... + bool auto_complete_config = false; + RETURN_IF_ERROR(TRITONBACKEND_ModelAutoCompleteConfig( + triton_model, &auto_complete_config)); + if (auto_complete_config) { + RETURN_IF_ERROR((*state)->AutoCompleteConfig()); + + triton::common::TritonJson::WriteBuffer json_buffer; + (*state)->ModelConfig().Write(&json_buffer); + + TRITONSERVER_Message* message; + RETURN_IF_ERROR(TRITONSERVER_MessageNewFromSerializedJson( + &message, json_buffer.Base(), json_buffer.Size())); + RETURN_IF_ERROR(TRITONBACKEND_ModelSetConfig( + triton_model, 1 /* config_version */, message)); + } + + return nullptr; // success +} + +ModelState::ModelState(TRITONBACKEND_Model* triton_model) + : BackendModel(triton_model) +{ + triton::common::TritonJson::WriteBuffer buffer; + ModelConfig().PrettyWrite(&buffer); + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + (std::string("model configuration:\n") + buffer.Contents()).c_str()); + + common::TritonJson::Value param; + model_config_.MemberAsObject("parameters", ¶m); + auto param_get = [&] (const char* field) { + common::TritonJson::Value key; + std::string value; + param.MemberAsObject(field, &key); + key.MemberAsString("string_value", &value); + return value; + }; + auto param_get_int = [&] (const char* field) { + int ret = 0; + try { + ret = std::stoi(param_get(field)); + } catch (std::invalid_argument& ia) { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, + (std::string("Invalid configuration argument '") + field + "': " + ia.what()).c_str()); + } + return ret; + }; + auto param_get_float = [&] (const char* field) { + float ret = 0.0; + try { + ret = std::stof(param_get(field)); + } catch (std::invalid_argument& ia) { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, + (std::string("Invalid configuration argument '") + field + "': " + ia.what()).c_str()); + } + return ret; + }; + + auto modelVersionPath = JoinPath({RepositoryPath(), std::to_string(Version()), "/"}); + + if (param_get_int("is_half")) + ftModel.reset(new GptModel + (param_get_int("batch_size"), + param_get_int("candidate_num"), + param_get_int("head_num"), + param_get_int("size_per_head"), + param_get_int("vocab_size"), + param_get_int("max_seq_len"), + param_get_int("decoder_layers"), + param_get_int("tensor_para_size"), + param_get_int("layer_para_size"), + param_get_int("layer_para_batch_size"), + param_get_float("probability_threshold"), + param_get_int("is_fuse_QKV"), + param_get("model_name"), + modelVersionPath)); + else + ftModel.reset(new GptModel + (param_get_int("batch_size"), + param_get_int("candidate_num"), + param_get_int("head_num"), + param_get_int("size_per_head"), + param_get_int("vocab_size"), + param_get_int("max_seq_len"), + param_get_int("decoder_layers"), + param_get_int("tensor_para_size"), + param_get_int("layer_para_size"), + param_get_int("layer_para_batch_size"), + param_get_float("probability_threshold"), + param_get_int("is_fuse_QKV"), + param_get("model_name"), + modelVersionPath)); + + node_id = 0; + int tensor_para_size = ftModel->get_tensor_para_size(); + int layer_para_size = ftModel->get_layer_para_size(); + gpu_size = tensor_para_size * layer_para_size; + world_size = gpu_size; + assert(tensor_para_size <= gpu_size); + + nccl_ids = ftModel->create_nccl_ids(world_size); + + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, + (std::string("Model is loaded as'") + ftModel->to_string()).c_str()); + + NCCLCHECK(ncclGroupStart()); + for (int gid=0; gid < gpu_size; gid++) { + + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, + (std::string("enter nccl group") + std::to_string(gid)).c_str()); + int rank = node_id * 8 + gid; + size_t tensor_para_rank = rank % tensor_para_size; + size_t layer_para_rank = rank / tensor_para_size; + + ncclUniqueId tensor_para_nccl_uid = nccl_ids[rank / tensor_para_size]; + ncclUniqueId layer_para_nccl_uid = nccl_ids[world_size / tensor_para_size + rank % tensor_para_size]; + + CUDACHECK(cudaSetDevice(gid)); + NCCLCHECK( ncclCommInitRank(&tensor_nccl_comms[gid], tensor_para_size, tensor_para_nccl_uid, tensor_para_rank)); + NCCLCHECK( ncclCommInitRank(&layer_nccl_comms[gid], layer_para_size, layer_para_nccl_uid, layer_para_rank)); + } + NCCLCHECK(ncclGroupEnd()); + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, + (std::string("Model is loaded as'") + ftModel->to_string()).c_str()); +} + +TRITONSERVER_Error* +ModelState::LoadModel( + const std::string& artifact_name, + const int32_t node_id, + const int32_t device_id, + const cudaStream_t stream, + std::string* model_path, + std::unique_ptr* ft_model_instance) +{ + // Find the TorchScript file that describes the model. If the model + // configuration doesn't have an explicit model file specified then + // use the default name ("model.pt"). + CUDACHECK(cudaSetDevice(device_id)); + std::string cc_model_filename = artifact_name; + if (cc_model_filename.empty()) { + cc_model_filename = "gpt3-model"; + } + + { + size_t free_bytes, total_bytes; + check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); + float free = (float)(free_bytes) / 1024.0 / 1024.0 / 1024.0; + float total = (float)(total_bytes) / 1024.0 / 1024.0 / 1024.0; + printf("before allocation, free %.2f GB total %.2f GB\n", free, total); + } + + auto path = JoinPath({RepositoryPath(), std::to_string(Version()), cc_model_filename}); + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("Model path ") + path).c_str()); + + cudaStreamCreate(&streams_[device_id]); + auto modelInstance = ftModel->createModelInstance(node_id, device_id, world_size, streams_[device_id]); + auto param_instance = ftModel->createParamInstance(node_id, device_id, world_size, streams_[device_id], nccl_ids); + param_instance->init_nccl_from_comms(tensor_nccl_comms[device_id], layer_nccl_comms[device_id]); + modelInstance->set_param(param_instance.get()); + //{ cannot do warm test for multi-gpu with nccl + // auto input_tensors = prepareRequest(ftModel->get_max_batch_seqlen(), "./start_ids.csv"); + // check_inputs(input_tensors, "in.warm"); + // auto output_tensors = modelInstance->forward(input_tensors); + // check_outputs(output_tensors, "out.warm"); + //} + ft_model_instance->reset(modelInstance.release()); + + *model_path = JoinPath( + {RepositoryPath(), std::to_string(Version()), cc_model_filename}); + + { + size_t free_bytes, total_bytes; + check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); + float free = (float)(free_bytes) / 1024.0 / 1024.0 / 1024.0; + float total = (float)(total_bytes) / 1024.0 / 1024.0 / 1024.0; + printf("after allocation, free %.2f GB total %.2f GB\n", free, total); + } + return nullptr; // success +} + +TRITONSERVER_Error* +ModelState::AutoCompleteConfig() +{ + // Auto-complete configuration is not supported since PyTorch does not + // store/capture sufficient model metadata so just log error instead. + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + (std::string("skipping model configuration auto-complete for '") + + Name() + "': not supported for pytorch backend") + .c_str()); + + return nullptr; // success +} + + +// +// ModelInstanceState +// +// State associated with a model instance. An object of this class is +// created and associated with each TRITONBACKEND_ModelInstance. +// +class ModelInstanceState : public BackendModelInstance { + public: + static TRITONSERVER_Error* Create( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state); + virtual ~ModelInstanceState(); + + // Get the state of the model that corresponds to this instance. + ModelState* StateForModel() const { return model_state_; } + + // Execute... + void ProcessRequests( + TRITONBACKEND_Request** requests, const uint32_t request_count); + + private: + ModelInstanceState( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance); + TRITONSERVER_Error* ValidateInputs(); + TRITONSERVER_Error* ValidateOutputs(); + std::shared_ptr> Execute( + std::vector* responses, + const uint32_t response_count, + std::shared_ptr> input_tensors); + void SetInputTensors( + size_t total_batch_size, TRITONBACKEND_Request** requests, + const uint32_t request_count, + std::vector* responses, + BackendInputCollector* collector, std::vector* input_names, + std::shared_ptr>* input_tensors, + std::vector* input_memories, bool* cuda_copy); + void ReadOutputTensors( + size_t total_batch_size, const std::vector& output_names, + std::shared_ptr> output_tensors, + TRITONBACKEND_Request** requests, const uint32_t request_count, + std::vector* responses); + + ModelState* model_state_; + + // The full path to the TorchScript model file. + std::string model_path_; + + + std::unique_ptr ft_model_instance_[8]; + + // Map from configuration name for an input to the index of + // that input in the model. + std::unordered_map input_index_map_; + + // Map from configuration name for an output to the index of + // that output in the model. + std::unordered_map output_index_map_; + std::unordered_map output_dtype_map_; +}; + +TRITONSERVER_Error* +ModelInstanceState::Create( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state) +{ + try { + *state = new ModelInstanceState(model_state, triton_model_instance); + } + catch (const BackendModelInstanceException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelInstanceException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + +int ThreadLoadModel(ModelState* model_state, + const std::string& artifact_name, + const int32_t node_id, + const int32_t device_id, + const cudaStream_t stream, + std::string* model_path, + std::unique_ptr* ft_model_instance) +{ + THROW_IF_BACKEND_INSTANCE_ERROR + (model_state->LoadModel + (artifact_name, 0, device_id, stream, model_path, ft_model_instance)); + return 0; +} + +ModelInstanceState::ModelInstanceState( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) + : BackendModelInstance(model_state, triton_model_instance), + model_state_(model_state) +{ + //if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("Faster transformer model instance is created at GPU '") + + std::to_string(DeviceId()) + "'").c_str()); + + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("Model name ") + ArtifactFilename()).c_str()); + //} + + THROW_IF_BACKEND_INSTANCE_ERROR(ValidateInputs()); + THROW_IF_BACKEND_INSTANCE_ERROR(ValidateOutputs()); + + std::vector threads; + for(int gid = 0; gid < model_state->GetGpuSize(); gid ++) { + threads.push_back(std::thread(ThreadLoadModel, + model_state, + ArtifactFilename(), 0, gid, CudaStream(), + &model_path_, &ft_model_instance_[gid])); + } + for(auto & t : threads) { + t.join(); + } + + struct cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, DeviceId()); + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("Model instance is created on GPU ") + prop.name).c_str()); +} + +ModelInstanceState::~ModelInstanceState() +{ +#ifdef TRITON_ENABLE_GPU +#endif // TRITON_ENABLE_GPU +} + +TRITONSERVER_Error* +ModelInstanceState::ValidateInputs() +{ + triton::common::TritonJson::Value ios; + std::string name, data_type; + triton::common::TritonJson::Value jshape; + model_state_->ModelConfig().MemberAsArray("input", &ios); + + for (size_t size = 0; size < ios.ArraySize(); size++){ + triton::common::TritonJson::Value input; + ios.IndexAsObject(size, &input); + input.MemberAsString("name", &name); + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("get input name: " + name).c_str())); + input.MemberAsString("data_type", &data_type); + input.MemberAsArray("dims", &jshape); + //TRITONSERVER_DataType type = TRITONSERVER_StringToDataType(data_type.c_str()); + //assert(type == TRITONSERVER_TYPE_UINT32); + + std::vector shape; + for(size_t size = 0; size < jshape.ArraySize(); size++){ + size_t value; + jshape.IndexAsUInt(size, &value); + shape.push_back(value); + } + + // assert(shape.size() == 2); + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("input: ") + name + + ", type: " + data_type + + ", shape: [" + std::to_string(shape[0]) + ", " + std::to_string(shape[1]) + "]").c_str()); + } + return nullptr; // success +} + +TRITONSERVER_Error* +ModelInstanceState::ValidateOutputs() +{ + triton::common::TritonJson::Value ios; + RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios)); + + std::string name, data_type; + triton::common::TritonJson::Value jshape; + model_state_->ModelConfig().MemberAsArray("output", &ios); + for (size_t size = 0; size < ios.ArraySize(); size++){ + triton::common::TritonJson::Value input; + ios.IndexAsObject(size, &input); + input.MemberAsString("name", &name); + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("get input name: " + name).c_str())); + input.MemberAsString("data_type", &data_type); + input.MemberAsArray("dims", &jshape); + //TRITONSERVER_DataType type = TRITONSERVER_StringToDataType(data_type.c_str()); + //assert(type == TRITONSERVER_TYPE_UINT32); + + std::vector shape; + for(size_t size = 0; size < jshape.ArraySize(); size++){ + size_t value; + jshape.IndexAsUInt(size, &value); + shape.push_back(value); + } + + // assert(shape.size() == 2); + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("input: ") + name + + ", type: " + data_type + + ", shape: [" + std::to_string(shape[0]) + ", " + std::to_string(shape[1]) + "]").c_str()); + } + + return nullptr; // success +} + +void +ModelInstanceState::ProcessRequests( + TRITONBACKEND_Request** requests, const uint32_t request_count) +{ + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + (std::string("TRITONBACKEND_ModelExecute: Running ") + Name() + " with " + + std::to_string(request_count) + " requests") + .c_str()); + + uint64_t exec_start_ns = 0; + SET_TIMESTAMP(exec_start_ns); + + const int max_batch_size = model_state_->MaxBatchSize(); + + // For each request collect the total batch size for this inference + // execution. The batch-size, number of inputs, and size of each + // input has already been checked so don't need to do that here. + size_t total_batch_size = 0; + for (size_t i = 0; i < request_count; i++) { + // If we get a nullptr request then something is badly wrong. Fail + // and release all requests. + if (requests[i] == nullptr) { + RequestsRespondWithError( + requests, request_count, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string( + "null request given to PyTorch backend for '" + Name() + "'") + .c_str())); + return; + } + + if (max_batch_size > 0) { + // Retrieve the batch size from one of the inputs, if the model + // supports batching, the first dimension size is batch size + TRITONBACKEND_Input* input; + TRITONSERVER_Error* err = + TRITONBACKEND_RequestInputByIndex(requests[i], 0 /* index */, &input); + if (err == nullptr) { + const int64_t* shape; + err = TRITONBACKEND_InputProperties( + input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr); + total_batch_size += shape[0]; + } + if (err != nullptr) { + RequestsRespondWithError(requests, request_count, err); + return; + } + } else { + total_batch_size += 1; + } + } + + // If there are no valid payloads then no need to run the inference. + if (total_batch_size == 0) { + return; + } + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("get total batch_size = ") + + std::to_string(total_batch_size)).c_str()); + + // Make sure the maximum batch size is not exceeded. The + // total_batch_size must be 1 for models that don't support batching + // (i.e. max_batch_size == 0). If max_batch_size is exceeded then + // scheduler has done something badly wrong so fail and release all + // requests. + if ((total_batch_size != 1) && (total_batch_size > (size_t)max_batch_size)) { + RequestsRespondWithError( + requests, request_count, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string( + "batch size " + std::to_string(total_batch_size) + " for '" + + Name() + "', max allowed is " + std::to_string(max_batch_size)) + .c_str())); + return; + } + + // At this point we are committed to running inference with all + // 'requests'. Create a response for each request. During input + // processing if there is an error with any request that error will + // be sent immediately with the corresponding response (and the + // response unique_ptr will then be nullptr). The request object + // itself will not be released until after all inferencing is done + // (below) as we may need to access the request object when + // determine how to process outputs (for example, even if we don't + // need the outputs for a request that has an error, we do need to + // know the size of those outputs associated with the request so we + // can skip them in the output tensors). + std::vector responses; + responses.reserve(request_count); + + for (size_t i = 0; i < request_count; i++) { + TRITONBACKEND_Response* response; + auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); + if (err == nullptr) { + responses.emplace_back(response); + } else { + responses.emplace_back(nullptr); + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response"); + TRITONSERVER_ErrorDelete(err); + } + } + + std::vector input_names; + std::shared_ptr> input_tensors = std::make_shared>(); + std::vector input_memories; + bool cuda_copy = false; + BackendInputCollector collector( + requests, request_count, &responses, model_state_->TritonMemoryManager(), + model_state_->EnablePinnedInput(), CudaStream()); + SetInputTensors( + total_batch_size, requests, request_count, &responses, &collector, + &input_names, &input_tensors, &input_memories, &cuda_copy); + + // Wait for any in-flight input tensor copies to complete. +#ifdef TRITON_ENABLE_GPU + if (cuda_copy) { + cudaStreamSynchronize(CudaStream()); + } +#endif + + uint64_t compute_start_ns = 0; + SET_TIMESTAMP(compute_start_ns); + + // Run... + auto output_tensors = Execute(&responses, request_count, input_tensors); + + uint64_t compute_end_ns = 0; + SET_TIMESTAMP(compute_end_ns); + + // Free BackendMemory used for inputs + for (BackendMemory* mem : input_memories) { + delete mem; + } + input_memories.clear(); + + // Verify output indices are valid with number of outputs after execution + std::vector output_names; + output_names.push_back("OUTPUT0"); + bool invalid_index = false; + int max_index = output_tensors->size() - 1; + for (const auto& name : output_names) { + int op_index = output_index_map_[name]; + if ((op_index < 0) || (op_index > max_index)) { + SendErrorForResponses( + &responses, request_count, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "The output " + std::string(name) + + " in the model configuration refers to an output index which" + " doesn't exist. This model has " + + std::to_string(max_index + 1) + " outputs") + .c_str())); + invalid_index = true; + break; + } + } + + if (!invalid_index) { + ReadOutputTensors( + total_batch_size, output_names, output_tensors, requests, request_count, + &responses); + } + + uint64_t exec_end_ns = 0; + SET_TIMESTAMP(exec_end_ns); + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("get response size = ") + std::to_string(responses.size())).c_str()); + + // Send all the responses that haven't already been sent because of + // an earlier error. Note that the responses are not set to nullptr + // here as we need that indication below to determine if the request + // we successful or not. + for (auto& response : responses) { + if (response != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), + "failed to send PyTorch backend response"); + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("response is sent")).c_str()); + } + else { + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("response is nullptr")).c_str()); + } + } + + // Report statistics for each request. + for (uint32_t r = 0; r < request_count; ++r) { + auto& request = requests[r]; + LOG_IF_ERROR( + TRITONBACKEND_ModelInstanceReportStatistics( + TritonModelInstance(), request, + (responses[r] != nullptr) /* success */, exec_start_ns, + compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting request statistics"); + + LOG_IF_ERROR( + TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), + "failed releasing request"); + } + + // Report the entire batch statistics. + LOG_IF_ERROR( + TRITONBACKEND_ModelInstanceReportBatchStatistics( + TritonModelInstance(), total_batch_size, exec_start_ns, + compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting batch request statistics"); +} + +int ThreadForward(std::unique_ptr *ft_model_instance, + std::shared_ptr> *input_tensors, + std::shared_ptr> *output_tensors, + const int device_id) +{ + CUDACHECK(cudaSetDevice(device_id)); + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("Start to forward")).c_str()); + // output_tensors = ft_model_instance->forward(input_tensors); + *output_tensors = (*ft_model_instance)->forward(*input_tensors); + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("Stop to forward")).c_str()); + + return 0; +} + +std::shared_ptr> +ModelInstanceState::Execute( + std::vector* responses, + const uint32_t response_count, + std::shared_ptr> input_tensors) +{ + + try { + const int gpu_size = model_state_->GetGpuSize(); + check_inputs(input_tensors); + std::vector threads; + std::shared_ptr> output_tensors_list[gpu_size]; + for(int gid = 0; gid < gpu_size; gid ++) + { + LOG_MESSAGE(TRITONSERVER_LOG_WARN, (std::string("before ThreadForward " + std::to_string(gid))).c_str()); + threads.push_back(std::thread(ThreadForward, &ft_model_instance_[gid], &input_tensors, &output_tensors_list[gid], gid)); + LOG_MESSAGE(TRITONSERVER_LOG_WARN, (std::string("after ThreadForward " + std::to_string(gid))).c_str()); + } + for(auto & t : threads) + { + t.join(); + } + + auto output_tensors = output_tensors_list[0]; + // auto output_tensors = ft_model_instance_[0]->forward(input_tensors); + check_outputs(output_tensors); + return output_tensors; + //return ft_model_instance_->forward(input_tensors); + } + catch (std::exception& ex) { + SendErrorForResponses( + responses, response_count, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + ("PyTorch execute failure: " + std::string(ex.what())).c_str())); + return std::shared_ptr>(nullptr); + } +} + +void +ModelInstanceState::SetInputTensors( + size_t total_batch_size, TRITONBACKEND_Request** requests, + const uint32_t request_count, + std::vector* responses, + BackendInputCollector* collector, std::vector* input_names, + std::shared_ptr>* input_tensors, + std::vector* input_memories, bool* cuda_copy) +{ + const int max_batch_size = model_state_->MaxBatchSize(); + + // All requests must have equally-sized input tensors so use any + // request as the representative for the input tensors. + uint32_t input_count; + RESPOND_ALL_AND_RETURN_IF_ERROR( + responses, request_count, + TRITONBACKEND_RequestInputCount(requests[0], &input_count)); + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("get input count = ") + + std::to_string(input_count)).c_str()); + + for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) { + TRITONBACKEND_Input* input; + RESPOND_ALL_AND_RETURN_IF_ERROR( + responses, request_count, + TRITONBACKEND_RequestInputByIndex(requests[0], input_idx, &input)); + + const char* input_name; + TRITONSERVER_DataType input_datatype; + const int64_t* input_shape; + uint32_t input_dims_count; + RESPOND_ALL_AND_RETURN_IF_ERROR( + responses, request_count, + TRITONBACKEND_InputProperties( + input, &input_name, &input_datatype, &input_shape, + &input_dims_count, nullptr, nullptr)); + + input_names->emplace_back(input_name); + + // The shape for the entire input patch, [total_batch_size, ...] + std::vector batchn_shape( + input_shape, input_shape + input_dims_count); + if (max_batch_size != 0) { + batchn_shape[0] = total_batch_size; + } + + // The input must be in contiguous CPU/GPU memory. + const int64_t batchn_byte_size = GetByteSize(input_datatype, batchn_shape); + + bool device_is_cpu = true; + + std::vector alloc_perference; + if (device_is_cpu) { + alloc_perference = {BackendMemory::AllocationType::CPU}; + } else { + alloc_perference = {BackendMemory::AllocationType::GPU_POOL, + BackendMemory::AllocationType::GPU}; + } + + BackendMemory* input_memory; + RESPOND_ALL_AND_RETURN_IF_ERROR( + responses, request_count, + BackendMemory::Create( + model_state_->TritonMemoryManager(), alloc_perference, + device_is_cpu ? 0 : DeviceId(), batchn_byte_size, + &input_memory)); + input_memories->push_back(input_memory); + + TRITONSERVER_MemoryType memory_type = input_memory->MemoryType(); + int64_t memory_type_id = input_memory->MemoryTypeId(); + char* input_buffer = input_memory->MemoryPtr(); + + collector->ProcessTensor( + input_name, input_buffer, batchn_byte_size, memory_type, + memory_type_id); + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("collect name: ") + input_name + + " size: " + std::to_string(batchn_byte_size)).c_str()); + (*input_tensors)->push_back(Tensor{TRITONSERVER_MEMORY_CPU, input_datatype, batchn_shape, input_buffer}); + } + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("the data is in ") + (*cuda_copy ? std::string("GPU") : std::string("CPU"))).c_str()); + // Finalize... + *cuda_copy |= collector->Finalize(); + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("the data is in ") + (*cuda_copy ? std::string("GPU") : std::string("CPU"))).c_str()); +} + +void +ModelInstanceState::ReadOutputTensors( + size_t total_batch_size, const std::vector& output_names, + std::shared_ptr> output_tensors, + TRITONBACKEND_Request** requests, const uint32_t request_count, + std::vector* responses) +{ + BackendOutputResponder responder( + requests, request_count, responses, model_state_->MaxBatchSize(), + model_state_->TritonMemoryManager(), model_state_->EnablePinnedInput(), + CudaStream()); + + bool cuda_copy = false; + std::vector> string_buffers; + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("output name size") + std::to_string(output_names.size()) + ", name: " + output_names[0]).c_str()); + + for (size_t idx = 0; idx < output_names.size(); idx++) { + std::string name = output_names[idx]; + int op_index = output_index_map_[name]; + name = "OUTPUT0"; + op_index = 0; + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("get output_tensors 0")).c_str()); + auto& output = output_tensors->at(0); + + + // Verify output datatype matches datatype from model config + TRITONSERVER_DataType output_dtype = output.type; + TRITONSERVER_DataType config_datatype = TRITONSERVER_TYPE_UINT32; + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("get output_type: ") + TRITONSERVER_DataTypeString(output.type) + ", " + TRITONSERVER_DataTypeString(config_datatype) ).c_str()); + // if (config_datatype != output_dtype) { + // RESPOND_ALL_AND_RETURN_IF_ERROR( + // responses, request_count, + // TRITONSERVER_ErrorNew( + // TRITONSERVER_ERROR_INVALID_ARG, + // (std::string("unexpected datatype TYPE_") + + // TRITONSERVER_DataTypeString(output_dtype) + + // " for inference output '" + name + "', expecting TYPE_" + + // TRITONSERVER_DataTypeString(config_datatype)) + // .c_str())); + // } + + const char* output_buffer = static_cast(output.data); + + // Set output shape + std::vector batchn_shape(output.shape); + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("output shape: [") + std::to_string(batchn_shape[0]) + ", " + std::to_string(batchn_shape[1]) + "]").c_str()); + + responder.ProcessTensor( + name, output_dtype, batchn_shape, output_buffer, + TRITONSERVER_MEMORY_GPU, + DeviceId()); + } + + // Finalize and wait for any pending buffer copies. + cuda_copy |= responder.Finalize(); + +#ifdef TRITON_ENABLE_GPU + if (cuda_copy) { + cudaStreamSynchronize(stream_); + } +#endif // TRITON_ENABLE_GPU + + LOG_MESSAGE(TRITONSERVER_LOG_WARN, + (std::string("PERFORMED GPU copy: ") + (cuda_copy ? std::string("YES") : std::string("NO")) ).c_str()); + +} + +///////////// + +extern "C" { + +TRITONSERVER_Error* +TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) +{ + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &cname)); + std::string name(cname); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_Initialize: ") + name).c_str()); + + // Check the backend API version that Triton supports vs. what this + // backend was compiled against. + uint32_t api_version_major, api_version_minor; + RETURN_IF_ERROR( + TRITONBACKEND_ApiVersion(&api_version_major, &api_version_minor)); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Triton TRITONBACKEND API version: ") + + std::to_string(api_version_major) + "." + + std::to_string(api_version_minor)) + .c_str()); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("'") + name + "' TRITONBACKEND API version: " + + std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + + std::to_string(TRITONBACKEND_API_VERSION_MINOR)) + .c_str()); + + if ((api_version_major != TRITONBACKEND_API_VERSION_MAJOR) || + (api_version_minor < TRITONBACKEND_API_VERSION_MINOR)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + (std::string("Triton TRITONBACKEND API version: ") + + std::to_string(api_version_major) + "." + + std::to_string(api_version_minor) + " does not support '" + name + + "' TRITONBACKEND API version: " + + std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + + std::to_string(TRITONBACKEND_API_VERSION_MINOR)) + .c_str()); + } + +// const int tensor_para_size = 1, tensor_para_rank = 1; +// const int layer_para_size = 1, layer_para_rank = 1; +// ncclUniqueId tensor_para_nccl_uid; +// ncclUniqueId layer_para_nccl_uid; +// +// ncclComm_t tensor_para_nccl_comm, layer_para_nccl_comm; +// NCCLCHECK( ncclCommInitRank(&tensor_para_nccl_comm, tensor_para_size, tensor_para_nccl_uid, tensor_para_rank)); +// NCCLCHECK( ncclCommInitRank(&layer_para_nccl_comm, layer_para_size, layer_para_nccl_uid, layer_para_rank)); + + + return nullptr; // success +} + +TRITONSERVER_Error* +TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) +{ + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_ModelName(model, &cname)); + std::string name(cname); + + uint64_t version; + RETURN_IF_ERROR(TRITONBACKEND_ModelVersion(model, &version)); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_ModelInitialize: ") + name + " (version " + + std::to_string(version) + ")") + .c_str()); + + // Create a ModelState object and associate it with the + // TRITONBACKEND_Model. + ModelState* model_state; + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); + RETURN_IF_ERROR( + TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); + + return nullptr; // success +} + +TRITONSERVER_Error* +TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); + ModelState* model_state = reinterpret_cast(vstate); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, "TRITONBACKEND_ModelFinalize: delete model state"); + + delete model_state; + + return nullptr; // success +} + +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) +{ + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceName(instance, &cname)); + std::string name(cname); + + int32_t device_id; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceDeviceId(instance, &device_id)); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_ModelInstanceInitialize: ") + name + + " (device " + std::to_string(device_id) + ")") + .c_str()); + + // Get the model state associated with this instance's model. + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + void* vmodelstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); + ModelState* model_state = reinterpret_cast(vmodelstate); + + // Create a ModelInstanceState object and associate it with the + // TRITONBACKEND_ModelInstance. + ModelInstanceState* instance_state; + RETURN_IF_ERROR( + ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( + instance, reinterpret_cast(instance_state))); + + return nullptr; // success +} + +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); + ModelInstanceState* instance_state = + reinterpret_cast(vstate); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + "TRITONBACKEND_ModelInstanceFinalize: delete instance state"); + + delete instance_state; + + return nullptr; // success +} + +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count) +{ + // Triton will not call this function simultaneously for the same + // 'instance'. But since this backend could be used by multiple + // instances from multiple models the implementation needs to handle + // multiple calls to this function at the same time (with different + // 'instance' objects). Suggested practice for this is to use only + // function-local and model-instance-specific state (obtained from + // 'instance'), which is what we do here. + ModelInstanceState* instance_state; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState( + instance, reinterpret_cast(&instance_state))); + ModelState* model_state = instance_state->StateForModel(); + + // This backend specifies BLOCKING execution policy. That means that + // we should not return from this function until execution is + // complete. Triton will automatically release 'instance' on return + // from this function so that it is again available to be used for + // another call to TRITONBACKEND_ModelInstanceExecute. + + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + (std::string("model ") + model_state->Name() + ", instance " + + instance_state->Name() + ", executing " + std::to_string(request_count) + + " requests") + .c_str()); + + // At this point we accept ownership of 'requests', which means that + // even if something goes wrong we must still return success from + // this function. If something does go wrong in processing a + // particular request then we send an error response just for the + // specific request. + instance_state->ProcessRequests(requests, request_count); + + return nullptr; // success +} + +} // extern "C" + +}}} // namespace triton::backend::pytorch diff --git a/src/libtriton_transformer.ldscript b/src/libtriton_transformer.ldscript new file mode 100644 index 0000000..61e9a06 --- /dev/null +++ b/src/libtriton_transformer.ldscript @@ -0,0 +1,30 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +{ + global: + TRITONBACKEND_*; + local: *; +}; diff --git a/tools/identity_test.py b/tools/identity_test.py new file mode 100644 index 0000000..9a9d495 --- /dev/null +++ b/tools/identity_test.py @@ -0,0 +1,180 @@ +#!/usr/bin/python + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import numpy as np +import os +import re +import sys +import requests as httpreq +from builtins import range +import tritongrpcclient as grpcclient +import tritonhttpclient as httpclient +from tritonclientutils import np_to_triton_dtype + +FLAGS = None + +START_LEN = 8 +OUTPUT_LEN = 24 +BATCH_SIZE = 8 + +start_id = 220 +end_id = 50256 +# random_start_ids = np.random.randint(0, 50255, size=(BATCH_SIZE, START_LEN), dtype=np.uint32) +random_start_ids = np.array([[9915, 27221, 59, 77, 383, 1853, 3327, 1462], + [6601, 4237, 345, 460, 779, 284, 787, 257], + [59, 77, 611, 7, 9248, 796, 657, 8], + [38, 10128, 6032, 651, 8699, 4, 4048, 20753], + [21448, 7006, 930, 12901, 930, 7406, 7006, 198], + [13256, 11, 281, 1605, 3370, 11, 1444, 6771], + [9915, 27221, 59, 77, 383, 1853, 3327, 1462], + [6601, 4237, 345, 460, 779, 284, 787, 257]], np.uint32) +input_len = np.array([ [sentence.size] for sentence in random_start_ids ], np.uint32) +output_len = np.ones_like(input_len).astype(np.uint32) * OUTPUT_LEN + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument('-u', + '--url', + type=str, + required=False, + help='Inference server URL.') + parser.add_argument( + '-i', + '--protocol', + type=str, + required=False, + default='http', + help='Protocol ("http"/"grpc") used to ' + + 'communicate with inference service. Default is "http".') + + FLAGS = parser.parse_args() + if (FLAGS.protocol != "http") and (FLAGS.protocol != "grpc"): + print("unexpected protocol \"{}\", expects \"http\" or \"grpc\"".format( + FLAGS.protocol)) + exit(1) + + client_util = httpclient if FLAGS.protocol == "http" else grpcclient + + if FLAGS.url is None: + FLAGS.url = "localhost:8000" if FLAGS.protocol == "http" else "localhost:8001" + + # Run async requests to make sure backend handles request batches + # correctly. We use just HTTP for this since we are not testing the + # protocol anyway. + + # warmup + if FLAGS.protocol == "http": + request_parallelism = 10 + model_name = "transformer" + # shape = [8, 8] + with client_util.InferenceServerClient(FLAGS.url, + concurrency=request_parallelism, + verbose=FLAGS.verbose) as client: + requests = [] + results = [] + for i in range(request_parallelism): + input_data = random_start_ids + inputs = [ + client_util.InferInput("INPUT_ID", input_data.shape, + np_to_triton_dtype(input_data.dtype)), + client_util.InferInput("REQUEST_INPUT_LEN", input_len.shape, + np_to_triton_dtype(input_len.dtype)), + client_util.InferInput("REQUEST_OUTPUT_LEN", output_len.shape, + np_to_triton_dtype(output_len.dtype)) + ] + inputs[0].set_data_from_numpy(input_data) + inputs[1].set_data_from_numpy(input_len) + inputs[2].set_data_from_numpy(output_len) + result = client.infer(model_name, inputs) + results.append(result) + + for i in range(request_parallelism): + # Get the result from the initiated asynchronous inference request. + # Note the call will block till the server responds. + output_data = results[i].as_numpy("OUTPUT0") + + from datetime import datetime + request_parallelism = 10 + start_time = datetime.now() + if FLAGS.protocol == "http": + model_name = "transformer" + # shape = [8, 8] + with client_util.InferenceServerClient(FLAGS.url, + concurrency=request_parallelism, + verbose=FLAGS.verbose) as client: + requests = [] + results = [] + for i in range(request_parallelism): + input_data = random_start_ids + inputs = [ + client_util.InferInput("INPUT_ID", input_data.shape, + np_to_triton_dtype(input_data.dtype)), + client_util.InferInput("REQUEST_INPUT_LEN", input_len.shape, + np_to_triton_dtype(input_len.dtype)), + client_util.InferInput("REQUEST_OUTPUT_LEN", output_len.shape, + np_to_triton_dtype(output_len.dtype)) + ] + inputs[0].set_data_from_numpy(input_data) + inputs[1].set_data_from_numpy(input_len) + inputs[2].set_data_from_numpy(output_len) + #requests.append(client.async_infer(model_name, inputs)) + print("set request") + result = client.infer(model_name, inputs) + print("get request") + results.append(result) + + for i in range(request_parallelism): + # Get the result from the initiated asynchronous inference request. + # Note the call will block till the server responds. + print("wait result return 0000\n") + ##results = requests[i].get_result() + print("wait result return 1111\n") + # print(results[i]) + print("get results\n") + + output_data = results[i].as_numpy("OUTPUT0") + output_data = output_data.reshape([-1, BATCH_SIZE]) + np.savetxt("triton_out", output_data, fmt='%u') + output_data = output_data.T + print("get results as OUTPUT0\n") + if output_data is None: + print("error: expected 'OUTPUT0'") + sys.exit(1) + else: + print("OUTPUT0 is received") + print(output_data.shape) + print(output_data) + stop_time = datetime.now() + print("[INFO] execution time: {} ms".format((stop_time - start_time).total_seconds() * 1000.0 / request_parallelism)) \ No newline at end of file diff --git a/tools/kill_server.sh b/tools/kill_server.sh new file mode 100644 index 0000000..09162d3 --- /dev/null +++ b/tools/kill_server.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +SERVER_PID=`ps -ef | grep '/opt/tritonserver' | awk '{print $2}' | head -n 1` +# echo $test +kill -KILL $SERVER_PID diff --git a/tools/run_client.sh b/tools/run_client.sh new file mode 100755 index 0000000..eb527b3 --- /dev/null +++ b/tools/run_client.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# export CUDA_VISIBLE_DEVICES=0 + +CLIENT_PY=$WORKSPACE/transformer_backend/tools/identity_test.py +CLIENT_LOG="./client.log" + +rm -rf client.log err.log +rm -rf triton_out + +RET=0 + +for PROTOCOL in http; do + set +e + python $CLIENT_PY -i $PROTOCOL -v 2> err.log > $CLIENT_LOG + if [ $? -ne 0 ]; then + RET=1 + fi + set -e +done + +exit $RET \ No newline at end of file diff --git a/tools/run_server.sh b/tools/run_server.sh new file mode 100755 index 0000000..c7ced90 --- /dev/null +++ b/tools/run_server.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# export CUDA_VISIBLE_DEVICES=0 + +SERVER=/opt/tritonserver/bin/tritonserver +SERVER_ARGS="--model-repository=$WORKSPACE/transformer_backend/all_models" +SERVER_LOG="./inference_server.log" +source $WORKSPACE/common/util.sh + +rm -fr inference_server.log + +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi \ No newline at end of file