From 1c734a159497ffcb07857fd897ab8908155912cc Mon Sep 17 00:00:00 2001 From: Roger Zurawicki Date: Wed, 8 Feb 2023 10:38:50 -0500 Subject: [PATCH] Refactor prompt and summary files for tera - Renamed and updated files to use tera for prompt formatting --- Cargo.lock | 167 ++++++++++++++++-- Cargo.toml | 3 +- e2e/test_githook.sh | 3 +- e2e/test_install.sh | 3 + ...ommit.prompt.txt => summarize_commit.tera} | 2 +- ...ff.prompt.txt => summarize_file_diff.tera} | 2 +- ...le_commit.prompt.txt => title_commit.tera} | 5 +- src/prompt.rs | 45 +---- src/summarize.rs | 6 +- 9 files changed, 175 insertions(+), 61 deletions(-) rename prompts/{summarize_commit.prompt.txt => summarize_commit.tera} (97%) rename prompts/{summarize_file_diff.prompt.txt => summarize_file_diff.tera} (99%) rename prompts/{title_commit.prompt.txt => title_commit.tera} (94%) diff --git a/Cargo.lock b/Cargo.lock index e5030f6..19e5ba7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -492,6 +492,30 @@ dependencies = [ "wasi", ] +[[package]] +name = "globset" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "029d74589adefde59de1a0c4f4732695c32805624aec7b68d91503d4dba79afc" +dependencies = [ + "aho-corasick", + "bstr", + "fnv", + "log", + "regex", +] + +[[package]] +name = "globwalk" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93e3af942408868f6934a7b85134a3230832b9977cf66125df2f9edcfce4ddcc" +dependencies = [ + "bitflags", + "ignore", + "walkdir", +] + [[package]] name = "gptcommit" version = "0.1.15" @@ -511,9 +535,10 @@ dependencies = [ "simple_logger", "strum", "strum_macros", + "tera", "tiktoken-rs", "tokio", - "toml 0.7.1", + "toml 0.7.2", "which", ] @@ -669,6 +694,23 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "ignore" +version = "0.4.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbe7873dab538a9a44ad79ede1faf5f30d49f9a5c883ddbab48bce81b64b7492" +dependencies = [ + "globset", + "lazy_static", + "log", + "memchr", + "regex", + "same-file", + "thread_local", + "walkdir", + "winapi-util", +] + [[package]] name = "indexmap" version = "1.9.2" @@ -984,9 +1026,9 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pest" -version = "2.5.4" +version = "2.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ab62d2fa33726dbe6321cc97ef96d8cde531e3eeaf858a058de53a8a6d40d8f" +checksum = "028accff104c4e513bad663bbcd2ad7cfd5304144404c31ed0a77ac103d00660" dependencies = [ "thiserror", "ucd-trie", @@ -994,9 +1036,9 @@ dependencies = [ [[package]] name = "pest_derive" -version = "2.5.4" +version = "2.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bf026e2d0581559db66d837fe5242320f525d85c76283c61f4d51a1238d65ea" +checksum = "2ac3922aac69a40733080f53c1ce7f91dcf57e1a5f6c52f421fadec7fbdc4b69" dependencies = [ "pest", "pest_generator", @@ -1004,9 +1046,9 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.5.4" +version = "2.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b27bd18aa01d91c8ed2b61ea23406a676b42d82609c6e2581fba42f0c15f17f" +checksum = "d06646e185566b5961b4058dd107e0a7f56e77c3f484549fb119867773c0f202" dependencies = [ "pest", "pest_meta", @@ -1017,9 +1059,9 @@ dependencies = [ [[package]] name = "pest_meta" -version = "2.5.4" +version = "2.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f02b677c1859756359fc9983c2e56a0237f18624a3789528804406b7e915e5d" +checksum = "e6f60b2ba541577e2a0c307c8f39d1439108120eb7903adeb6497fa880c59616" dependencies = [ "once_cell", "pest", @@ -1271,6 +1313,15 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.21" @@ -1485,6 +1536,22 @@ dependencies = [ "winapi", ] +[[package]] +name = "tera" +version = "1.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df578c295f9ec044ff1c829daf31bb7581d5b3c2a7a3d87419afe1f2531438c" +dependencies = [ + "globwalk", + "lazy_static", + "pest", + "pest_derive", + "regex", + "serde", + "serde_json", + "unic-segment", +] + [[package]] name = "termcolor" version = "1.2.0" @@ -1514,6 +1581,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180" +dependencies = [ + "once_cell", +] + [[package]] name = "tiktoken-rs" version = "0.1.2" @@ -1606,9 +1682,9 @@ dependencies = [ [[package]] name = "tokio-native-tls" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d995660bd2b7f8c1568414c1126076c13fbb725c40112dc0120b78eb9b717b" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" dependencies = [ "native-tls", "tokio", @@ -1650,9 +1726,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "772c1426ab886e7362aedf4abc9c0d1348a979517efedfc25862944d10137af0" +checksum = "f7afcae9e3f0fe2c370fd4657108972cbb2fa9db1b9f84849cefd80741b01cb6" dependencies = [ "serde", "serde_spanned", @@ -1671,9 +1747,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.19.1" +version = "0.19.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90a238ee2e6ede22fb95350acc78e21dc40da00bb66c0334bde83de4ed89424e" +checksum = "5e6a7712b49e1775fb9a7b998de6635b299237f48b404dde71704f2e0e7f37e5" dependencies = [ "indexmap", "nom8", @@ -1726,6 +1802,56 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" +[[package]] +name = "unic-char-property" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8c57a407d9b6fa02b4795eb81c5b6652060a15a7903ea981f3d723e6c0be221" +dependencies = [ + "unic-char-range", +] + +[[package]] +name = "unic-char-range" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0398022d5f700414f6b899e10b8348231abf9173fa93144cbc1a43b9793c1fbc" + +[[package]] +name = "unic-common" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d7ff825a6a654ee85a63e80f92f054f904f21e7d12da4e22f9834a4aaa35bc" + +[[package]] +name = "unic-segment" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4ed5d26be57f84f176157270c112ef57b86debac9cd21daaabbe56db0f88f23" +dependencies = [ + "unic-ucd-segment", +] + +[[package]] +name = "unic-ucd-segment" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2079c122a62205b421f499da10f3ee0f7697f012f55b675e002483c73ea34700" +dependencies = [ + "unic-char-property", + "unic-char-range", + "unic-ucd-version", +] + +[[package]] +name = "unic-ucd-version" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96bd2f2237fe450fcd0a1d2f5f4e91711124f7857ba2e964247776ebeeb7b0c4" +dependencies = [ + "unic-common", +] + [[package]] name = "unicode-bidi" version = "0.3.10" @@ -1776,6 +1902,17 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" +dependencies = [ + "same-file", + "winapi", + "winapi-util", +] + [[package]] name = "want" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 6e24dcb..7f872ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,8 @@ serde_json = "1.0.92" simple_logger = "4.0.0" strum = "0.24.1" strum_macros = "0.24.3" +tera = { version = "1.17.1", default-features = false } tiktoken-rs = "0.1.2" tokio = { version = "1.25.0", features = ["full"] } -toml = "0.7.1" +toml = "0.7.2" which = "4.4.0" diff --git a/e2e/test_githook.sh b/e2e/test_githook.sh index 40427b4..282c268 100755 --- a/e2e/test_githook.sh +++ b/e2e/test_githook.sh @@ -2,6 +2,7 @@ set -eu ( + rm -rf test_dir_foo4 mkdir test_dir_foo4 cd test_dir_foo4 git init @@ -17,4 +18,4 @@ set -eu cat $TEMPFILE ) -rm -rf test_dir_foo4 +rm -rf test_dir_foo4 diff --git a/e2e/test_install.sh b/e2e/test_install.sh index e8eaf66..b0f020d 100755 --- a/e2e/test_install.sh +++ b/e2e/test_install.sh @@ -2,6 +2,7 @@ set -eu ( + rm -rf test_dir_foo mkdir test_dir_foo cd test_dir_foo git init @@ -16,6 +17,7 @@ rm -rf test_dir_foo ; ############################# ( + rm -rf test_dir_foo2 mkdir test_dir_foo2 cd test_dir_foo2 git init @@ -28,6 +30,7 @@ rm -rf test_dir_foo2 ############################# ( + rm -rf test_dir_foo3 mkdir test_dir_foo3 cd test_dir_foo3 # no git init diff --git a/prompts/summarize_commit.prompt.txt b/prompts/summarize_commit.tera similarity index 97% rename from prompts/summarize_commit.prompt.txt rename to prompts/summarize_commit.tera index 768205d..aa83edd 100644 --- a/prompts/summarize_commit.prompt.txt +++ b/prompts/summarize_commit.tera @@ -9,7 +9,7 @@ Write the most important bullet points. The list should not be more than a few b THE FILE SUMMARIES: ``` - +{{ summary_points }} ``` Remember to write only the most important points and do not write more than a few bullet points. diff --git a/prompts/summarize_file_diff.prompt.txt b/prompts/summarize_file_diff.tera similarity index 99% rename from prompts/summarize_file_diff.prompt.txt rename to prompts/summarize_file_diff.tera index b6c110f..4adfdd1 100644 --- a/prompts/summarize_file_diff.prompt.txt +++ b/prompts/summarize_file_diff.tera @@ -41,7 +41,7 @@ It is given only as an example of appropriate comments. THE GIT DIFF TO BE SUMMARIZED: ``` - +{{ file_diff }} ``` THE SUMMARY: diff --git a/prompts/title_commit.prompt.txt b/prompts/title_commit.tera similarity index 94% rename from prompts/title_commit.prompt.txt rename to prompts/title_commit.tera index 2a2d071..3953dde 100644 --- a/prompts/title_commit.prompt.txt +++ b/prompts/title_commit.tera @@ -17,9 +17,8 @@ Schedule all GitHub actions on all OSs THE FILE SUMMARIES: ``` - +{{ summary_points }} ``` Remember to write only one line, no more than 50 characters. -THE PULL REQUEST TITLE: - +THE PULL REQUEST TITLE: \ No newline at end of file diff --git a/src/prompt.rs b/src/prompt.rs index 2a78b9f..0ef10c6 100644 --- a/src/prompt.rs +++ b/src/prompt.rs @@ -1,42 +1,15 @@ -use anyhow::bail; -use anyhow::Result; -use regex::Regex; +use tera::{Context, Error}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; +use tera::Tera; -pub fn format_prompt(prompt: &str, map: HashMap<&str, &str>) -> Result { - lazy_static! { - static ref RE: Regex = Regex::new("<([A-Z_]+)>").unwrap(); - } +pub fn format_prompt(prompt: &str, map: HashMap<&str, &str>) -> Result { + let context = Context::from_serialize(map)?; - let required_keys: HashSet = RE - .captures_iter(prompt) - .map(|cap| cap[1].to_string()) - .collect(); - let provided_keys: HashSet = map.keys().map(|s| s.to_string()).collect(); - - if !required_keys.eq(&provided_keys) { - bail!( - r#"Required keys did not match provided keys. - Required: {:?} - Provided: {:?} - Prompt: {}"#, - required_keys, - provided_keys, - prompt - ); - } - - let mut result = prompt.to_string(); - for (key, value) in map { - result = result.replace(&format!("<{key}>"), value); - } - Ok(result) + Tera::one_off(prompt, &context, false) } -pub static PROMPT_TO_SUMMARIZE_DIFF: &str = - include_str!("../prompts/summarize_file_diff.prompt.txt"); +pub static PROMPT_TO_SUMMARIZE_DIFF: &str = include_str!("../prompts/summarize_file_diff.tera"); pub static PROMPT_TO_SUMMARIZE_DIFF_SUMMARIES: &str = - include_str!("../prompts/summarize_commit.prompt.txt"); -pub static PROMPT_TO_SUMMARIZE_DIFF_TITLE: &str = - include_str!("../prompts/title_commit.prompt.txt"); + include_str!("../prompts/summarize_commit.tera"); +pub static PROMPT_TO_SUMMARIZE_DIFF_TITLE: &str = include_str!("../prompts/title_commit.tera"); diff --git a/src/summarize.rs b/src/summarize.rs index 3caf29e..e5d91b9 100644 --- a/src/summarize.rs +++ b/src/summarize.rs @@ -34,7 +34,7 @@ impl SummarizationClient { let prompt = format_prompt( &self.prompt_file_diff, - HashMap::from([("FILE_DIFF", file_diff)]), + HashMap::from([("file_diff", file_diff)]), )?; let completion = self.client.completions(&prompt).await; @@ -44,7 +44,7 @@ impl SummarizationClient { pub(crate) async fn commit_summary(&self, summary_points: &str) -> Result { let prompt = format_prompt( &self.prompt_commit_summary, - HashMap::from([("SUMMARY_POINTS", summary_points)]), + HashMap::from([("summary_points", summary_points)]), )?; let completion = self.client.completions(&prompt).await; @@ -54,7 +54,7 @@ impl SummarizationClient { pub(crate) async fn commit_title(&self, summary_points: &str) -> Result { let prompt = format_prompt( &self.prompt_commit_title, - HashMap::from([("SUMMARY_POINTS", summary_points)]), + HashMap::from([("summary_points", summary_points)]), )?; let completion = self.client.completions(&prompt).await;