From 52421cd99fc1c79f14fa212908c15ebd590334c1 Mon Sep 17 00:00:00 2001 From: Roger Zurawicki Date: Mon, 20 Feb 2023 15:51:44 +0100 Subject: [PATCH 1/5] Add LLM direcotry --- src/actions/prepare_commit_msg.rs | 2 +- src/llms/mod.rs | 1 + src/{ => llms}/openai.rs | 0 src/main.rs | 2 +- src/summarize.rs | 2 +- 5 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 src/llms/mod.rs rename src/{ => llms}/openai.rs (100%) diff --git a/src/actions/prepare_commit_msg.rs b/src/actions/prepare_commit_msg.rs index c7a6bcb..caaee9b 100644 --- a/src/actions/prepare_commit_msg.rs +++ b/src/actions/prepare_commit_msg.rs @@ -17,7 +17,7 @@ use tokio::task::JoinSet; use crate::git; use crate::help::print_help_openai_api_key; -use crate::openai::OpenAIClient; +use crate::llms::openai::OpenAIClient; use crate::settings::Settings; use crate::summarize::SummarizationClient; diff --git a/src/llms/mod.rs b/src/llms/mod.rs new file mode 100644 index 0000000..1e30e94 --- /dev/null +++ b/src/llms/mod.rs @@ -0,0 +1 @@ +pub(crate) mod openai; diff --git a/src/openai.rs b/src/llms/openai.rs similarity index 100% rename from src/openai.rs rename to src/llms/openai.rs diff --git a/src/main.rs b/src/main.rs index 23884c4..82df57c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,6 @@ extern crate lazy_static; extern crate log; mod cmd; mod git; -mod openai; mod prompt; mod summarize; mod util; @@ -16,6 +15,7 @@ use clap::{Parser, Subcommand}; mod actions; mod help; +mod llms; mod settings; use log::LevelFilter; diff --git a/src/summarize.rs b/src/summarize.rs index e5d91b9..b44fb5a 100644 --- a/src/summarize.rs +++ b/src/summarize.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::{openai::OpenAIClient, prompt::format_prompt, settings::PromptSettings}; +use crate::{llms::openai::OpenAIClient, prompt::format_prompt, settings::PromptSettings}; use anyhow::Result; #[derive(Clone, Debug)] From ab72d2e7559fd4fea8328f73de4914baba695989 Mon Sep 17 00:00:00 2001 From: Roger Zurawicki Date: Mon, 20 Feb 2023 16:15:24 +0100 Subject: [PATCH 2/5] Add LlmClient trait --- Cargo.lock | 1 + Cargo.toml | 1 + src/actions/prepare_commit_msg.rs | 10 +++++++--- src/llms/base_llm.rs | 8 ++++++++ src/llms/mod.rs | 1 + src/llms/openai.rs | 32 ++++++++++++++++++------------- src/summarize.rs | 15 ++++++--------- 7 files changed, 43 insertions(+), 25 deletions(-) create mode 100644 src/llms/base_llm.rs diff --git a/Cargo.lock b/Cargo.lock index 376c3e3..c0ce5f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -521,6 +521,7 @@ name = "gptcommit" version = "0.1.17" dependencies = [ "anyhow", + "async-trait", "clap", "colored", "config", diff --git a/Cargo.toml b/Cargo.toml index b23c667..948d6b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ path = "src/main.rs" [dependencies] anyhow = "1.0.69" +async-trait = "0.1.64" clap = { version = "4.1.6", features = ["derive"] } colored = "2.0.0" config = { version = "0.13.3", features = ["toml"] } diff --git a/src/actions/prepare_commit_msg.rs b/src/actions/prepare_commit_msg.rs index caaee9b..3421771 100644 --- a/src/actions/prepare_commit_msg.rs +++ b/src/actions/prepare_commit_msg.rs @@ -17,6 +17,7 @@ use tokio::task::JoinSet; use crate::git; use crate::help::print_help_openai_api_key; +use crate::llms::base_llm::LlmClient; use crate::llms::openai::OpenAIClient; use crate::settings::Settings; @@ -32,8 +33,8 @@ use crate::util::SplitPrefixInclusive; /// The function assumes that the file_diff input is well-formed /// according to the Diff format described in the Git documentation: /// https://git-scm.com/docs/git-diff -async fn process_file_diff( - summarize_client: SummarizationClient, +async fn process_file_diff( + summarize_client: SummarizationClient, file_diff: &str, ) -> Option<(String, String)> { if let Some(file_name) = util::get_file_name_from_diff(file_diff) { @@ -81,7 +82,10 @@ pub(crate) struct PrepareCommitMsgArgs { git_diff_content: Option, } -async fn get_commit_message(client: SummarizationClient, diff_as_input: &str) -> Result { +async fn get_commit_message( + client: SummarizationClient, + diff_as_input: &str, +) -> Result { let file_diffs = diff_as_input.split_prefix_inclusive("\ndiff --git "); let mut set = JoinSet::new(); diff --git a/src/llms/base_llm.rs b/src/llms/base_llm.rs new file mode 100644 index 0000000..8f380c1 --- /dev/null +++ b/src/llms/base_llm.rs @@ -0,0 +1,8 @@ +use async_trait::async_trait; +use anyhow::Result; + +#[async_trait] +pub trait LlmClient { + /// It takes a prompt as input, and returns the completion using an external Large Language Model. + async fn completions(&self, prompt: &str) -> Result; +} diff --git a/src/llms/mod.rs b/src/llms/mod.rs index 1e30e94..34b3087 100644 --- a/src/llms/mod.rs +++ b/src/llms/mod.rs @@ -1 +1,2 @@ +pub(crate) mod base_llm; pub(crate) mod openai; diff --git a/src/llms/openai.rs b/src/llms/openai.rs index 5a692e3..ef0565c 100644 --- a/src/llms/openai.rs +++ b/src/llms/openai.rs @@ -2,12 +2,15 @@ use std::time::Duration; use anyhow::{anyhow, bail, Result}; +use async_trait::async_trait; use reqwest::{Client, ClientBuilder}; use serde_json::{json, Value}; use tiktoken_rs::tiktoken::{p50k_base, CoreBPE}; use crate::settings::OpenAISettings; +use super::base_llm::LlmClient; + #[derive(Clone, Debug)] pub(crate) struct OpenAIClient { api_key: String, @@ -35,9 +38,24 @@ impl OpenAIClient { }) } + pub(crate) fn get_prompt_token_limit_for_model(&self) -> usize { + match self.model.as_str() { + "text-davinci-003" => 4097, + "text-curie-001" => 2048, + "text-babbage-001" => 2048, + "text-ada-001" => 2048, + "code-davinci-002" => 8000, + "code-cushman-001" => 2048, + _ => 4097, + } + } +} + +#[async_trait] +impl LlmClient for OpenAIClient { /// Sends a request to OpenAI's API to get a text completion. /// It takes a prompt as input, and returns the completion. - pub(crate) async fn completions(&self, prompt: &str) -> Result { + async fn completions(&self, prompt: &str) -> Result { let prompt_token_limit = self.get_prompt_token_limit_for_model(); lazy_static! { static ref BPE_TOKENIZER: CoreBPE = p50k_base().unwrap(); @@ -87,16 +105,4 @@ impl OpenAIClient { .trim() .to_string()) } - - pub(crate) fn get_prompt_token_limit_for_model(&self) -> usize { - match self.model.as_str() { - "text-davinci-003" => 4097, - "text-curie-001" => 2048, - "text-babbage-001" => 2048, - "text-ada-001" => 2048, - "code-davinci-002" => 8000, - "code-cushman-001" => 2048, - _ => 4097, - } - } } diff --git a/src/summarize.rs b/src/summarize.rs index b44fb5a..6f7bc5e 100644 --- a/src/summarize.rs +++ b/src/summarize.rs @@ -1,22 +1,19 @@ use std::collections::HashMap; -use crate::{llms::openai::OpenAIClient, prompt::format_prompt, settings::PromptSettings}; +use crate::llms::base_llm::LlmClient; +use crate::{prompt::format_prompt, settings::PromptSettings}; use anyhow::Result; - #[derive(Clone, Debug)] -pub(crate) struct SummarizationClient { - client: OpenAIClient, +pub(crate) struct SummarizationClient { + client: T, prompt_file_diff: String, prompt_commit_summary: String, prompt_commit_title: String, } -impl SummarizationClient { - pub(crate) fn new( - settings: PromptSettings, - client: OpenAIClient, - ) -> Result { +impl SummarizationClient { + pub(crate) fn new(settings: PromptSettings, client: T) -> Result { let prompt_file_diff = settings.file_diff.unwrap_or_default(); let prompt_commit_summary = settings.commit_summary.unwrap_or_default(); let prompt_commit_title = settings.commit_title.unwrap_or_default(); From c7bb37e12cb4e25c00b9e3d7c09b0bd0f9360586 Mon Sep 17 00:00:00 2001 From: Roger Zurawicki Date: Mon, 20 Feb 2023 16:38:26 +0100 Subject: [PATCH 3/5] Update testing --- Justfile | 4 ++-- e2e/test_githook.sh | 13 +++++++------ e2e/test_install.sh | 25 +++++++++++++------------ 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/Justfile b/Justfile index d55b997..8a9517a 100644 --- a/Justfile +++ b/Justfile @@ -19,13 +19,13 @@ release: cargo build --release install: - cargo install --path . + cargo install --path . --offline e2e: install sh -eux -c 'for i in ./e2e/test_*.sh ; do sh -x "$i" ; done' test *args: e2e - cargo test + cargo test --offline alias t := test lint: diff --git a/e2e/test_githook.sh b/e2e/test_githook.sh index 282c268..7c1363e 100755 --- a/e2e/test_githook.sh +++ b/e2e/test_githook.sh @@ -1,10 +1,11 @@ #!/bin/sh set -eu +DIFF_CONTENT_PATH="$(pwd)/tests/data/example_1.diff" + +export TEMPDIR=$(mktemp -d) ( - rm -rf test_dir_foo4 - mkdir test_dir_foo4 - cd test_dir_foo4 + cd "${TEMPDIR}" git init export TEMPFILE=$(mktemp) @@ -12,10 +13,10 @@ set -eu GPTCOMMIT__OPENAI__MODEL="text-ada-001" \ gptcommit prepare-commit-msg \ - --git-diff-content ../tests/data/example_1.diff \ - --commit-msg-file $TEMPFILE \ + --git-diff-content "${DIFF_CONTENT_PATH}" \ + --commit-msg-file "${TEMPFILE}" \ --commit-source "" cat $TEMPFILE ) -rm -rf test_dir_foo4 +rm -rf "${TEMPDIR}" diff --git a/e2e/test_install.sh b/e2e/test_install.sh index b0f020d..a5be1e8 100755 --- a/e2e/test_install.sh +++ b/e2e/test_install.sh @@ -1,10 +1,9 @@ #!/bin/sh set -eu +export TEMPDIR=$(mktemp -d) ( - rm -rf test_dir_foo - mkdir test_dir_foo - cd test_dir_foo + cd "${TEMPDIR}" git init gptcommit install @@ -12,28 +11,30 @@ set -eu gptcommit install # assert still works ) -rm -rf test_dir_foo ; +rm -rf "${TEMPDIR}" ############################# +export TEMPDIR=$(mktemp -d) ( - rm -rf test_dir_foo2 - mkdir test_dir_foo2 - cd test_dir_foo2 + cd "${TEMPDIR}" git init mkdir a cd a gptcommit install ) -rm -rf test_dir_foo2 +rm -rf "${TEMPDIR}" ############################# +export TEMPDIR=$(mktemp -d) ( - rm -rf test_dir_foo3 - mkdir test_dir_foo3 - cd test_dir_foo3 + cd "${TEMPDIR}" # no git init + set +e gptcommit install ; + # TODO assert output + test $? -ne 0 || exit $? + set -e ) -rm -rf test_dir_foo3 +rm -rf "${TEMPDIR}" From 710578bc323f1bcc7493d8d89a0619b6dff48493 Mon Sep 17 00:00:00 2001 From: Roger Zurawicki Date: Mon, 20 Feb 2023 16:50:35 +0100 Subject: [PATCH 4/5] Create dummy LLM client for testing --- Cargo.lock | 245 ++++++++++++++++++++++++ Cargo.toml | 3 + e2e/test_githook.sh | 2 +- src/actions/prepare_commit_msg.rs | 109 +++-------- src/llms/{base_llm.rs => llm_client.rs} | 6 +- src/llms/mod.rs | 3 +- src/llms/openai.rs | 2 +- src/llms/tester_foobar.rs | 41 ++++ src/settings.rs | 11 +- src/summarize.rs | 84 +++++++- 10 files changed, 406 insertions(+), 100 deletions(-) rename src/llms/{base_llm.rs => llm_client.rs} (78%) create mode 100644 src/llms/tester_foobar.rs diff --git a/Cargo.lock b/Cargo.lock index c0ce5f1..5f2cf6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,6 +49,17 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" +[[package]] +name = "async-channel" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf46fee83e5ccffc220104713af3292ff9bc7c64c7de289f66dae8e38d826833" +dependencies = [ + "concurrent-queue", + "event-listener", + "futures-core", +] + [[package]] name = "async-compression" version = "0.3.15" @@ -63,6 +74,97 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-executor" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17adb73da160dfb475c183343c8cccd80721ea5a605d3eb57125f0a7b7a92d0b" +dependencies = [ + "async-lock", + "async-task", + "concurrent-queue", + "fastrand", + "futures-lite", + "slab", +] + +[[package]] +name = "async-global-executor" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1b6f5d7df27bd294849f8eec66ecfc63d11814df7a4f5d74168a2394467b776" +dependencies = [ + "async-channel", + "async-executor", + "async-io", + "async-lock", + "blocking", + "futures-lite", + "once_cell", +] + +[[package]] +name = "async-io" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c374dda1ed3e7d8f0d9ba58715f924862c63eae6849c92d3a18e7fbde9e2794" +dependencies = [ + "async-lock", + "autocfg", + "concurrent-queue", + "futures-lite", + "libc", + "log", + "parking", + "polling", + "slab", + "socket2", + "waker-fn", + "windows-sys 0.42.0", +] + +[[package]] +name = "async-lock" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8101efe8695a6c17e02911402145357e718ac92d3ff88ae8419e84b1707b685" +dependencies = [ + "event-listener", + "futures-lite", +] + +[[package]] +name = "async-std" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62565bb4402e926b29953c785397c6dc0391b7b446e45008b0049eb43cec6f5d" +dependencies = [ + "async-channel", + "async-global-executor", + "async-io", + "async-lock", + "crossbeam-utils", + "futures-channel", + "futures-core", + "futures-io", + "futures-lite", + "gloo-timers", + "kv-log-macro", + "log", + "memchr", + "once_cell", + "pin-project-lite", + "pin-utils", + "slab", + "wasm-bindgen-futures", +] + +[[package]] +name = "async-task" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a40729d2133846d9ed0ea60a8b9541bccddab49cd30f0715a1da672fe9a2524" + [[package]] name = "async-trait" version = "0.1.64" @@ -74,6 +176,12 @@ dependencies = [ "syn", ] +[[package]] +name = "atomic-waker" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "debc29dde2e69f9e47506b525f639ed42300fc014a3e007832592448fa8e4599" + [[package]] name = "atty" version = "0.2.14" @@ -133,6 +241,20 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blocking" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c67b173a56acffd6d2326fb7ab938ba0b00a71480e14902b2591c87bc5741e8" +dependencies = [ + "async-channel", + "async-lock", + "async-task", + "atomic-waker", + "fastrand", + "futures-lite", +] + [[package]] name = "brotli" version = "3.3.4" @@ -238,6 +360,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "concurrent-queue" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c278839b831783b70278b14df4d45e1beb1aad306c07bb796637de9a0e323e8e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "config" version = "0.13.3" @@ -291,6 +422,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" +dependencies = [ + "cfg-if", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -301,6 +441,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "ctor" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "digest" version = "0.10.6" @@ -373,6 +523,12 @@ dependencies = [ "libc", ] +[[package]] +name = "event-listener" +version = "2.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" + [[package]] name = "fancy-regex" version = "0.11.0" @@ -447,6 +603,27 @@ version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" +[[package]] +name = "futures-io" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" + +[[package]] +name = "futures-lite" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7694489acd39452c77daa48516b894c153f192c3578d5a839b62c58099fcbf48" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "memchr", + "parking", + "pin-project-lite", + "waker-fn", +] + [[package]] name = "futures-sink" version = "0.3.26" @@ -516,11 +693,24 @@ dependencies = [ "walkdir", ] +[[package]] +name = "gloo-timers" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b995a66bb87bebce9a0f4a95aed01daca4872c050bfcb21653361c03bc35e5c" +dependencies = [ + "futures-channel", + "futures-core", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "gptcommit" version = "0.1.17" dependencies = [ "anyhow", + "async-std", "async-trait", "clap", "colored", @@ -785,6 +975,15 @@ dependencies = [ "serde", ] +[[package]] +name = "kv-log-macro" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de8b303297635ad57c9f5059fd9cee7a47f8e8daa09df0fcd07dd39fb22977f" +dependencies = [ + "log", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -826,6 +1025,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" dependencies = [ "cfg-if", + "value-bag", ] [[package]] @@ -990,6 +1190,12 @@ version = "6.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" +[[package]] +name = "parking" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "427c3892f9e783d91cc128285287e70a59e206ca452770ece88a76f7a3eddd72" + [[package]] name = "parking_lot" version = "0.12.1" @@ -1087,6 +1293,20 @@ version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" +[[package]] +name = "polling" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22122d5ec4f9fe1b3916419b76be1e80bcb93f618d071d2edf841b137b2a2bd6" +dependencies = [ + "autocfg", + "cfg-if", + "libc", + "log", + "wepoll-ffi", + "windows-sys 0.42.0", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -1892,6 +2112,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "value-bag" +version = "1.0.0-alpha.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2209b78d1249f7e6f3293657c9779fe31ced465df091bbd433a1cf88e916ec55" +dependencies = [ + "ctor", + "version_check", +] + [[package]] name = "vcpkg" version = "0.2.15" @@ -1904,6 +2134,12 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "waker-fn" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d5b2c62b4012a3e1eca5a7e077d13b3bf498c4073e33ccd58626607748ceeca" + [[package]] name = "walkdir" version = "2.3.2" @@ -2026,6 +2262,15 @@ dependencies = [ "webpki", ] +[[package]] +name = "wepoll-ffi" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d743fdedc5c64377b5fc2bc036b01c7fd642205a0d96356034ae3404d49eb7fb" +dependencies = [ + "cc", +] + [[package]] name = "which" version = "4.4.0" diff --git a/Cargo.toml b/Cargo.toml index 948d6b3..e0ef898 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,3 +38,6 @@ tiktoken-rs = "0.1.2" tokio = { version = "1.25.0", features = ["full"] } toml = "0.7.2" which = "4.4.0" + +[dev-dependencies] +async-std = "1.12.0" diff --git a/e2e/test_githook.sh b/e2e/test_githook.sh index 7c1363e..46f27f0 100755 --- a/e2e/test_githook.sh +++ b/e2e/test_githook.sh @@ -11,7 +11,7 @@ export TEMPDIR=$(mktemp -d) export TEMPFILE=$(mktemp) echo "foo" > $TEMPFILE - GPTCOMMIT__OPENAI__MODEL="text-ada-001" \ + GPTCOMMIT__MODEL_PROVIDER="tester-foobar" \ gptcommit prepare-commit-msg \ --git-diff-content "${DIFF_CONTENT_PATH}" \ --commit-msg-file "${TEMPFILE}" \ diff --git a/src/actions/prepare_commit_msg.rs b/src/actions/prepare_commit_msg.rs index 3421771..efbf42c 100644 --- a/src/actions/prepare_commit_msg.rs +++ b/src/actions/prepare_commit_msg.rs @@ -1,52 +1,26 @@ -use anyhow::bail; use anyhow::Result; use clap::arg; use clap::ValueEnum; use colored::Colorize; use clap::Args; -use tokio::try_join; - -use std::collections::HashMap; use std::fs; use std::path::PathBuf; -use tokio::task::JoinSet; use crate::git; use crate::help::print_help_openai_api_key; -use crate::llms::base_llm::LlmClient; -use crate::llms::openai::OpenAIClient; +use crate::llms::{llm_client::LlmClient, openai::OpenAIClient}; +use crate::settings::ModelProvider; use crate::settings::Settings; use crate::summarize::SummarizationClient; -use crate::util; use crate::util::SplitPrefixInclusive; -/// Splits the contents of a git diff by file. -/// -/// The file path is the first string in the returned tuple, and the -/// file content is the second string in the returned tuple. -/// -/// The function assumes that the file_diff input is well-formed -/// according to the Diff format described in the Git documentation: -/// https://git-scm.com/docs/git-diff -async fn process_file_diff( - summarize_client: SummarizationClient, - file_diff: &str, -) -> Option<(String, String)> { - if let Some(file_name) = util::get_file_name_from_diff(file_diff) { - let completion = summarize_client.diff_summary(file_name, file_diff).await; - Some(( - file_name.to_string(), - completion.unwrap_or_else(|_| "".to_string()), - )) - } else { - None - } -} +use crate::llms::tester_foobar::FooBarClient; + #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, ValueEnum)] enum CommitSource { #[clap(name = "")] @@ -82,54 +56,28 @@ pub(crate) struct PrepareCommitMsgArgs { git_diff_content: Option, } -async fn get_commit_message( - client: SummarizationClient, - diff_as_input: &str, -) -> Result { - let file_diffs = diff_as_input.split_prefix_inclusive("\ndiff --git "); - - let mut set = JoinSet::new(); - - for file_diff in file_diffs { - let file_diff = file_diff.to_owned(); - let summarize_client = client.clone(); - set.spawn(async move { process_file_diff(summarize_client, &file_diff).await }); - } - - let mut summary_for_file: HashMap = HashMap::with_capacity(set.len()); - while let Some(res) = set.join_next().await { - if let Some((k, v)) = res.unwrap() { - summary_for_file.insert(k, v); +fn get_llm_client(settings: &Settings) -> Box { + match settings { + Settings { + model_provider: Some(ModelProvider::TesterFoobar), + .. + } => Box::new(FooBarClient::new().unwrap()), + Settings { + model_provider: Some(ModelProvider::OpenAI), + openai: Some(openai), + .. + } => { + let client = OpenAIClient::new(openai.to_owned()); + if let Err(_e) = client { + print_help_openai_api_key(); + panic!("OpenAI API key not found in config or environment"); + } + Box::new(client.unwrap()) } + _ => panic!("Could not load LLM Client from config!"), } - - let summary_points = &summary_for_file - .iter() - .map(|(file_name, completion)| format!("[{file_name}]\n{completion}")) - .collect::>() - .join("\n"); - - let (title, completion) = try_join!( - client.commit_title(summary_points), - client.commit_summary(summary_points) - )?; - - let mut message = String::with_capacity(1024); - - message.push_str(&format!("{title}\n\n{completion}\n\n")); - for (file_name, completion) in &summary_for_file { - if !completion.is_empty() { - message.push_str(&format!("[{file_name}]\n{completion}\n")); - } - } - - // split message into lines and uniquefy lines - let mut lines = message.lines().collect::>(); - lines.dedup(); - let message = lines.join("\n"); - - Ok(message) } + pub(crate) async fn main(settings: Settings, args: PrepareCommitMsgArgs) -> Result<()> { match (args.commit_source, settings.allow_amend) { (CommitSource::Empty, _) | (CommitSource::Commit, Some(true)) => {} @@ -143,13 +91,7 @@ pub(crate) async fn main(settings: Settings, args: PrepareCommitMsgArgs) -> Resu } }; - let client = match OpenAIClient::new(settings.openai.unwrap_or_default()) { - Ok(client) => client, - Err(_e) => { - print_help_openai_api_key(); - bail!("OpenAI API key not found in config or environment"); - } - }; + let client = get_llm_client(&settings); let summarization_client = SummarizationClient::new(settings.prompt.unwrap(), client)?; println!("{}", "🤖 Asking GPT-3 to summarize diffs...".green().bold()); @@ -160,7 +102,8 @@ pub(crate) async fn main(settings: Settings, args: PrepareCommitMsgArgs) -> Resu git::get_diffs()? }; - let commit_message = get_commit_message(summarization_client, &output).await?; + let file_diffs = output.split_prefix_inclusive("\ndiff --git "); + let commit_message = summarization_client.get_commit_message(file_diffs).await?; // prepend output to commit message let mut original_message: String = if args.commit_msg_file.is_file() { diff --git a/src/llms/base_llm.rs b/src/llms/llm_client.rs similarity index 78% rename from src/llms/base_llm.rs rename to src/llms/llm_client.rs index 8f380c1..d4cfc9b 100644 --- a/src/llms/base_llm.rs +++ b/src/llms/llm_client.rs @@ -1,8 +1,10 @@ -use async_trait::async_trait; +use std::fmt::Debug; + use anyhow::Result; +use async_trait::async_trait; #[async_trait] -pub trait LlmClient { +pub trait LlmClient: Debug + Send + Sync { /// It takes a prompt as input, and returns the completion using an external Large Language Model. async fn completions(&self, prompt: &str) -> Result; } diff --git a/src/llms/mod.rs b/src/llms/mod.rs index 34b3087..7cae2d3 100644 --- a/src/llms/mod.rs +++ b/src/llms/mod.rs @@ -1,2 +1,3 @@ -pub(crate) mod base_llm; +pub(crate) mod llm_client; pub(crate) mod openai; +pub(crate) mod tester_foobar; diff --git a/src/llms/openai.rs b/src/llms/openai.rs index ef0565c..9623cc9 100644 --- a/src/llms/openai.rs +++ b/src/llms/openai.rs @@ -9,7 +9,7 @@ use tiktoken_rs::tiktoken::{p50k_base, CoreBPE}; use crate::settings::OpenAISettings; -use super::base_llm::LlmClient; +use super::llm_client::LlmClient; #[derive(Clone, Debug)] pub(crate) struct OpenAIClient { diff --git a/src/llms/tester_foobar.rs b/src/llms/tester_foobar.rs new file mode 100644 index 0000000..59ae2ae --- /dev/null +++ b/src/llms/tester_foobar.rs @@ -0,0 +1,41 @@ +use anyhow::Result; + +use async_trait::async_trait; + +#[cfg(test)] +use async_std::task; + +use super::llm_client::LlmClient; + +#[derive(Clone, Debug)] +/// Tester LLM client +pub(crate) struct FooBarClient {} + +impl FooBarClient { + pub(crate) fn new() -> Result { + Ok(Self {}) + } +} + +#[async_trait] +impl LlmClient for FooBarClient { + /// Dummy Completion that responds with "foo bar" for prompt + async fn completions(&self, _prompt: &str) -> Result { + Ok("foo bar".to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic() { + task::block_on(async { + let client = FooBarClient::new().unwrap(); + + let result = client.completions("Hi there! ").await.unwrap(); + assert_eq!(result, "foo bar"); + }); + } +} diff --git a/src/settings.rs b/src/settings.rs index 37dea78..d24cc0e 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -20,11 +20,14 @@ use crate::{ }; #[derive(Debug, Clone, Display, Serialize, Default, EnumString)] -pub enum ModelProvider { +pub(crate) enum ModelProvider { #[default] #[strum(serialize = "openai")] #[serde(rename = "openai")] OpenAI, + #[strum(serialize = "tester-foobar")] + #[serde(rename = "tester-foobar")] + TesterFoobar, } // implement the trait `From` for `ValueKind` @@ -56,7 +59,7 @@ impl<'de> serde::Deserialize<'de> for ModelProvider { } #[derive(Debug, Default, Serialize, Deserialize, Clone)] -pub struct OpenAISettings { +pub(crate) struct OpenAISettings { pub api_key: Option, pub model: Option, } @@ -72,7 +75,7 @@ impl From for config::ValueKind { } #[derive(Debug, Default, Serialize, Deserialize, Clone)] -pub struct PromptSettings { +pub(crate) struct PromptSettings { pub file_diff: Option, pub commit_summary: Option, pub commit_title: Option, @@ -99,7 +102,7 @@ impl From for config::ValueKind { } #[derive(Debug, Default, Serialize, Deserialize, Clone)] -pub struct Settings { +pub(crate) struct Settings { pub model_provider: Option, pub openai: Option, pub prompt: Option, diff --git a/src/summarize.rs b/src/summarize.rs index 6f7bc5e..f3614fa 100644 --- a/src/summarize.rs +++ b/src/summarize.rs @@ -1,32 +1,100 @@ use std::collections::HashMap; +use std::sync::Arc; -use crate::llms::base_llm::LlmClient; +use crate::llms::llm_client::LlmClient; +use crate::util; use crate::{prompt::format_prompt, settings::PromptSettings}; use anyhow::Result; -#[derive(Clone, Debug)] -pub(crate) struct SummarizationClient { - client: T, +use tokio::task::JoinSet; +use tokio::try_join; +#[derive(Debug, Clone)] +pub(crate) struct SummarizationClient { + client: Arc, prompt_file_diff: String, prompt_commit_summary: String, prompt_commit_title: String, } -impl SummarizationClient { - pub(crate) fn new(settings: PromptSettings, client: T) -> Result { +impl SummarizationClient { + pub(crate) fn new(settings: PromptSettings, client: Box) -> Result { let prompt_file_diff = settings.file_diff.unwrap_or_default(); let prompt_commit_summary = settings.commit_summary.unwrap_or_default(); let prompt_commit_title = settings.commit_title.unwrap_or_default(); Ok(Self { - client, + client: client.into(), prompt_file_diff, prompt_commit_summary, prompt_commit_title, }) } - pub(crate) async fn diff_summary(&self, file_name: &str, file_diff: &str) -> Result { + pub(crate) async fn get_commit_message(&self, file_diffs: Vec<&str>) -> Result { + let mut set = JoinSet::new(); + + for file_diff in file_diffs { + let file_diff = file_diff.to_owned(); + let cloned_self = self.clone(); + set.spawn(async move { cloned_self.process_file_diff(&file_diff).await }); + } + + let mut summary_for_file: HashMap = HashMap::with_capacity(set.len()); + while let Some(res) = set.join_next().await { + if let Some((k, v)) = res.unwrap() { + summary_for_file.insert(k, v); + } + } + + let summary_points = &summary_for_file + .iter() + .map(|(file_name, completion)| format!("[{file_name}]\n{completion}")) + .collect::>() + .join("\n"); + + let (title, completion) = try_join!( + self.commit_title(summary_points), + self.commit_summary(summary_points) + )?; + + let mut message = String::with_capacity(1024); + + message.push_str(&format!("{title}\n\n{completion}\n\n")); + for (file_name, completion) in &summary_for_file { + if !completion.is_empty() { + message.push_str(&format!("[{file_name}]\n{completion}\n")); + } + } + + // split message into lines and uniquefy lines + let mut lines = message.lines().collect::>(); + lines.dedup(); + let message = lines.join("\n"); + + Ok(message) + } + + /// Splits the contents of a git diff by file. + /// + /// The file path is the first string in the returned tuple, and the + /// file content is the second string in the returned tuple. + /// + /// The function assumes that the file_diff input is well-formed + /// according to the Diff format described in the Git documentation: + /// https://git-scm.com/docs/git-diff + async fn process_file_diff(&self, file_diff: &str) -> Option<(String, String)> { + if let Some(file_name) = util::get_file_name_from_diff(file_diff) { + let completion = self.diff_summary(file_name, file_diff).await; + Some(( + file_name.to_string(), + completion.unwrap_or_else(|_| "".to_string()), + )) + } else { + None + } + } + + async fn diff_summary(&self, file_name: &str, file_diff: &str) -> Result { debug!("summarizing file: {}", file_name); let prompt = format_prompt( From 56fc10f82e6677a41faf7bff167d62f92bd05d50 Mon Sep 17 00:00:00 2001 From: Roger Zurawicki Date: Mon, 20 Feb 2023 16:34:34 -0500 Subject: [PATCH 5/5] Improve cargo install/test commands - Update Justfile to not use `--offline` for `cargo install` - Remove `--offline` from `cargo test` command --- Justfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Justfile b/Justfile index 8a9517a..d55b997 100644 --- a/Justfile +++ b/Justfile @@ -19,13 +19,13 @@ release: cargo build --release install: - cargo install --path . --offline + cargo install --path . e2e: install sh -eux -c 'for i in ./e2e/test_*.sh ; do sh -x "$i" ; done' test *args: e2e - cargo test --offline + cargo test alias t := test lint: