diff --git a/api/ai.js b/api/ai.js new file mode 100644 index 000000000..d06414157 --- /dev/null +++ b/api/ai.js @@ -0,0 +1,110 @@ +// @ts-check +/** + * @module AI + * + * Provides high level classes for common AI tasks + * + * Example usage: + * ```js + * import { LLM } from 'socket:ai' + * ``` + */ +import ipc from './ipc.js' +import gc from './gc.js' +import { EventEmitter } from './events.js' +import { rand64 } from './crypto.js' +import * as exports from './ai.js' + +/** + * A class to interact with large language models (using llama.cpp) + * @extends EventEmitter + */ +export class LLM extends EventEmitter { + /** + * Constructs an LLM instance. + * @param {Object} options - The options for initializing the LLM. + * @param {string} options.path - The path to a valid model (.gguf). + * @param {string} options.prompt - The query that guides the model to generate a relevant and coherent responses. + * @param {string} [options.id] - The optional ID for the LLM instance. + * @throws {Error} If the model path is not provided. + */ + constructor (options = {}) { + super() + + if (!options.path) { + throw new Error('expected a path to a valid model (.gguf)') + } + + this.path = options.path + this.prompt = options.prompt + this.id = options.id || rand64() + + const opts = { + id: this.id, + prompt: this.prompt, + path: this.path + } + + globalThis.addEventListener('data', ({ detail }) => { + const { err, data, source } = detail.params + + if (err && BigInt(err.id) === this.id) { + return this.emit('error', err) + } + + if (!data || BigInt(data.id) !== this.id) return + + if (source === 'ai.llm.chat') { + if (data.complete) { + return this.emit('end') + } + + this.emit('data', data.token) + } + }) + + const result = ipc.sendSync('ai.llm.create', opts) + + if (result.err) { + throw result.err + } + } + + /** + * Tell the LLM to stop after the next token. + * @returns {Promise} A promise that resolves when the LLM stops. + */ + async stop () { + return ipc.request('ai.llm.stop', { id: this.id }) + } + + /** + * Implements `gc.finalizer` for gc'd resource cleanup. + * @param {Object} options - The options for finalizer. + * @returns {gc.Finalizer} The finalizer object. + * @ignore + */ + [gc.finalizer] (options) { + return { + args: [this.id, options], + async handle (id) { + if (process.env.DEBUG) { + console.warn('Closing Socket on garbage collection') + } + + await ipc.request('ai.llm.destroy', { id }, options) + } + } + } + + /** + * Send a message to the chat. + * @param {string} message - The message to send to the chat. + * @returns {Promise} A promise that resolves with the response from the chat. + */ + async chat (message) { + return ipc.request('ai.llm.chat', { id: this.id, message }) + } +} + +export default exports diff --git a/api/commonjs/builtins.js b/api/commonjs/builtins.js index 80f2f9d4d..066d1de15 100644 --- a/api/commonjs/builtins.js +++ b/api/commonjs/builtins.js @@ -12,6 +12,7 @@ import _async, { } from '../async.js' // eslint-disable-next-line +import * as ai from '../ai.js' import * as application from '../application.js' import assert from '../assert.js' import * as buffer from '../buffer.js' diff --git a/bin/build-runtime-library.sh b/bin/build-runtime-library.sh index 9c3208362..c1b632fe0 100755 --- a/bin/build-runtime-library.sh +++ b/bin/build-runtime-library.sh @@ -105,6 +105,11 @@ declare sources=( $(find "$root"/src/ipc/*.cc) $(find "$root"/src/platform/*.cc) $(find "$root"/src/serviceworker/*.cc) + "$root/build/llama/common/common.cpp" + "$root/build/llama/common/sampling.cpp" + "$root/build/llama/common/json-schema-to-grammar.cpp" + "$root/build/llama/common/grammar-parser.cpp" + "$root/build/llama/llama.cpp" "$root/src/window/manager.cc" "$root/src/window/dialog.cc" "$root/src/window/hotkey.cc" @@ -129,9 +134,11 @@ if [[ "$platform" = "android" ]]; then sources+=("$root/src/window/android.cc") elif [[ "$host" = "Darwin" ]]; then sources+=("$root/src/window/apple.mm") - if (( TARGET_OS_IPHONE)) || (( TARGET_IPHONE_SIMULATOR )); then - cflags=("-sdk" "iphoneos" "$clang") - clang="xcrun" + + if (( TARGET_OS_IPHONE)); then + clang="xcrun -sdk iphoneos "$clang"" + elif (( TARGET_IPHONE_SIMULATOR )); then + clang="xcrun -sdk iphonesimulator "$clang"" else sources+=("$root/src/core/process/unix.cc") fi @@ -154,11 +161,23 @@ mkdir -p "$output_directory" cd "$(dirname "$output_directory")" +sources+=("$output_directory/llama/build-info.cpp") + echo "# building runtime static libary ($arch-$platform)" for source in "${sources[@]}"; do declare src_directory="$root/src" + declare object="${source/.cc/$d.o}" - declare object="${object/$src_directory/$output_directory}" + object="${object/.cpp/$d.o}" + + declare build_dir="$root/build" + + if [[ "$object" =~ ^"$src_directory" ]]; then + object="${object/$src_directory/$output_directory}" + else + object="${object/$build_dir/$output_directory}" + fi + objects+=("$object") done @@ -166,6 +185,38 @@ if [[ -z "$ignore_header_mtimes" ]]; then test_headers+="$(find "$root/src"/core/*.hh)" fi +function generate_llama_build_info () { + build_number="0" + build_commit="unknown" + build_compiler="unknown" + build_target="unknown" + + if out=$(git rev-list --count HEAD); then + # git is broken on WSL so we need to strip extra newlines + build_number=$(printf '%s' "$out" | tr -d '\n') + fi + + if out=$(git rev-parse --short HEAD); then + build_commit=$(printf '%s' "$out" | tr -d '\n') + fi + + if out=$($clang --version | head -1); then + build_compiler=$out + fi + + if out=$($clang -dumpmachine); then + build_target=$out + fi + + echo "# generating llama build info" + cat > "$output_directory/llama/build-info.cpp" << LLAMA_BUILD_INFO + int LLAMA_BUILD_NUMBER = $build_number; + char const *LLAMA_COMMIT = "$build_commit"; + char const *LLAMA_COMPILER = "$build_compiler"; + char const *LLAMA_BUILD_TARGET = "$build_target"; +LLAMA_BUILD_INFO +} + function main () { trap onsignal INT TERM local i=0 @@ -177,6 +228,8 @@ function main () { cp -rf "$root/include"/* "$output_directory/include" rm -f "$output_directory/include/socket/_user-config-bytes.hh" + generate_llama_build_info + for source in "${sources[@]}"; do if (( ${#pids[@]} > max_concurrency )); then wait "${pids[0]}" 2>/dev/null @@ -185,9 +238,20 @@ function main () { { declare src_directory="$root/src" + declare object="${source/.cc/$d.o}" + object="${object/.cpp/$d.o}" + declare header="${source/.cc/.hh}" - declare object="${object/$src_directory/$output_directory}" + header="${header/.cpp/.h}" + + declare build_dir="$root/build" + + if [[ "$object" =~ ^"$src_directory" ]]; then + object="${object/$src_directory/$output_directory}" + else + object="${object/$build_dir/$output_directory}" + fi if (( force )) || ! test -f "$object" || @@ -197,7 +261,7 @@ function main () { then mkdir -p "$(dirname "$object")" echo "# compiling object ($arch-$platform) $(basename "$source")" - quiet "$clang" "${cflags[@]}" -c "$source" -o "$object" || onsignal + quiet $clang "${cflags[@]}" -c "$source" -o "$object" || onsignal echo "ok - built ${source/$src_directory\//} -> ${object/$output_directory\//} ($arch-$platform)" fi } & pids+=($!) diff --git a/bin/cflags.sh b/bin/cflags.sh index 806220e32..0f1c5ded8 100755 --- a/bin/cflags.sh +++ b/bin/cflags.sh @@ -52,6 +52,9 @@ cflags+=( -std=c++2a -I"$root/include" -I"$root/build/uv/include" + -I"$root/build" + -I"$root/build/llama" + -I"$root/build/llama/common" -I"$root/build/include" -DSOCKET_RUNTIME_BUILD_TIME="$(date '+%s')" -DSOCKET_RUNTIME_VERSION_HASH=$(git rev-parse --short=8 HEAD) diff --git a/bin/install.sh b/bin/install.sh index ccc496758..7c5a01fa5 100755 --- a/bin/install.sh +++ b/bin/install.sh @@ -220,12 +220,15 @@ function _build_cli { # local libs=($("echo" -l{socket-runtime})) local libs="" + # + # Add libuv, socket-runtime and llama + # if [[ "$(uname -s)" != *"_NT"* ]]; then - libs=($("echo" -l{uv,socket-runtime})) + libs=($("echo" -l{uv,llama,socket-runtime})) fi if [[ -n "$VERBOSE" ]]; then - echo "# cli libs: $libs, $(uname -s)" + echo "# cli libs: ${libs[@]}, $(uname -s)" fi local ldflags=($("$root/bin/ldflags.sh" --arch "$arch" --platform $platform ${libs[@]})) @@ -281,6 +284,7 @@ function _build_cli { test_sources+=("$static_libs") elif [[ "$(uname -s)" == "Linux" ]]; then static_libs+=("$BUILD_DIR/$arch-$platform/lib/libuv.a") + static_libs+=("$BUILD_DIR/$arch-$platform/lib/libllama.a") static_libs+=("$BUILD_DIR/$arch-$platform/lib/libsocket-runtime.a") fi @@ -512,7 +516,14 @@ function _prepare { mv "$tempmkl" "$BUILD_DIR/uv/CMakeLists.txt" fi - die $? "not ok - unable to clone. See trouble shooting guide in the README.md file" + die $? "not ok - unable to clone libuv. See trouble shooting guide in the README.md file" + fi + + if [ ! -d "$BUILD_DIR/llama" ]; then + git clone --depth=1 https://github.com/socketsupply/llama.cpp.git "$BUILD_DIR/llama" > /dev/null 2>&1 + rm -rf $BUILD_DIR/llama/.git + + die $? "not ok - unable to clone llama. See trouble shooting guide in the README.md file" fi echo "ok - directories prepared" @@ -764,6 +775,95 @@ function _compile_libuv_android { fi } +function _compile_llama { + target=$1 + hosttarget=$1 + platform=$2 + + if [ -z "$target" ]; then + target="$(host_arch)" + platform="desktop" + fi + + echo "# building llama.cpp for $platform ($target) on $host..." + STAGING_DIR="$BUILD_DIR/$target-$platform/llama" + + if [ ! -d "$STAGING_DIR" ]; then + mkdir -p "$STAGING_DIR" + cp -r "$BUILD_DIR"/llama/* "$STAGING_DIR" + cd "$STAGING_DIR" || exit 1 + else + cd "$STAGING_DIR" || exit 1 + fi + + mkdir -p "$STAGING_DIR/build/" + + if [ "$platform" == "desktop" ]; then + if [[ "$host" != "Win32" ]]; then + quiet cmake -S . -B build -DCMAKE_INSTALL_PREFIX="$BUILD_DIR/$target-$platform" + die $? "not ok - desktop configure" + + quiet cmake --build build --target clean + quiet cmake --build build -- -j"$CPU_CORES" + quiet cmake --install build + else + if ! test -f "$BUILD_DIR/$target-$platform/lib$d/libllama.lib"; then + local config="Release" + if [[ -n "$DEBUG" ]]; then + config="Debug" + fi + cd "$STAGING_DIR/build/" || exit 1 + quiet cmake -S .. -B . -DBUILD_TESTING=OFF -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_BUILD_SHARED=OFF + quiet cmake --build . --config $config + mkdir -p "$BUILD_DIR/$target-$platform/lib$d" + quiet echo "cp -up $STAGING_DIR/build/$config/libllama.lib "$BUILD_DIR/$target-$platform/lib$d/libllama.lib"" + cp -up "$STAGING_DIR/build/$config/libllama.lib" "$BUILD_DIR/$target-$platform/lib$d/libllama.lib" + if [[ -n "$DEBUG" ]]; then + cp -up "$STAGING_DIR"/build/$config/llama_a.pdb "$BUILD_DIR/$target-$platform/lib$d/llama_a.pdb" + fi; + fi + fi + + rm -f "$root/build/$(host_arch)-desktop/lib$d"/*.{so,la,dylib}* + return + fi + + if [ "$hosttarget" == "arm64" ]; then + hosttarget="arm" + fi + + local sdk="iphoneos" + [[ "$platform" == "iPhoneSimulator" ]] && sdk="iphonesimulator" + + export PLATFORM=$platform + export CC="$(xcrun -sdk $sdk -find clang)" + export CXX="$(xcrun -sdk $sdk -find clang++)" + export STRIP="$(xcrun -sdk $sdk -find strip)" + export LD="$(xcrun -sdk $sdk -find ld)" + export CPP="$CC -E" + export CFLAGS="-fembed-bitcode -arch ${target} -isysroot $PLATFORMPATH/$platform.platform/Developer/SDKs/$platform$SDKVERSION.sdk -m$sdk-version-min=$SDKMINVERSION" + export AR=$(xcrun -sdk $sdk -find ar) + export RANLIB=$(xcrun -sdk $sdk -find ranlib) + export CPPFLAGS="-fembed-bitcode -arch ${target} -isysroot $PLATFORMPATH/$platform.platform/Developer/SDKs/$platform$SDKVERSION.sdk -m$sdk-version-min=$SDKMINVERSION" + export LDFLAGS="-Wc,-fembed-bitcode -arch ${target} -isysroot $PLATFORMPATH/$platform.platform/Developer/SDKs/$platform$SDKVERSION.sdk" + + #if ! test -f CMakeLists.txt; then + quiet cmake -S . -B build -DCMAKE_INSTALL_PREFIX="$BUILD_DIR/$target-$platform" -DCMAKE_SYSTEM_NAME=iOS -DCMAKE_OSX_ARCHITECTURES="$target" -DCMAKE_OSX_SYSROOT=$(xcrun --sdk $sdk --show-sdk-path) + #fi + + if [ ! $? = 0 ]; then + echo "WARNING! - iOS will not be enabled. iPhone simulator not found, try \"sudo xcode-select --switch /Applications/Xcode.app\"." + return + fi + + cmake --build build -- -j"$CPU_CORES" + cmake --install build + + cd "$BUILD_DIR" || exit 1 + rm -f "$root/build/$target-$platform/lib$d"/*.{so,la,dylib}* + echo "ok - built for $target" +} + function _compile_libuv { target=$1 hosttarget=$1 @@ -840,6 +940,7 @@ function _compile_libuv { export PLATFORM=$platform export CC="$(xcrun -sdk $sdk -find clang)" + export CXX="$(xcrun -sdk $sdk -find clang++)" export STRIP="$(xcrun -sdk $sdk -find strip)" export LD="$(xcrun -sdk $sdk -find ld)" export CPP="$CC -E" @@ -912,6 +1013,11 @@ cd "$BUILD_DIR" || exit 1 trap onsignal INT TERM +{ + _compile_llama + echo "ok - built llama for $platform ($target)" +} & _compile_llama_pid=$! + # Although we're passing -j$CPU_CORES on non Win32, we still don't get max utiliztion on macos. Start this before fat libs. { _compile_libuv @@ -931,9 +1037,14 @@ if [[ "$(uname -s)" == "Darwin" ]] && [[ -z "$NO_IOS" ]]; then _setSDKVersion iPhoneOS _compile_libuv arm64 iPhoneOS & pids+=($!) + _compile_llama arm64 iPhoneOS & pids+=($!) + _compile_libuv x86_64 iPhoneSimulator & pids+=($!) + _compile_llama x86_64 iPhoneSimulator & pids+=($!) + if [[ "$arch" = "arm64" ]]; then _compile_libuv arm64 iPhoneSimulator & pids+=($!) + _compile_llama arm64 iPhoneSimulator & pids+=($!) fi for pid in "${pids[@]}"; do wait "$pid"; done @@ -951,6 +1062,7 @@ fi if [[ "$host" != "Win32" ]]; then # non windows hosts uses make -j$CPU_CORES, wait for them to finish. wait $_compile_libuv_pid + wait $_compile_llama_pid fi if [[ -n "$BUILD_ANDROID" ]]; then @@ -974,6 +1086,7 @@ _get_web_view2 if [[ "$host" == "Win32" ]]; then # Wait for Win32 lib uv build wait $_compile_libuv_pid + wait $_compile_llama_pid fi _check_compiler_features diff --git a/bin/ldflags.sh b/bin/ldflags.sh index 40097f8e1..abfdb1d14 100755 --- a/bin/ldflags.sh +++ b/bin/ldflags.sh @@ -109,6 +109,8 @@ if [[ "$host" = "Darwin" ]]; then ldflags+=("-framework" "Network") ldflags+=("-framework" "UniformTypeIdentifiers") ldflags+=("-framework" "WebKit") + ldflags+=("-framework" "Metal") + ldflags+=("-framework" "Accelerate") ldflags+=("-framework" "UserNotifications") ldflags+=("-framework" "OSLog") ldflags+=("-ldl") @@ -118,7 +120,7 @@ elif [[ "$host" = "Linux" ]]; then elif [[ "$host" = "Win32" ]]; then if [[ -n "$DEBUG" ]]; then # https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-170 - # TODO(@mribbons): Populate from vcvars64.bat + # TODO(@heapwolf): Populate from vcvars64.bat IFS=',' read -r -a libs <<< "$WIN_DEBUG_LIBS" for (( i = 0; i < ${#libs[@]}; ++i )); do ldflags+=("${libs[$i]}") diff --git a/socket-runtime.pc.in b/socket-runtime.pc.in index a608c699f..fd473a5c9 100644 --- a/socket-runtime.pc.in +++ b/socket-runtime.pc.in @@ -3,5 +3,5 @@ Version: {{VERSION}} Description: Build and package lean, fast, native desktop and mobile applications using the web technologies you already know. URL: https://github.com/socketsupply/socket Requires: {{DEPENDENCIES}} -Libs: -L{{LIB_DIRECTORY}} -lsocket-runtime -luv {{LDFLAGS}} +Libs: -L{{LIB_DIRECTORY}} -lsocket-runtime -luv -lllama {{LDFLAGS}} Cflags: -I{{INCLUDE_DIRECTORY}} {{CFLAGS}} diff --git a/src/cli/cli.cc b/src/cli/cli.cc index 2ea1e0b50..8e5bb63da 100644 --- a/src/cli/cli.cc +++ b/src/cli/cli.cc @@ -3075,6 +3075,8 @@ int main (const int argc, const char* argv[]) { flags += " -framework Network"; flags += " -framework UserNotifications"; flags += " -framework WebKit"; + flags += " -framework Metal"; + flags += " -framework Accelerate"; flags += " -framework Carbon"; flags += " -framework Cocoa"; flags += " -framework OSLog"; @@ -3089,6 +3091,7 @@ int main (const int argc, const char* argv[]) { flags += " -L" + prefixFile("lib/" + platform.arch + "-desktop"); flags += " -lsocket-runtime"; flags += " -luv"; + flags += " -lllama"; flags += " -I\"" + Path(paths.platformSpecificOutputPath / "include").string() + "\""; files += prefixFile("objects/" + platform.arch + "-desktop/desktop/main.o"); files += prefixFile("src/init.cc"); @@ -4411,6 +4414,8 @@ int main (const int argc, const char* argv[]) { " -framework CoreLocation" + " -framework Network" + " -framework UserNotifications" + + " -framework Metal" + + " -framework Accelerate" + " -framework WebKit" + " -framework Cocoa" + " -framework OSLog" @@ -4632,6 +4637,7 @@ int main (const int argc, const char* argv[]) { << " -L" + libdir << " -lsocket-runtime" << " -luv" + << " -lllama" << " -isysroot " << iosSdkPath << "/" << " -iframeworkwithsysroot /System/Library/Frameworks/" << " -F " << iosSdkPath << "/System/Library/Frameworks/" @@ -4641,6 +4647,8 @@ int main (const int argc, const char* argv[]) { << " -framework Foundation" << " -framework Network" << " -framework UserNotifications" + << " -framework Metal" + << " -framework Accelerate" << " -framework WebKit" << " -framework UIKit" << " -fembed-bitcode" @@ -5722,6 +5730,8 @@ int main (const int argc, const char* argv[]) { compilerFlags += " -framework CoreLocation"; compilerFlags += " -framework Network"; compilerFlags += " -framework UserNotifications"; + compilerFlags += " -framework Metal"; + compilerFlags += " -framework Accelerate"; compilerFlags += " -framework WebKit"; compilerFlags += " -framework Cocoa"; compilerFlags += " -framework OSLog"; @@ -6051,6 +6061,7 @@ int main (const int argc, const char* argv[]) { << " " << extraFlags #if defined(__linux__) << " -luv" + << " -lllama" << " -lsocket-runtime" #endif << (" -L" + quote + trim(prefixFile("lib/" + platform.arch + "-desktop")) + quote) diff --git a/src/core/core.hh b/src/core/core.hh index 8ac1e7027..bbd35ab12 100644 --- a/src/core/core.hh +++ b/src/core/core.hh @@ -22,6 +22,7 @@ #include "version.hh" #include "webview.hh" +#include "modules/ai.hh" #include "modules/child_process.hh" #include "modules/dns.hh" #include "modules/fs.hh" @@ -44,9 +45,9 @@ namespace SSC { class Core { public: - #if !SOCKET_RUNTIME_PLATFORM_IOS - using ChildProcess = CoreChildProcess; - #endif + #if !SOCKET_RUNTIME_PLATFORM_IOS + using ChildProcess = CoreChildProcess; + #endif using DNS = CoreDNS; using FS = CoreFS; using Geolocation = CoreGeolocation; @@ -56,6 +57,7 @@ namespace SSC { using Platform = CorePlatform; using Timers = CoreTimers; using UDP = CoreUDP; + using AI = CoreAI; #if !SOCKET_RUNTIME_PLATFORM_IOS ChildProcess childProcess; @@ -69,6 +71,7 @@ namespace SSC { Platform platform; Timers timers; UDP udp; + AI ai; Posts posts; @@ -86,25 +89,26 @@ namespace SSC { uv_async_t eventLoopAsync; Queue eventLoopDispatchQueue; - #if SOCKET_RUNTIME_PLATFORM_APPLE - dispatch_queue_attr_t eventLoopQueueAttrs = dispatch_queue_attr_make_with_qos_class( - DISPATCH_QUEUE_SERIAL, - QOS_CLASS_DEFAULT, - -1 - ); - - dispatch_queue_t eventLoopQueue = dispatch_queue_create( - "socket.runtime.core.loop.queue", - eventLoopQueueAttrs - ); - #else - Thread *eventLoopThread = nullptr; - #endif + #if SOCKET_RUNTIME_PLATFORM_APPLE + dispatch_queue_attr_t eventLoopQueueAttrs = dispatch_queue_attr_make_with_qos_class( + DISPATCH_QUEUE_SERIAL, + QOS_CLASS_DEFAULT, + -1 + ); + + dispatch_queue_t eventLoopQueue = dispatch_queue_create( + "socket.runtime.core.loop.queue", + eventLoopQueueAttrs + ); + #else + Thread *eventLoopThread = nullptr; + #endif Core () : - #if !SOCKET_RUNTIME_PLATFORM_IOS - childProcess(this), - #endif + #if !SOCKET_RUNTIME_PLATFORM_IOS + childProcess(this), + #endif + ai(this), dns(this), fs(this), geolocation(this), diff --git a/src/core/json.cc b/src/core/json.cc index 119720228..dcba8178f 100644 --- a/src/core/json.cc +++ b/src/core/json.cc @@ -75,6 +75,7 @@ namespace SSC::JSON { SSC::String String::str () const { auto escaped = replace(this->data, "\"", "\\\""); + escaped = replace(escaped, "\\n", "\\\\n"); return "\"" + replace(escaped, "\n", "\\n") + "\""; } diff --git a/src/core/modules/ai.cc b/src/core/modules/ai.cc new file mode 100644 index 000000000..26603f6b6 --- /dev/null +++ b/src/core/modules/ai.cc @@ -0,0 +1,759 @@ +#include "../core.hh" +#include "../resource.hh" +#include "ai.hh" + +namespace SSC { + static JSON::Object::Entries ERR_AI_LLM_NOEXISTS ( + const String& source, + CoreAI::ID id + ) { + return JSON::Object::Entries { + {"source", source}, + {"err", JSON::Object::Entries { + {"id", std::to_string(id)}, + {"type", "InternalError"}, + {"code", "ERR_AI_LLM_NOEXISTS"}, + {"message", "The requested LLM does not exist"} + }} + }; + } + + static JSON::Object::Entries ERR_AI_LLM_MESSAGE ( + const String& source, + CoreAI::ID id, + const String& message + ) { + return JSON::Object::Entries { + {"source", source}, + {"err", JSON::Object::Entries { + {"id", std::to_string(id)}, + {"type", "InternalError"}, + {"code", "ERR_AI_LLM_MESSAGE"}, + {"message", message} + }} + }; + } + + static JSON::Object::Entries ERR_AI_LLM_EXISTS ( + const String& source, + CoreAI::ID id + ) { + return JSON::Object::Entries { + {"source", source}, + {"err", JSON::Object::Entries { + {"id", std::to_string(id)}, + {"type", "InternalError"}, + {"code", "ERR_AI_LLM_EXISTS"}, + {"message", "The requested LLM already exists"} + }} + }; + } + + SharedPointer CoreAI::getLLM (ID id) { + if (!this->hasLLM(id)) return nullptr; + Lock lock(this->mutex); + return this->llms.at(id); + } + + bool CoreAI::hasLLM (ID id) { + Lock lock(this->mutex); + return this->llms.find(id) != this->llms.end(); + } + + void CoreAI::createLLM( + const String& seq, + ID id, + LLMOptions options, + const CoreModule::Callback& callback + ) { + if (this->hasLLM(id)) { + auto json = ERR_AI_LLM_EXISTS("ai.llm.create", id); + return callback(seq, json, Post{}); + } + + this->core->dispatchEventLoop([=, this] { + auto llm = new LLM(options); + if (llm->err.size()) { + auto json = ERR_AI_LLM_MESSAGE("ai.llm.create", id, llm->err); + return callback(seq, json, Post{}); + return; + } + + const auto json = JSON::Object::Entries { + {"source", "ai.llm.create"}, + {"data", JSON::Object::Entries { + {"id", std::to_string(id)}, + }} + }; + + callback(seq, json, Post{}); + Lock lock(this->mutex); + this->llms[id].reset(llm); + }); + }; + + void CoreAI::chatLLM( + const String& seq, + ID id, + String message, + const CoreModule::Callback& callback + ) { + this->core->dispatchEventLoop([=, this] { + if (!this->hasLLM(id)) { + auto json = ERR_AI_LLM_NOEXISTS("ai.llm.chat", id); + return callback(seq, json, Post{}); + } + + auto llm = this->getLLM(id); + + llm->chat(message, [=](auto self, auto token, auto isComplete) { + const auto json = JSON::Object::Entries { + {"source", "ai.llm.chat"}, + {"data", JSON::Object::Entries { + {"id", std::to_string(id)}, + {"token", token}, + {"complete", isComplete} + }} + }; + + callback("-1", json, Post{}); + + return isComplete; + }); + }); + }; + + void CoreAI::destroyLLM( + const String& seq, + ID id, + const CoreModule::Callback& callback + ) { + this->core->dispatchEventLoop([=, this] { + if (!this->hasLLM(id)) { + auto json = ERR_AI_LLM_NOEXISTS("ai.llm.destroy", id); + return callback(seq, json, Post{}); + } + + Lock lock(this->mutex); + auto llm = this->getLLM(id); + llm->stopped = true; + + while (llm->interactive) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + this->llms.erase(id); + }); + }; + + void CoreAI::stopLLM( + const String& seq, + ID id, + const CoreModule::Callback& callback + ) { + this->core->dispatchEventLoop([=, this] { + if (!this->hasLLM(id)) { + auto json = ERR_AI_LLM_NOEXISTS("ai.llm.stop", id); + return callback(seq, json, Post{}); + } + + auto llm = this->getLLM(id); + llm->stopped = true; // remains stopped until chat is called again. + }); + }; + + LLM::Logger LLM::log = nullptr; + + void LLM::tramp(ggml_log_level level, const char* message, void* user_data) { + if (LLM::log) LLM::log(level, message, user_data); + } + + void LLM::escape(String& input) { + std::size_t input_len = input.length(); + std::size_t output_idx = 0; + + for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { + if (input[input_idx] == '\\' && input_idx + 1 < input_len) { + switch (input[++input_idx]) { + case 'n': input[output_idx++] = '\n'; break; + case 'r': input[output_idx++] = '\r'; break; + case 't': input[output_idx++] = '\t'; break; + case '\'': input[output_idx++] = '\''; break; + case '\"': input[output_idx++] = '\"'; break; + case '\\': input[output_idx++] = '\\'; break; + case 'x': + // Handle \x12, etc + if (input_idx + 2 < input_len) { + const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 }; + char *err_p = nullptr; + const long val = std::strtol(x, &err_p, 16); + if (err_p == x + 2) { + input_idx += 2; + input[output_idx++] = char(val); + break; + } + } + // fall through + default: { + input[output_idx++] = '\\'; + input[output_idx++] = input[input_idx]; break; + } + } + } else { + input[output_idx++] = input[input_idx]; + } + } + + input.resize(output_idx); + } + + LLM::LLM (const LLMOptions options) { + // + // set up logging + // + LLM::log = [&](ggml_log_level level, const char* message, void* user_data) { + // std::cout << message << std::endl; + }; + + llama_log_set(LLM::tramp, nullptr); + + // + // set params and init the model and context + // + llama_backend_init(); + llama_numa_init(this->params.numa); + + llama_sampling_params& sparams = this->params.sparams; + this->sampling = llama_sampling_init(sparams); + + if (!this->sampling) this->err = "failed to initialize sampling subsystem"; + if (this->params.seed == LLAMA_DEFAULT_SEED) this->params.seed = time(nullptr); + + this->params.chatml = true; + this->params.prompt = "<|im_start|>system\n" + options.prompt + "<|im_end|>\n\n"; + this->params.n_ctx = 2048; + + FileResource resource(options.path); + + if (!resource.exists()) { + this->err = "Unable to access the model file due to permissions"; + return; + } + + this->params.model = options.path; + + std::tie(this->model, this->ctx) = llama_init_from_gpt_params(this->params); + + this->embd_inp = ::llama_tokenize(this->ctx, this->params.prompt.c_str(), true, true); + + // + // create a guidance context + // + if (sparams.cfg_scale > 1.f) { + struct llama_context_params lparams = llama_context_params_from_gpt_params(this->params); + this->guidance = llama_new_context_with_model(this->model, lparams); + } + + if (this->model == nullptr) { + this->err = "unable to load model"; + return; + } + + // + // determine the capacity of the model + // + const int n_ctx_train = llama_n_ctx_train(this->model); + const int n_ctx = llama_n_ctx(this->ctx); + + if (n_ctx > n_ctx_train) { + LOG("warning: model was trained on only %d context tokens (%d specified)\n", n_ctx_train, n_ctx); + } + + this->n_ctx = n_ctx; + + if (this->guidance) { + this->guidance_inp = ::llama_tokenize(this->guidance, sparams.cfg_negative_prompt, true, true); + std::vector original_inp = ::llama_tokenize(ctx, params.prompt.c_str(), true, true); + original_prompt_len = original_inp.size(); + guidance_offset = (int)this->guidance_inp.size() - original_prompt_len; + } + + // + // number of tokens to keep when resetting context + // + const bool add_bos = llama_should_add_bos_token(this->model); + GGML_ASSERT(llama_add_eos_token(this->model) != 1); + + if (this->params.n_keep < 0 || this->params.n_keep > (int)this->embd_inp.size() || this->params.instruct || this->params.chatml) { + this->params.n_keep = (int)this->embd_inp.size(); + } else if (add_bos) { + this->params.n_keep += add_bos; + } + + if (this->params.instruct) { + this->params.interactive_first = true; + this->params.antiprompt.emplace_back("### Instruction:\n\n"); + } else if (this->params.chatml) { + this->params.interactive_first = true; + this->params.antiprompt.emplace_back("<|im_start|>user\n"); + } else if (this->params.conversation) { + this->params.interactive_first = true; + } + + if (params.interactive_first) { + params.interactive = true; + } + + if (params.interactive) { + if (!params.antiprompt.empty()) { + for (const auto & antiprompt : params.antiprompt) { + LOG("Reverse prompt: '%s'\n", antiprompt.c_str()); + + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, antiprompt.c_str(), false, true); + + for (int i = 0; i < (int) tmp.size(); i++) { + LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } + } + } + + if (params.input_prefix_bos) { + LOG("Input prefix with BOS\n"); + } + + if (!params.input_prefix.empty()) { + LOG("Input prefix: '%s'\n", params.input_prefix.c_str()); + + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, params.input_prefix.c_str(), true, true); + + for (int i = 0; i < (int) tmp.size(); i++) { + LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } + } + + if (!params.input_suffix.empty()) { + LOG("Input suffix: '%s'\n", params.input_suffix.c_str()); + + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, params.input_suffix.c_str(), false, true); + + for (int i = 0; i < (int) tmp.size(); i++) { + LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } + } + } + + // + // initialize any anti-prompts sent in as params + // + this->antiprompt_ids.reserve(this->params.antiprompt.size()); + + for (const String& antiprompt : this->params.antiprompt) { + this->antiprompt_ids.emplace_back(::llama_tokenize(this->ctx, antiprompt.c_str(), false, true)); + } + + this->path_session = params.path_prompt_cache; + }; + + LLM::~LLM () { + llama_free(this->ctx); + llama_free(this->guidance); + llama_free_model(this->model); + llama_sampling_free(this->sampling); + llama_backend_free(); + }; + + void LLM::chat (String buffer, const Cb cb) { + this->stopped = false; + int ga_i = 0; + + const int ga_n = this->params.grp_attn_n; + const int ga_w = this->params.grp_attn_w; + + if (ga_n != 1) { + GGML_ASSERT(ga_n > 0 && "grp_attn_n must be positive"); + GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); + } + + this->interactive = this->params.interactive_first = true; + + bool display = true; + bool is_antiprompt = false; + bool input_echo = true; + + int n_past = 0; + int n_remain = this->params.n_predict; + int n_consumed = 0; + int n_session_consumed = 0; + int n_past_guidance = 0; + + std::vector input_tokens; + this->input_tokens = &input_tokens; + + std::vector output_tokens; + this->output_tokens = &output_tokens; + + std::ostringstream output_ss; + this->output_ss = &output_ss; + + std::vector embd; + std::vector embd_guidance; + + const int n_ctx = this->n_ctx; + const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true, true); + const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true); + + LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str()); + LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str()); + + const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", true, true); + const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true); + + while ((n_remain != 0 && !is_antiprompt) || params.interactive) { + if (!embd.empty()) { + int max_embd_size = n_ctx - 4; + + if ((int) embd.size() > max_embd_size) { + const int skipped_tokens = (int)embd.size() - max_embd_size; + embd.resize(max_embd_size); + LOG("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); + } + + if (ga_n == 1) { + if (n_past + (int) embd.size() + std::max(0, guidance_offset) >= n_ctx) { + if (params.n_predict == -2) { + LOG("\n\ncontext full and n_predict == -%d => stopping\n", this->params.n_predict); + break; + } + + const int n_left = n_past - this->params.n_keep; + const int n_discard = n_left/2; + + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, this->params.n_keep, this->params.n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, this->params.n_keep + n_discard, n_past, -n_discard); + + n_past -= n_discard; + + if (this->guidance) { + n_past_guidance -= n_discard; + } + + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(this->ctx, embd).c_str()); + LOG("clear session path\n"); + this->path_session.clear(); + } + } else { + while (n_past >= ga_i + ga_w) { + const int ib = (ga_n*ga_i)/ga_w; + const int bd = (ga_w/ga_n)*(ga_n - 1); + const int dd = (ga_w/ga_n) - ib*bd - ga_w; + + LOG("\n"); + LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd); + LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); + LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); + + llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); + llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); + + n_past -= bd; + + ga_i += ga_w/ga_n; + + LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i); + } + } + + if (n_session_consumed < (int) this->session_tokens.size()) { + size_t i = 0; + + for ( ; i < embd.size(); i++) { + if (embd[i] != this->session_tokens[n_session_consumed]) { + this->session_tokens.resize(n_session_consumed); + break; + } + + n_past++; + n_session_consumed++; + + if (n_session_consumed >= (int) this->session_tokens.size()) { + ++i; + break; + } + } + + if (i > 0) { + embd.erase(embd.begin(), embd.begin() + i); + } + } + + if (this->guidance) { + int input_size = 0; + llama_token * input_buf = nullptr; + + if (n_past_guidance < (int)this->guidance_inp.size()) { + embd_guidance = this->guidance_inp; + + if (embd.begin() + original_prompt_len < embd.end()) { + embd_guidance.insert(embd_guidance.end(), embd.begin() + original_prompt_len, embd.end()); + } + + input_buf = embd_guidance.data(); + input_size = embd_guidance.size(); + + LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str()); + } else { + input_buf = embd.data(); + input_size = embd.size(); + } + + for (int i = 0; i < input_size; i += params.n_batch) { + int n_eval = std::min(input_size - i, params.n_batch); + + if (llama_decode(this->guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) { + LOG("failed to eval\n"); + return; + } + + n_past_guidance += n_eval; + } + } + + for (int i = 0; i < (int) embd.size(); i += params.n_batch) { + int n_eval = (int) embd.size() - i; + + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + + LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { + LOG("%s : failed to eval\n", __func__); + return; + } + + n_past += n_eval; + } + + if (!embd.empty() && !this->path_session.empty()) { + this->session_tokens.insert(this->session_tokens.end(), embd.begin(), embd.end()); + n_session_consumed = this->session_tokens.size(); + } + } + + embd.clear(); + embd_guidance.clear(); + + if ((int)this->embd_inp.size() <= n_consumed && !interactive) { + const llama_token id = llama_sampling_sample(this->sampling, this->ctx, this->guidance); + llama_sampling_accept(this->sampling, this->ctx, id, true); + + LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(this->ctx, this->sampling->prev).c_str()); + embd.push_back(id); + + input_echo = true; + --n_remain; + + LOG("n_remain: %d\n", n_remain); + } else { + LOG("embd_inp.size(): %d, n_consumed: %d\n", (int)this->embd_inp.size(), n_consumed); + + while ((int)this->embd_inp.size() > n_consumed) { + embd.push_back(this->embd_inp[n_consumed]); + llama_sampling_accept(this->sampling, this->ctx, this->embd_inp[n_consumed], false); + + ++n_consumed; + if ((int) embd.size() >= params.n_batch) { + break; + } + } + } + + if (input_echo && display) { + for (auto id : embd) { + const String token_str = llama_token_to_piece(ctx, id, !params.conversation); + if (this->stopped) { + this->interactive = false; + return; + } + + cb(this, token_str, false); + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + + if (embd.size() > 1) { + input_tokens.push_back(id); + } else { + output_tokens.push_back(id); + output_ss << token_str; + } + } + } + + if (input_echo && (int)this->embd_inp.size() == n_consumed) { + display = true; + } + + if ((int)this->embd_inp.size() <= n_consumed) { + if (!params.antiprompt.empty()) { + const int n_prev = 32; + const String last_output = llama_sampling_prev_str(this->sampling, this->ctx, n_prev); + + is_antiprompt = false; + + for (String & antiprompt : this->params.antiprompt) { + size_t extra_padding = this->params.interactive ? 0 : 2; + size_t search_start_pos = last_output.length() > static_cast(antiprompt.length() + extra_padding) + ? last_output.length() - static_cast(antiprompt.length() + extra_padding) + : 0; + + if (last_output.find(antiprompt, search_start_pos) != String::npos) { + if (this->params.interactive) { + this->interactive = true; + } + + is_antiprompt = true; + break; + } + } + + llama_token last_token = llama_sampling_last(this->sampling); + for (std::vector ids : antiprompt_ids) { + if (ids.size() == 1 && last_token == ids[0]) { + if (this->params.interactive) { + this->interactive = true; + } + + is_antiprompt = true; + break; + } + } + + if (is_antiprompt) { + LOG("found antiprompt: %s\n", last_output.c_str()); + } + } + + if (llama_token_is_eog(model, llama_sampling_last(this->sampling))) { + LOG("found an EOG token\n"); + + if (this->params.interactive) { + if (!this->params.antiprompt.empty()) { + const auto first_antiprompt = ::llama_tokenize(this->ctx, this->params.antiprompt.front().c_str(), false, true); + this->embd_inp.insert(this->embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); + is_antiprompt = true; + } + + this->interactive = true; + } else if (this->params.instruct || this->params.chatml) { + this->interactive = true; + } + } + + if (n_past > 0 && this->interactive) { + LOG("waiting for user input\n"); + + if (this->params.input_prefix_bos) { + LOG("adding input prefix BOS token\n"); + this->embd_inp.push_back(llama_token_bos(this->model)); + } + + if (!params.input_prefix.empty() && !params.conversation) { + LOG("appending input prefix: '%s'\n", this->params.input_prefix.c_str()); + } + + if (buffer.length() > 1) { + if (!this->params.input_suffix.empty() && !this->params.conversation) { + LOG("appending input suffix: '%s'\n", this->params.input_suffix.c_str()); + } + + LOG("buffer: '%s'\n", buffer.c_str()); + + const size_t original_size = this->embd_inp.size(); + + if (this->params.instruct && !is_antiprompt) { + LOG("inserting instruction prefix\n"); + n_consumed = this->embd_inp.size(); + embd_inp.insert(this->embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + } + + if (params.chatml && !is_antiprompt) { + LOG("inserting chatml prefix\n"); + n_consumed = this->embd_inp.size(); + embd_inp.insert(this->embd_inp.end(), cml_pfx.begin(), cml_pfx.end()); + } + + if (params.escape) { + this->escape(buffer); + } + + const auto line_pfx = ::llama_tokenize(this->ctx, this->params.input_prefix.c_str(), false, true); + const auto line_inp = ::llama_tokenize(this->ctx, buffer.c_str(), false, params.interactive_specials); + const auto line_sfx = ::llama_tokenize(this->ctx, this->params.input_suffix.c_str(), false, true); + + LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(this->ctx, line_inp).c_str()); + + this->embd_inp.insert(this->embd_inp.end(), line_pfx.begin(), line_pfx.end()); + this->embd_inp.insert(this->embd_inp.end(), line_inp.begin(), line_inp.end()); + this->embd_inp.insert(this->embd_inp.end(), line_sfx.begin(), line_sfx.end()); + + if (this->params.instruct) { + LOG("inserting instruction suffix\n"); + this->embd_inp.insert(this->embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } + + if (this->params.chatml) { + LOG("inserting chatml suffix\n"); + this->embd_inp.insert(this->embd_inp.end(), cml_sfx.begin(), cml_sfx.end()); + } + + for (size_t i = original_size; i < this->embd_inp.size(); ++i) { + const llama_token token = this->embd_inp[i]; + this->output_tokens->push_back(token); + output_ss << llama_token_to_piece(this->ctx, token); + } + + n_remain -= line_inp.size(); + LOG("n_remain: %d\n", n_remain); + } else { + LOG("empty line, passing control back\n"); + } + + input_echo = false; + } + + if (n_past > 0) { + if (this->interactive) { + llama_sampling_reset(this->sampling); + } + this->interactive = false; + } + } + + if (llama_token_is_eog(this->model, embd.back())) { + if (this->stopped) { + this->interactive = false; + return; + } + + if (cb(this, "", true)) return; + } + + if (this->params.interactive && n_remain <= 0 && this->params.n_predict >= 0) { + n_remain = this->params.n_predict; + this->interactive = true; + } + } + } +} diff --git a/src/core/modules/ai.hh b/src/core/modules/ai.hh new file mode 100644 index 000000000..a204bf89c --- /dev/null +++ b/src/core/modules/ai.hh @@ -0,0 +1,114 @@ +#ifndef SOCKET_RUNTIME_CORE_AI_H +#define SOCKET_RUNTIME_CORE_AI_H + +#include "../module.hh" + +#include "llama/common/common.h" +#include "llama/llama.h" + +// #include +// #include +// #include +// #include + +#if defined (_WIN32) + #define WIN32_LEAN_AND_MEAN + + #ifndef NOMINMAX + #define NOMINMAX + #endif +#endif + +namespace SSC { + class LLM; + class Core; + + struct LLMOptions { + int attentionCapacity; + int seed; + String path; + String prompt; + }; + + class CoreAI : public CoreModule { + public: + using ID = uint64_t; + using LLMs = std::map>; + + Mutex mutex; + LLMs llms; + + void chatLLM ( + const String& seq, + ID id, + String message, + const CoreModule::Callback& callback + ); + + void createLLM ( + const String& seq, + ID id, + LLMOptions options, + const CoreModule::Callback& callback + ); + + void destroyLLM ( + const String& seq, + ID id, + const CoreModule::Callback& callback + ); + + void stopLLM ( + const String& seq, + ID id, + const CoreModule::Callback& callback + ); + + bool hasLLM (ID id); + SharedPointer getLLM (ID id); + + CoreAI (Core* core) + : CoreModule(core) + {} + }; + + class LLM { + using Cb = std::function; + using Logger = std::function; + + gpt_params params; + llama_model* model; + llama_context* ctx; + llama_context* guidance = nullptr; + struct llama_sampling_context* sampling; + + std::vector* input_tokens; + std::ostringstream* output_ss; + std::vector* output_tokens; + std::vector session_tokens; + std::vector embd_inp; + std::vector guidance_inp; + std::vector> antiprompt_ids; + + String path_session = ""; + int guidance_offset = 0; + int original_prompt_len = 0; + int n_ctx = 0; + + public: + String err = ""; + bool stopped = false; + bool interactive = false; + + void chat (String input, const Cb cb); + void escape (String& input); + + LLM(const LLMOptions options); + ~LLM(); + + static void tramp(ggml_log_level level, const char* message, void* user_data); + static Logger log; + }; +} + +#endif diff --git a/src/ipc/routes.cc b/src/ipc/routes.cc index 0e74c69f1..c5ee4cc95 100644 --- a/src/ipc/routes.cc +++ b/src/ipc/routes.cc @@ -50,6 +50,50 @@ static void mapIPCRoutes (Router *router) { ); #endif + /** + * AI + */ + router->map("ai.llm.create", [](auto message, auto router, auto reply) { + auto err = validateMessageParameters(message, {"id", "path", "prompt"}); + + if (err.type != JSON::Type::Null) { + return reply(Result::Err { message, err }); + } + + SSC::LLMOptions options; + options.path = message.get("path"); + options.prompt = message.get("prompt"); + + uint64_t modelId = 0; + REQUIRE_AND_GET_MESSAGE_VALUE(modelId, "id", std::stoull); + + router->bridge->core->ai.createLLM(message.seq, modelId, options, RESULT_CALLBACK_FROM_CORE_CALLBACK(message, reply)); + }); + + router->map("ai.llm.destroy", [](auto message, auto router, auto reply) { + auto err = validateMessageParameters(message, {"id"}); + uint64_t modelId = 0; + REQUIRE_AND_GET_MESSAGE_VALUE(modelId, "id", std::stoull); + router->bridge->core->ai.destroyLLM(message.seq, modelId, RESULT_CALLBACK_FROM_CORE_CALLBACK(message, reply)); + }); + + router->map("ai.llm.stop", [](auto message, auto router, auto reply) { + auto err = validateMessageParameters(message, {"id"}); + uint64_t modelId = 0; + REQUIRE_AND_GET_MESSAGE_VALUE(modelId, "id", std::stoull); + router->bridge->core->ai.stopLLM(message.seq, modelId, RESULT_CALLBACK_FROM_CORE_CALLBACK(message, reply)); + }); + + router->map("ai.llm.chat", [](auto message, auto router, auto reply) { + auto err = validateMessageParameters(message, {"id", "message"}); + + uint64_t modelId = 0; + REQUIRE_AND_GET_MESSAGE_VALUE(modelId, "id", std::stoull); + + auto value = message.get("message"); + router->bridge->core->ai.chatLLM(message.seq, modelId, value, RESULT_CALLBACK_FROM_CORE_CALLBACK(message, reply)); + }); + /** * Attemps to exit the application * @param value The exit code