diff --git a/examples/llm_manual/CMakeLists.txt b/examples/llm_manual/CMakeLists.txt new file mode 100644 index 00000000000..c605e947409 --- /dev/null +++ b/examples/llm_manual/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +project(nanogpt_runner) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED True) + +# Set options for executorch build. +option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON) +option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON) +option(EXECUTORCH_BUILD_OPTIMIZED "" ON) +option(EXECUTORCH_BUILD_XNNPACK "" ON) # Build with Xnnpack backend + +# Include the executorch subdirectory. +add_subdirectory( + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch + ${CMAKE_BINARY_DIR}/executorch) + +# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) + +add_executable(nanogpt_runner main.cpp) +target_link_libraries( + nanogpt_runner + PRIVATE + executorch + extension_module_static # Provides the Module class + optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels + xnnpack_backend) # Provides the XNNPACK CPU acceleration backend diff --git a/examples/llm_manual/README.md b/examples/llm_manual/README.md new file mode 100644 index 00000000000..0ee6bb6a9f1 --- /dev/null +++ b/examples/llm_manual/README.md @@ -0,0 +1,3 @@ +# LLM Manual + +This repository is a storage place for the files that [LLM Maunal](https://pytorch.org/executorch/main/llm/getting-started.html) needs. Please refer to the documentation website for more information. diff --git a/examples/llm_manual/basic_sampler.h b/examples/llm_manual/basic_sampler.h new file mode 100644 index 00000000000..a95b823de8d --- /dev/null +++ b/examples/llm_manual/basic_sampler.h @@ -0,0 +1,20 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +class BasicSampler { + public: + BasicSampler() {} + int64_t sample(std::vector logits) { + // Find the token with the highest log probability. + int64_t max_index = + std::max_element(logits.begin(), logits.end()) - logits.begin(); + return max_index; + } +}; diff --git a/examples/llm_manual/basic_tokenizer.h b/examples/llm_manual/basic_tokenizer.h new file mode 100644 index 00000000000..eb51d15fc50 --- /dev/null +++ b/examples/llm_manual/basic_tokenizer.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +class BasicTokenizer { + public: + BasicTokenizer(const std::string& filePath) { + std::ifstream file(filePath); + + if (!file) { + std::cerr << "Unable to open file"; + exit(9); // return with error code + } + std::string str( + (std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + + size_t i = 0u; + i = consume_whitespace(str, i); + i = expect(str, i, '{'); + + while (i < str.size() && str[i] != '}') { + i = consume_field(str, i); + } + + // Build decode map as inverse of encode. + for (auto& i : encode_) { + decode_[i.second] = i.first; + } + } + + std::vector encode(const std::string& prompt) { + std::vector words = parse_prompt(prompt); + std::vector result; + for (auto word : words) { + result.push_back(encode_[word]); + } + return result; + } + + std::string decode(const std::vector& indices) { + std::string result; + for (const auto& index : indices) { + result += decode_[index]; + } + return result; + } + + private: + std::unordered_map encode_; + std::unordered_map decode_; + + // Advance the input string index until a non-whitespace character is found + // or it reaches the end of string. + size_t consume_whitespace(const std::string& data, size_t i) { + while (i < data.size() && std::isspace(data[i])) { + i++; + } + + return i; + } + + // Consumes an JSON field of the form + // "str": id, + size_t consume_field(const std::string& data, size_t i) { + i = consume_whitespace(data, i); + + // Parse the key literal. + i = expect(data, i, '"'); + + auto in_escape = false; + std::string key = ""; + while (i < data.size()) { + if (in_escape) { + key += data[i]; + i++; + in_escape = false; + } else { // !in_escape + if (data[i] == '"') { // End of string literal + i++; + break; + } else if (data[i] == '\\') { // Escaped code point + in_escape = true; + } + key += data[i]; + i++; + } + } + + key = post_process_key(key); + + i = expect(data, i, ':'); + i = consume_whitespace(data, i); + + // Read unsigned integer value + auto value_start = i; + while (i < data.size() && std::isdigit(data[i])) { + i++; + } + auto value = static_cast( + std::stol(data.substr(value_start, i - value_start))); + + encode_[key] = value; + + i = consume_whitespace(data, i); + if (i < data.size() && data[i] == ',') { + i++; + } + + return i; + } + + // Assert that the next character in the input string is equal to c. Increment + // the input string index by one. + size_t expect(const std::string& data, size_t i, char c) { + if (i >= data.size() || data[i] != c) { + std::cerr << "Invalid tokenizer vocabulary file. Expected '" << c + << "' at index " << i << std::endl; + exit(1); + } + + return i + 1; + } + + std::string post_process_key(std::string key) { + // Replace the unicode characters with the corresponding byte encoding + // TODO: adopt byte encoder to handle unicode characters in json file. + + std::unordered_map replacements = { + {"\\u0120", " "}, + {"\\u010a", "\n"}, + }; + + for (const auto& replacement : replacements) { + size_t pos = 0; + // While loop through all instances of the substring in the string + while ((pos = key.find(replacement.first, pos)) != std::string::npos) { + key.replace(pos, replacement.first.length(), replacement.second); + pos += replacement.second.length(); + } + } + + // remove duplicate backslashes + for (size_t idx = 0; idx < key.length(); idx++) { + if (key[idx] == '\\') { + key.erase(idx, 1); + if (key[idx] == '\\') { + // If there are two backslashes, keep the second one + idx += 1; + } + } + } + + return key; + } + std::vector parse_prompt(const std::string& prompt) { + std::vector result; + std::string word; + for (char c : prompt) { + if (c == ' ') { + if (!word.empty()) { + result.push_back(word); + word.clear(); + } + word += c; + } else if (ispunct(c)) { + if (!word.empty()) { + result.push_back(word); + word.clear(); + } + result.push_back(std::string(1, c)); + } else { + word += c; + } + } + if (!word.empty()) { + result.push_back(word); + } + return result; + } +}; diff --git a/examples/llm_manual/export_nanogpt.py b/examples/llm_manual/export_nanogpt.py new file mode 100644 index 00000000000..cf29a69c080 --- /dev/null +++ b/examples/llm_manual/export_nanogpt.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# export_nanogpt.py + +# Load partitioner for Xnnpack backend +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + +# Model to be delegated to specific backend should use specific edge compile config +from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config +from executorch.exir import to_edge + +from model import GPT +from torch._export import capture_pre_autograd_graph +from torch.export import export +from torch.nn.attention import sdpa_kernel, SDPBackend + +model = GPT.from_pretrained("gpt2") # use gpt2 weight as pretrained weight +example_inputs = ( + torch.randint(0, 100, (1, model.config.block_size), dtype=torch.long), +) +dynamic_shape = ({1: torch.export.Dim("token_dim", max=model.config.block_size)},) + +# Trace the model, converting it to a portable intermediate representation. +# The torch.no_grad() call tells PyTorch to exclude training-specific logic. +with sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + m = capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shape) + traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape) + +# Convert the model into a runnable ExecuTorch program. +# To be further lowered to Xnnpack backend, `traced_model` needs xnnpack-specific edge compile config +edge_config = get_xnnpack_edge_compile_config() +edge_manager = to_edge(traced_model, compile_config=edge_config) + +# Delegate exported model to Xnnpack backend by invoking `to_backend` function with Xnnpack partitioner. +edge_manager = edge_manager.to_backend(XnnpackPartitioner()) +et_program = edge_manager.to_executorch() + +# Save the Xnnpack-delegated ExecuTorch program to a file. +with open("nanogpt.pte", "wb") as file: + file.write(et_program.buffer) diff --git a/examples/llm_manual/main.cpp b/examples/llm_manual/main.cpp new file mode 100644 index 00000000000..2b336059cff --- /dev/null +++ b/examples/llm_manual/main.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// main.cpp + +#include +#include +#include +#include + +#include "basic_sampler.h" +#include "basic_tokenizer.h" +#include "managed_tensor.h" + +#include +#include +#include +#include +#include + +using namespace torch::executor; + +using SizesType = exec_aten::SizesType; +using DimOrderType = exec_aten::DimOrderType; +using StridesType = exec_aten::StridesType; + +// main.cpp + +#define ENDOFTEXT 50256 + +std::string generate( + Module& llm_model, + std::string& prompt, + BasicTokenizer& tokenizer, + BasicSampler& sampler, + size_t max_input_length, + size_t max_output_length) { + // Convert the input text into a list of integers (tokens) that represents + // it, using the string-to-token mapping that the model was trained on. + // Each token is an integer that represents a word or part of a word. + std::vector input_tokens = tokenizer.encode(prompt); + std::vector output_tokens; + + for (auto i = 0u; i < max_output_length; i++) { + // Convert the input_tokens from a vector of int64_t to EValue. + // EValue is a unified data type in the ExecuTorch runtime. + ManagedTensor tensor_tokens( + input_tokens.data(), + {1, static_cast(input_tokens.size())}, + ScalarType::Long); + std::vector inputs = {tensor_tokens.get_tensor()}; + + // Run the model. It will return a tensor of logits (log-probabilities). + Result> logits_evalue = llm_model.forward(inputs); + + // Convert the output logits from EValue to std::vector, which is what + // the sampler expects. + Tensor logits_tensor = logits_evalue.get()[0].toTensor(); + std::vector logits( + logits_tensor.data_ptr(), + logits_tensor.data_ptr() + logits_tensor.numel()); + + // Sample the next token from the logits. + int64_t next_token = sampler.sample(logits); + + // Break if we reached the end of the text. + if (next_token == ENDOFTEXT) { + break; + } + + // Add the next token to the output. + output_tokens.push_back(next_token); + + std::cout << tokenizer.decode({next_token}); + std::cout.flush(); + + // Update next input. + input_tokens.push_back(next_token); + if (input_tokens.size() > max_input_length) { + input_tokens.erase(input_tokens.begin()); + } + } + + std::cout << std::endl; + + // Convert the output tokens into a human-readable string. + std::string output_string = tokenizer.decode(output_tokens); + return output_string; +} + +// main.cpp + +int main() { + // Set up the prompt. This provides the seed text for the model to elaborate. + std::cout << "Prompt: "; + std::string prompt; + std::getline(std::cin, prompt); + + // The tokenizer is used to convert between tokens (used by the model) and + // human-readable strings. + BasicTokenizer tokenizer("vocab.json"); + + // The sampler is used to sample the next token from the logits. + BasicSampler sampler = BasicSampler(); + + // Load the exported nanoGPT program, which was generated via the previous + // steps. + Module model( + "nanogpt.pte", + torch::executor::Module::MlockConfig::UseMlockIgnoreErrors); + + const auto max_input_tokens = 1024; + const auto max_output_tokens = 30; + std::cout << prompt; + generate( + model, prompt, tokenizer, sampler, max_input_tokens, max_output_tokens); +} diff --git a/examples/llm_manual/managed_tensor.h b/examples/llm_manual/managed_tensor.h new file mode 100644 index 00000000000..d401ae4d18b --- /dev/null +++ b/examples/llm_manual/managed_tensor.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#pragma once + +namespace torch { +namespace executor { + +/** + * A tensor wrapper takes ownership of all the memory of the necessary metadata + * for torch::executor::Tensor. Note that it doesn't own the data memory. + */ +class ManagedTensor { + public: + /// The type used for elements of `sizes()`. + using SizesType = exec_aten::SizesType; + /// The type used for elements of `dim_order()`. + using DimOrderType = exec_aten::DimOrderType; + /// The type used for elements of `strides()`. + using StridesType = exec_aten::StridesType; + ManagedTensor() = delete; + + explicit ManagedTensor( + void* data, + const std::vector& sizes, + ScalarType dtype) + : dtype_(dtype), sizes_(sizes), data_ptr_(data) { + ssize_t dim = sizes.size(); + dim_order_.resize(dim); + strides_.resize(dim); + for (size_t i = 0; i < dim; ++i) { + dim_order_[i] = i; + } + dim_order_to_stride_nocheck( + sizes.data(), dim_order_.data(), dim, strides_.data()); + tensor_impl_ = std::make_unique( + dtype_, + dim, + sizes_.data(), + data_ptr_, + dim_order_.data(), + strides_.data(), + TensorShapeDynamism::DYNAMIC_BOUND); + } + + /** + * Get the Tensor object managed by this class. + */ + Tensor get_tensor() { + return Tensor(tensor_impl_.get()); + } + + private: + void* data_ptr_ = nullptr; + std::unique_ptr tensor_impl_; + std::vector sizes_; + std::vector strides_; + std::vector dim_order_; + ScalarType dtype_; +}; +} // namespace executor +} // namespace torch