Skip to content

Commit

Permalink
[refactor] combine sequence and request outputs (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed May 2, 2024
1 parent 54569d6 commit 4e4eecc
Show file tree
Hide file tree
Showing 21 changed files with 339 additions and 432 deletions.
33 changes: 17 additions & 16 deletions cmake/pybind11_module.cmake → cmake/pybind_extension.cmake
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include(CMakeParseArguments)

# pybind11_module()
# pybind_extension()
#
# Parameters:
# NAME: name of module
Expand All @@ -11,7 +11,7 @@ include(CMakeParseArguments)
# DEFINES: List of public defines
# LINKOPTS: List of link options
#
# pybind11_module(
# pybind_extension(
# NAME
# awesome
# HDRS
Expand Down Expand Up @@ -53,42 +53,43 @@ if(NOT DEFINED PYTHON_MODULE_EXTENSION OR NOT DEFINED PYTHON_MODULE_DEBUG_POSTFI
endif()
endif()

function(pybind11_module)
function(pybind_extension)
cmake_parse_arguments(
PYBIND11 # prefix
PY # prefix
"TESTONLY" # options
"NAME" # one value args
"HDRS;SRCS;COPTS;DEFINES;LINKOPTS;DEPS" # multi value args
${ARGN}
)

if(PYBIND11_TESTONLY AND (NOT BUILD_TESTING))
if(PY_TESTONLY AND (NOT BUILD_TESTING))
return()
endif()

add_library(${PYBIND11_NAME} SHARED)
target_sources(${PYBIND11_NAME}
PRIVATE ${PYBIND11_SRCS} ${PYBIND11_HDRS})
target_link_libraries(${PYBIND11_NAME}
PUBLIC ${PYBIND11_DEPS}
PRIVATE ${PYBIND11_LINKOPTS}
add_library(${PY_NAME} SHARED)
target_sources(${PY_NAME}
PRIVATE ${PY_SRCS} ${PY_HDRS}
)
target_compile_options(${PYBIND11_NAME} PRIVATE ${PYBIND11_COPTS})
target_compile_definitions(${PYBIND11_NAME} PUBLIC ${PYBIND11_DEFINES})
target_link_libraries(${PY_NAME}
PUBLIC ${PY_DEPS}
PRIVATE ${PY_LINKOPTS}
)
target_compile_options(${PY_NAME} PRIVATE ${PY_COPTS})
target_compile_definitions(${PY_NAME} PUBLIC ${PY_DEFINES})

# -fvisibility=hidden is required to allow multiple modules compiled against
# different pybind versions to work properly, and for some features (e.g.
# py::module_local).
if(NOT DEFINED CMAKE_CXX_VISIBILITY_PRESET)
set_target_properties(${PYBIND11_NAME} PROPERTIES CXX_VISIBILITY_PRESET "hidden")
set_target_properties(${PY_NAME} PROPERTIES CXX_VISIBILITY_PRESET "hidden")
endif()

if(NOT DEFINED CMAKE_CUDA_VISIBILITY_PRESET)
set_target_properties(${PYBIND11_NAME} PROPERTIES CUDA_VISIBILITY_PRESET "hidden")
set_target_properties(${PY_NAME} PROPERTIES CUDA_VISIBILITY_PRESET "hidden")
endif()

set_target_properties(
${PYBIND11_NAME}
${PY_NAME}
PROPERTIES PREFIX ""
DEBUG_POSTFIX "${PYTHON_MODULE_DEBUG_POSTFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}")
Expand Down
31 changes: 31 additions & 0 deletions python/scalellm/_C/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List

# Defined in scalellm/csrc/scalellm.cpp
class SamplingParameter:
def __init__(self) -> None: ...
frequency_penalty: float
presence_penalty: float
repetition_penalty: float
temperature: float
top_p: float
top_k: int

# Defined in scalellm/csrc/scalellm.cpp
class StoppingCriteria:
def __init__(self) -> None: ...
max_tokens: int
eos_token_id: int
ignore_eos_token: bool
stop_token_ids: List[int]

# Defined in scalellm/csrc/llm.h
class LLM:
def __init__(
self,
model_path: str,
sampling_parameter: SamplingParameter,
stopping_criteria: StoppingCriteria,
max_seq_len: int,
devices: str,
) -> None: ...
def generate(self, batched_prompt: List[str]) -> None: ...
4 changes: 3 additions & 1 deletion python/scalellm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__version__ = "0.0.5"
__version__ = "0.0.9"

from scalellm._C import LLM, SamplingParameter, StoppingCriteria

__all__ = ["LLM", "SamplingParameter", "StoppingCriteria"]
8 changes: 5 additions & 3 deletions python/scalellm/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
include(pybind11_module)
include(pybind_extension)

pybind11_module(
pybind_extension(
NAME
wrapper
_C
COPTS
-DPY_MODULE_NAME=_C
HDRS
llm.h
SRCS
Expand Down
46 changes: 20 additions & 26 deletions python/scalellm/csrc/scalellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,39 @@
#include "request/stopping_criteria.h"
#include "sampling/parameters.h"

namespace llm {
namespace py = pybind11;

PYBIND11_MODULE(wrapper, m) {
PYBIND11_MODULE(PY_MODULE_NAME, m) {
// class SamplingParameter
py::class_<llm::SamplingParameter, std::shared_ptr<llm::SamplingParameter>>(
py::class_<SamplingParameter, std::shared_ptr<SamplingParameter>>(
m, "SamplingParameter")
.def(py::init())
.def_readwrite("frequency_penalty",
&llm::SamplingParameter::frequency_penalty)
.def_readwrite("presence_penalty",
&llm::SamplingParameter::presence_penalty)
.def_readwrite("frequency_penalty", &SamplingParameter::frequency_penalty)
.def_readwrite("presence_penalty", &SamplingParameter::presence_penalty)
.def_readwrite("repetition_penalty",
&llm::SamplingParameter::repetition_penalty)
.def_readwrite("temperature", &llm::SamplingParameter::temperature)
.def_readwrite("top_p", &llm::SamplingParameter::top_p)
.def_readwrite("top_k", &llm::SamplingParameter::top_k)
.def_readwrite("do_sample", &llm::SamplingParameter::do_sample)
.def_readwrite("seed", &llm::SamplingParameter::do_sample);
&SamplingParameter::repetition_penalty)
.def_readwrite("temperature", &SamplingParameter::temperature)
.def_readwrite("top_p", &SamplingParameter::top_p)
.def_readwrite("top_k", &SamplingParameter::top_k);

// class StoppingCriteria
py::class_<llm::StoppingCriteria, std::shared_ptr<llm::StoppingCriteria>>(
py::class_<StoppingCriteria, std::shared_ptr<StoppingCriteria>>(
m, "StoppingCriteria")
.def(py::init())
.def_readwrite("max_tokens", &llm::StoppingCriteria::max_tokens)
.def_readwrite("eos_token_id", &llm::StoppingCriteria::eos_token_id)
.def_readwrite("ignore_eos_token",
&llm::StoppingCriteria::ignore_eos_token)
.def_readwrite("stop_token_ids", &llm::StoppingCriteria::stop_token_ids)
.def_readwrite("stop_sequences", &llm::StoppingCriteria::stop_sequences);
.def_readwrite("max_tokens", &StoppingCriteria::max_tokens)
.def_readwrite("eos_token_id", &StoppingCriteria::eos_token_id)
.def_readwrite("ignore_eos_token", &StoppingCriteria::ignore_eos_token)
.def_readwrite("stop_token_ids", &StoppingCriteria::stop_token_ids);

// class LLM
py::class_<llm::LLM, std::shared_ptr<llm::LLM>>(m, "LLM")
py::class_<LLM, std::shared_ptr<LLM>>(m, "LLM")
.def(py::init<const std::string&,
const llm::SamplingParameter&,
const llm::StoppingCriteria&,
const SamplingParameter&,
const StoppingCriteria&,
int64_t,
const std::string>())
.def("generate", &llm::LLM::generate);

// function add
// m.def("add", &add, "A function which adds two numbers");
.def("generate", &LLM::generate);
}

} // namespace llm
149 changes: 0 additions & 149 deletions python/scalellm/llm.py

This file was deleted.

9 changes: 9 additions & 0 deletions python/scalellm/llm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class LLMEngine:
def __init__(self):
pass

async def schedule_request() -> None:
pass

async def cancel_request() -> None:
pass
20 changes: 20 additions & 0 deletions python/scalellm/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

# TODO: define following classes in c++ code
class SequenceOutput:
def __init__(self):
self.index = 0
self.text = ""
self.finish_reason = None

def __str__(self):
return f"index: {self.index}, text: {self.text}, finish_reason: {self.finish_reason}"


class RequestOutput:
def __init__(self):
self.id = ""
self.sequence_outpus = []
self.status = None
self.usage = None


2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def build_extension(self, ext: CMakeExtension):
packages=[
"scalellm",
],
ext_modules=[CMakeExtension("wrapper", "scalellm/")],
ext_modules=[CMakeExtension("_C", "scalellm/")],
cmdclass={"build_ext": CMakeBuild},
zip_safe=False,
package_data={
Expand Down
Loading

0 comments on commit 4e4eecc

Please sign in to comment.