diff --git a/continuous_eval/metrics/__init__.py b/continuous_eval/metrics/__init__.py
index 83521ba..f8341eb 100644
--- a/continuous_eval/metrics/__init__.py
+++ b/continuous_eval/metrics/__init__.py
@@ -27,3 +27,7 @@
)
from continuous_eval.metrics.retrieval_precision_recall_f1 import PrecisionRecallF1
from continuous_eval.metrics.retrieval_ranked_metrics import RankedRetrievalMetrics
+from continuous_eval.metrics.code_deterministic_metrics import (
+ CodeStringMatch,
+ PythonASTSimilarity,
+)
\ No newline at end of file
diff --git a/continuous_eval/metrics/code_deterministic_metrics.py b/continuous_eval/metrics/code_deterministic_metrics.py
new file mode 100644
index 0000000..cae242f
--- /dev/null
+++ b/continuous_eval/metrics/code_deterministic_metrics.py
@@ -0,0 +1,319 @@
+import ast
+
+from munkres import Munkres
+from thefuzz import fuzz
+
+from continuous_eval.metrics.base import Metric
+
+
+class CodeStringMatch(Metric):
+ def calculate(self, answer, ground_truths, **kwargs):
+ max_exact_match = 0
+ max_similarity_score = 0
+ for gt in ground_truths:
+ exact_match = float(answer == gt)
+ similarity_score = fuzz.ratio(answer, gt) / 100
+ if exact_match > max_exact_match:
+ max_exact_match = exact_match
+ if similarity_score > max_similarity_score:
+ max_similarity_score = similarity_score
+ return {
+ "Exact_Match_Score": max_exact_match,
+ "Fuzzy_Match_Score": max_similarity_score,
+ }
+
+
+class PythonASTSimilarity(Metric):
+ """
+ The following functions are adapted from python-ast-comparison by Pedro Salazar Paredes
+ Copyright (c) 2023 Pedro Salazar Paredes
+ Licensed under the MIT License
+ Source: https://github.com/PedroSalazarParedes/python-ast-comparison
+ Modifications: Adjusted to be used in the context of generated code evaluation
+ """
+
+ def _compare_ASTs(self, ast_a: ast.AST, ast_b: ast.AST, reorder_depth: int) -> int:
+ """Compare two ASTs corresponding to python programs.
+
+ Args:
+ ast_a: The first program AST to compare.
+ ast_b: The first program AST to compare.
+ reorder_depth: The maximum children reorder depth for better
+ performance.
+
+ Returns:
+ The number of matching nodes in the ASTs.
+ """
+ children_a = list(ast.iter_child_nodes(ast_a))
+ children_b = list(ast.iter_child_nodes(ast_b))
+ if (type(ast_a) == type(ast_b)) and len(list(children_a)) == 0 and len(list(children_b)) == 0:
+ return 1
+
+ if (type(ast_a) != type(ast_b)) or (len(children_a) != len(children_b)):
+ return 0
+
+ if reorder_depth == 0:
+ match_index = sum(
+ map(
+ lambda pairs: self._compare_ASTs(pairs[0], pairs[1], reorder_depth),
+ zip(children_a, children_b),
+ )
+ )
+ return match_index + 1
+
+ elif reorder_depth > 0:
+ match_index = self._reorder_children_compare(ast_a, ast_b, reorder_depth - 1)
+ return match_index + 1
+
+ return 0
+
+ def _reorder_children_compare(self, ast_a: ast.AST, ast_b: ast.AST, reorder_depth: int) -> int:
+ """Reorders child nodes and compares them.
+
+ Args:
+ ast_a: The first AST for child comparison.
+ ast_b: The second AST for child comparison.
+ reorder_depth: The maximum children reorder depth for better
+ performance.
+
+ Returns:
+ True if there is a way to match 1-1 every child node of ast_a
+ with every child node of ast_b, otherwise False.
+ """
+ comparison_matrix = []
+ cost_matrix = []
+ best_match_value = 0
+ children_a = list(ast.iter_child_nodes(ast_a))
+ children_b = list(ast.iter_child_nodes(ast_b))
+
+ if len(children_a) <= 1 or len(children_b) <= 1:
+ for child_a in children_a:
+ for child_b in children_b:
+ best_match_value += self._compare_ASTs(child_a, child_b, reorder_depth)
+ else:
+ for child_a in children_a:
+ row = []
+ cost_row = []
+ for child_b in children_b:
+ similarity = self._compare_ASTs(child_a, child_b, reorder_depth)
+ row.append(similarity)
+ cost_row.append(10000000 - similarity)
+
+ comparison_matrix.append(row)
+ cost_matrix.append(cost_row)
+
+ m = Munkres()
+ indices = m.compute(cost_matrix) # type: ignore
+
+ for row, col in indices:
+ best_match_value += comparison_matrix[row][col]
+
+ return best_match_value
+
+ def _compare_subtrees(self, sig_subtrees_p1: list, sig_subtrees_p2: list, reorder_depth: int) -> tuple:
+ """Compare two significant subtree lists reordering up to a certain depth.
+
+ Args:
+ sig_subtrees_p1: The first significant AST list for comparison.
+ sig_subtrees_p2: The second significant AST list for comparison.
+ reorder_depth: The maximum children reorder depth for better
+ performance.
+
+ Returns:
+ A tuple with the ratio of matching to non-matching nodes of the
+ significant subtrees, and a list with the best matching of subtrees.
+ """
+ comparison_matrix = []
+ cost_matrix = []
+ best_match = []
+ best_match_value = 0
+ best_match_weight = 0
+ children_a = sig_subtrees_p1.copy()
+ children_b = sig_subtrees_p2.copy()
+
+ if len(children_a) <= 1 or len(children_b) <= 1:
+ for child_a in children_a:
+ best_match += [child_a]
+ for child_b in children_b:
+ best_match_value += self._compare_ASTs(child_a, child_b, reorder_depth)
+ best_match += [child_b]
+ else:
+ for child_a in children_a:
+ row = []
+ cost_row = []
+ for child_b in children_b:
+ similarity = self._compare_ASTs(child_a, child_b, reorder_depth)
+ row.append(similarity)
+ cost_row.append(10000000 - similarity)
+
+ comparison_matrix.append(row)
+ cost_matrix.append(cost_row)
+
+ m = Munkres()
+ indices = m.compute(cost_matrix) # type: ignore
+
+ for row, col in indices:
+ best_match_weight += self._apply_weights_to_subtrees_mult(
+ comparison_matrix[row][col],
+ sig_subtrees_p1[row],
+ sig_subtrees_p2[col],
+ )
+ best_match += [sig_subtrees_p1[row], sig_subtrees_p2[col]]
+
+ all_subtrees_weight = sum(
+ map(
+ lambda tree: self._apply_weights_to_subtrees(self._get_num_nodes(tree), tree),
+ sig_subtrees_p1,
+ )
+ ) + sum(
+ map(
+ lambda tree: self._apply_weights_to_subtrees(self._get_num_nodes(tree), tree),
+ sig_subtrees_p2,
+ )
+ )
+
+ similarity = 2 * best_match_weight / all_subtrees_weight
+
+ return round(similarity, 4), best_match
+
+ def _is_significant(self, root: ast.AST) -> bool:
+ """Determine if an AST is significant.
+
+ Args:
+ root: The AST whose significance we want.
+
+ Returns:
+ True for if it is significant, False otherwise.
+ """
+ return (
+ isinstance(root, ast.Import)
+ or isinstance(root, ast.FunctionDef)
+ or isinstance(root, ast.If)
+ or isinstance(root, ast.ClassDef)
+ or isinstance(root, ast.While)
+ or isinstance(root, ast.For)
+ or isinstance(root, ast.comprehension)
+ or isinstance(root, ast.Return)
+ )
+
+ def _get_significant_subtrees(self, root: ast.AST) -> list:
+ """Find the significant subtrees of an AST.
+
+ Args:
+ root: The root of the main AST.
+
+ Returns:
+ A list with all the significant subtrees of root.
+ """
+ significant_subtrees = []
+ for node in ast.walk(root):
+ if self._is_significant(node):
+ significant_subtrees.append(node)
+ return significant_subtrees
+
+ def _get_num_nodes(self, root: ast.AST) -> int:
+ """Find the number of nodes for a given tree.
+
+ Args:
+ root: The root of the tree whose size we want.
+
+ Returns:
+ The number of nodes in the tree.
+ """
+ return len(list(ast.walk(root)))
+
+ def _apply_weights_to_subtrees(self, weight: float, subtree: ast.AST) -> float:
+ """Apply weights to subtrees according to the time por their roots.
+
+ Args:
+ weight: The number of nodes in the subtree.
+ subtree: The subtree.
+
+ Returns:
+ The weighed weight of the tree.
+ """
+ new_weight = weight
+ if isinstance(subtree, ast.Import):
+ new_weight *= 0.3
+ elif isinstance(subtree, ast.Module):
+ new_weight *= 1
+ elif isinstance(subtree, ast.FunctionDef):
+ new_weight *= 1.2
+ elif isinstance(subtree, ast.If):
+ new_weight *= 0.5
+ elif isinstance(subtree, ast.ClassDef):
+ new_weight *= 1
+ elif isinstance(subtree, ast.While):
+ new_weight *= 1
+ elif isinstance(subtree, ast.For):
+ new_weight *= 1
+ elif isinstance(subtree, ast.comprehension):
+ new_weight *= 1
+ elif isinstance(subtree, ast.Return):
+ new_weight *= 1
+ return new_weight
+
+ def _apply_weights_to_subtrees_mult(self, weight: float, ast_1: ast.AST, ast_2: ast.AST) -> float:
+ """Find the average weight of both trees in order to weigh the comparison.
+
+ Args:
+ weight: The weight of the comparison.
+ ast_1: The first compared tree.
+ ast_2: The second compared tree.
+
+ Returns:
+ The average of the subtrees' weights.
+ """
+ if weight == 0:
+ return 0
+ else:
+ return (self._apply_weights_to_subtrees(weight, ast_1) + self._apply_weights_to_subtrees(weight, ast_2)) / 2
+
+ def _compare_many(self, programs: list) -> list:
+ """Compare all of the programs in the list.
+
+ Args:
+ programs: A list of strings with python programs.
+
+ Returns:
+ A matrix with the similarity rating of between all the programs.
+ """
+ tree_list = list(map(lambda prog: self._get_significant_subtrees(ast.parse(prog)), programs))
+
+ matrix = []
+ for program_1_tree_num in range(0, len(tree_list)):
+ for program_2_tree_num in range(program_1_tree_num, len(tree_list)):
+ if program_1_tree_num == program_2_tree_num:
+ continue
+
+ subtrees1 = tree_list[program_1_tree_num]
+ subtrees2 = tree_list[program_2_tree_num]
+
+ result = self._compare_subtrees(subtrees1, subtrees2, 1000)[0]
+
+ matrix.append((program_1_tree_num, program_2_tree_num, result))
+ matrix.append((program_2_tree_num, program_1_tree_num, result))
+
+ return matrix
+
+ def calculate(self, answer, ground_truths, **kwargs):
+
+ try:
+ answer_tree = ast.parse(answer, mode="exec")
+ ground_truth_trees = [ast.parse(gt, mode="exec") for gt in ground_truths]
+ except SyntaxError as e:
+ return {"Python_AST_Similarity": -1.0}
+
+ answer_subtree = self._get_significant_subtrees(answer_tree)
+ ground_truth_subtrees = [
+ self._get_significant_subtrees(ground_truth_tree) for ground_truth_tree in ground_truth_trees
+ ]
+
+ similarity_scores = [
+ self._compare_subtrees(answer_subtree, ground_truth_subtree, 1000)[0]
+ for ground_truth_subtree in ground_truth_subtrees
+ ]
+
+ return {
+ "Python_AST_Similarity": max(similarity_scores),
+ }
diff --git a/docs/README.md b/docs/README.md
index b51abaa..872ced9 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -1,41 +1,7 @@
-# Starlight Starter Kit: Basics
+# Documentation
[![Built with Starlight](https://astro.badg.es/v2/built-with-starlight/tiny.svg)](https://starlight.astro.build)
-```
-npm create astro@latest -- --template starlight
-```
-
-[![Open in StackBlitz](https://developer.stackblitz.com/img/open_in_stackblitz.svg)](https://stackblitz.com/github/withastro/starlight/tree/main/examples/basics)
-[![Open with CodeSandbox](https://assets.codesandbox.io/github/button-edit-lime.svg)](https://codesandbox.io/p/sandbox/github/withastro/starlight/tree/main/examples/basics)
-[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fwithastro%2Fstarlight%2Ftree%2Fmain%2Fexamples%2Fbasics&project-name=my-starlight-docs&repository-name=my-starlight-docs)
-
-> π§βπ **Seasoned astronaut?** Delete this file. Have fun!
-
-## π Project Structure
-
-Inside of your Astro + Starlight project, you'll see the following folders and files:
-
-```
-.
-βββ public/
-βββ src/
-β βββ assets/
-β βββ content/
-β β βββ docs/
-β β βββ config.ts
-β βββ env.d.ts
-βββ astro.config.mjs
-βββ package.json
-βββ tsconfig.json
-```
-
-Starlight looks for `.md` or `.mdx` files in the `src/content/docs/` directory. Each file is exposed as a route based on its file name.
-
-Images can be added to `src/assets/` and embedded in Markdown with a relative link.
-
-Static assets, like favicons, can be placed in the `public/` directory.
-
## π§ Commands
All commands are run from the root of the project, from a terminal:
@@ -49,6 +15,3 @@ All commands are run from the root of the project, from a terminal:
| `npm run astro ...` | Run CLI commands like `astro add`, `astro check` |
| `npm run astro -- --help` | Get help using the Astro CLI |
-## π Want to learn more?
-
-Check out [Starlightβs docs](https://starlight.astro.build/), read [the Astro documentation](https://docs.astro.build), or jump into the [Astro Discord server](https://astro.build/chat).
diff --git a/docs/astro.config.mjs b/docs/astro.config.mjs
index 945de35..d0e55b3 100644
--- a/docs/astro.config.mjs
+++ b/docs/astro.config.mjs
@@ -60,6 +60,10 @@ export default defineConfig({
},
]
},
+ {
+ label: 'Code',
+ autogenerate: {directory: '/metrics/Code/'}
+ },
{
label: 'Metric Ensembling',
autogenerate: { directory: '/metrics/ensembling/' },
diff --git a/docs/src/content/docs/metrics/Code/Deterministic/python_ast_similarity.md b/docs/src/content/docs/metrics/Code/Deterministic/python_ast_similarity.md
new file mode 100644
index 0000000..b6f6d17
--- /dev/null
+++ b/docs/src/content/docs/metrics/Code/Deterministic/python_ast_similarity.md
@@ -0,0 +1,46 @@
+---
+title: Python AST Similarity
+sidebar:
+ badge:
+ text: new
+ variant: note
+---
+
+### Definitions
+
+**Python AST Similarity** compares the structure of two Python programs (generated code string vs. ground truth code string) by analyzing their Abstract Syntax Trees (ASTs). It evaluates how similar these programs are by matching nodes in the trees, considering both the types of statements and their organization. The comparison can involve reordering certain parts for a deeper match and uses a scoring system to quantify similarity.
+
+
+
+:::note
+The metric depends on syntactically correct Python scripts to produce the Abstract Syntax Trees (ASTs). If the scripts contain syntax errors and cannot be parsed, the metric will yield a score of -1.0.
+:::
+
+
+
+### Example Usage
+
+Required data items: `answer`, `ground_truths`
+
+```python
+from continuous_eval.metrics import PythonASTSimilarity
+
+datum = {
+ "answer": "def function(x, y):\n return x + y",
+ "ground_truths": [
+ "def foo(x, y):\n return x * y",
+ "def foo(x, y):\n return x + y",
+ ],
+},
+
+metric = PythonASTSimilarity()
+print(metric.calculate(**datum))
+```
+
+### Example Output
+
+```JSON
+{
+ "Python_AST_Similarity": 1.0
+}
+```
diff --git a/docs/src/content/docs/metrics/Code/Deterministic/string_match.md b/docs/src/content/docs/metrics/Code/Deterministic/string_match.md
new file mode 100644
index 0000000..f318a1f
--- /dev/null
+++ b/docs/src/content/docs/metrics/Code/Deterministic/string_match.md
@@ -0,0 +1,42 @@
+---
+title: StringMatch
+sidebar:
+ order: 1
+---
+
+### Definitions
+
+**Code String Match** measures how close the generated code string is to the ground truth code string.
+
+It outputs both the binary exact match score and the fuzzy match score in the range of (0.0 - 1.0).
+
+
+
+
+### Example Usage
+
+Required data items: `answer`, `ground_truths`
+
+```python
+from continuous_eval.metrics import CodeStringMatch
+
+datum = {
+ "answer": "def function(x, y):\n return x + y",
+ "ground_truths": [
+ "def foo(x, y):\n return x * y",
+ "def foo(x, y):\n return x + y",
+ ],
+},
+
+metric = CodeStringMatch()
+print(metric.calculate(**datum))
+```
+
+### Example Output
+
+```JSON
+{
+ "Exact_Match_Score": 0,
+ "Fuzzy_Match_Score": 0.89
+}
+```
diff --git a/docs/src/content/docs/metrics/Code/LLM-Based/llm-based.md b/docs/src/content/docs/metrics/Code/LLM-Based/llm-based.md
new file mode 100644
index 0000000..6b46416
--- /dev/null
+++ b/docs/src/content/docs/metrics/Code/LLM-Based/llm-based.md
@@ -0,0 +1,11 @@
+---
+title: LLM-based Code Metrics
+sidebar:
+ order: 1
+---
+
+### Coming Soon
+
+**LLM-Based Code Evaluation Metrics** can measure various types of code quality and correctness for different languages.
+
+Stay tuned as we release these metrics.
\ No newline at end of file
diff --git a/poetry.lock b/poetry.lock
index 6ba2436..db77f6b 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -4,7 +4,7 @@
name = "aiohttp"
version = "3.9.1"
description = "Async http client/server framework (asyncio)"
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "aiohttp-3.9.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e1f80197f8b0b846a8d5cf7b7ec6084493950d0882cc5537fb7b96a69e3c8590"},
@@ -100,7 +100,7 @@ speedups = ["Brotli", "aiodns", "brotlicffi"]
name = "aiosignal"
version = "1.3.1"
description = "aiosignal: a list of registered asynchronous callbacks"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"},
@@ -194,7 +194,7 @@ tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"]
name = "async-timeout"
version = "4.0.3"
description = "Timeout context manager for asyncio programs"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
@@ -205,7 +205,7 @@ files = [
name = "attrs"
version = "23.2.0"
description = "Classes Without Boilerplate"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"},
@@ -628,7 +628,7 @@ cron = ["capturer (>=2.4)"]
name = "dataclasses-json"
version = "0.6.3"
description = "Easily serialize dataclasses to and from JSON."
-optional = false
+optional = true
python-versions = ">=3.7,<4.0"
files = [
{file = "dataclasses_json-0.6.3-py3-none-any.whl", hash = "sha256:4aeb343357997396f6bca1acae64e486c3a723d8f5c76301888abeccf0c45176"},
@@ -787,7 +787,7 @@ files = [
name = "frozenlist"
version = "1.4.1"
description = "A list-like structure which implements collections.abc.MutableSequence"
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"},
@@ -1015,7 +1015,7 @@ grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
name = "greenlet"
version = "3.0.3"
description = "Lightweight in-process concurrent programming"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "greenlet-3.0.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9da2bd29ed9e4f15955dd1595ad7bc9320308a3b766ef7f837e23ad4b4aac31a"},
@@ -1457,7 +1457,7 @@ files = [
name = "jsonpatch"
version = "1.33"
description = "Apply JSON-Patches (RFC 6902)"
-optional = false
+optional = true
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
files = [
{file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"},
@@ -1482,7 +1482,7 @@ files = [
name = "jsonpointer"
version = "2.4"
description = "Identify specific nodes in a JSON document (RFC 6901)"
-optional = false
+optional = true
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
files = [
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
@@ -1519,7 +1519,7 @@ adal = ["adal (>=1.0.2)"]
name = "langchain"
version = "0.0.345"
description = "Building applications with LLMs through composability"
-optional = false
+optional = true
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langchain-0.0.345-py3-none-any.whl", hash = "sha256:461a126ec182834c714589ceec47354401d80b903262efab8d669fe941a0a4df"},
@@ -1560,7 +1560,7 @@ text-helpers = ["chardet (>=5.1.0,<6.0.0)"]
name = "langchain-core"
version = "0.0.13"
description = "Building applications with LLMs through composability"
-optional = false
+optional = true
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langchain_core-0.0.13-py3-none-any.whl", hash = "sha256:36d33a3d280877fb29a1f0f292b9b02b9ba29bf43fb54090b7364f00d5925459"},
@@ -1598,7 +1598,7 @@ six = "*"
name = "langsmith"
version = "0.0.83"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
-optional = false
+optional = true
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langsmith-0.0.83-py3-none-any.whl", hash = "sha256:a5bb7ac58c19a415a9d5f51db56dd32ee2cd7343a00825bbc2018312eb3d122a"},
@@ -1810,7 +1810,7 @@ files = [
name = "marshmallow"
version = "3.20.2"
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "marshmallow-3.20.2-py3-none-any.whl", hash = "sha256:c21d4b98fee747c130e6bc8f45c4b3199ea66bc00c12ee1f639f0aeca034d5e9"},
@@ -1950,7 +1950,7 @@ tests = ["pytest (>=4.6)"]
name = "multidict"
version = "6.0.4"
description = "multidict implementation"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"},
@@ -2029,6 +2029,17 @@ files = [
{file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
]
+[[package]]
+name = "munkres"
+version = "1.1.4"
+description = "Munkres (Hungarian) algorithm for the Assignment Problem"
+optional = false
+python-versions = "*"
+files = [
+ {file = "munkres-1.1.4-py2.py3-none-any.whl", hash = "sha256:6b01867d4a8480d865aea2326e4b8f7c46431e9e55b4a2e32d989307d7bced2a"},
+ {file = "munkres-1.1.4.tar.gz", hash = "sha256:fc44bf3c3979dada4b6b633ddeeb8ffbe8388ee9409e4d4e8310c2da1792db03"},
+]
+
[[package]]
name = "mypy-extensions"
version = "1.0.0"
@@ -3192,7 +3203,6 @@ files = [
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
- {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
@@ -3200,16 +3210,8 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
- {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
- {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
- {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
- {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
- {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
- {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
@@ -3226,7 +3228,6 @@ files = [
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
- {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
@@ -3234,7 +3235,6 @@ files = [
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
- {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
@@ -3244,7 +3244,7 @@ files = [
name = "rapidfuzz"
version = "3.6.1"
description = "rapid fuzzy string matching"
-optional = true
+optional = false
python-versions = ">=3.8"
files = [
{file = "rapidfuzz-3.6.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ac434fc71edda30d45db4a92ba5e7a42c7405e1a54cb4ec01d03cc668c6dcd40"},
@@ -3863,7 +3863,7 @@ files = [
name = "sqlalchemy"
version = "2.0.25"
description = "Database Abstraction Library"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "SQLAlchemy-2.0.25-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4344d059265cc8b1b1be351bfb88749294b87a8b2bbe21dfbe066c4199541ebd"},
@@ -3996,7 +3996,7 @@ widechars = ["wcwidth"]
name = "tenacity"
version = "8.2.3"
description = "Retry code until it succeeds"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"},
@@ -4006,6 +4006,20 @@ files = [
[package.extras]
doc = ["reno", "sphinx", "tornado (>=4.5)"]
+[[package]]
+name = "thefuzz"
+version = "0.22.1"
+description = "Fuzzy string matching in python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "thefuzz-0.22.1-py3-none-any.whl", hash = "sha256:59729b33556850b90e1093c4cf9e618af6f2e4c985df193fdf3c5b5cf02ca481"},
+ {file = "thefuzz-0.22.1.tar.gz", hash = "sha256:7138039a7ecf540da323792d8592ef9902b1d79eb78c147d4f20664de79f3680"},
+]
+
+[package.dependencies]
+rapidfuzz = ">=3.0.0,<4.0.0"
+
[[package]]
name = "threadpoolctl"
version = "3.2.0"
@@ -4447,7 +4461,7 @@ files = [
name = "typing-inspect"
version = "0.9.0"
description = "Runtime inspection utilities for typing module."
-optional = false
+optional = true
python-versions = "*"
files = [
{file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"},
@@ -4973,7 +4987,7 @@ files = [
name = "yarl"
version = "1.9.4"
description = "Yet another URL library"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"},
@@ -5090,9 +5104,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[extras]
anthropic = ["anthropic"]
gemini = ["google-generativeai"]
-generators = ["chromadb", "pinecone-client", "tiktoken", "unstructured"]
+generators = ["chromadb", "langchain", "pinecone-client", "tiktoken", "unstructured"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
-content-hash = "654c613bf80cd272953382c858a9f7705833be88f965a5b21de55177127444d9"
+content-hash = "5c92b93d6a18309e82403177b56bb696473866130b2b6b22568a6ca6dfdf6548"
diff --git a/pyproject.toml b/pyproject.toml
index 23e2d0c..9037133 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,6 +30,8 @@ chromadb = {version = "^0.4.21", optional = true}
tiktoken = {version = "^0.5.2", optional = true}
unstructured = {version = "^0.11.6", optional = true}
appdirs = "^1.4.4"
+munkres = "^1.1.4"
+thefuzz = "^0.22.1"
[tool.poetry.extras]
anthropic = ["anthropic"]
diff --git a/tests/code_metrics_test.py b/tests/code_metrics_test.py
new file mode 100644
index 0000000..f4c8cb6
--- /dev/null
+++ b/tests/code_metrics_test.py
@@ -0,0 +1,37 @@
+import pytest
+
+from continuous_eval.metrics import CodeStringMatch, PythonASTSimilarity
+from tests.helpers import example_datum
+from tests.helpers.utils import all_close
+
+
+def test_code_string_match():
+ expected_results = [
+ {"Exact_Match_Score": 0, "Fuzzy_Match_Score": 0.89},
+ {"Exact_Match_Score": 0, "Fuzzy_Match_Score": 0.73},
+ {"Exact_Match_Score": 0, "Fuzzy_Match_Score": 0.67},
+ {"Exact_Match_Score": 0, "Fuzzy_Match_Score": 0.21},
+ {"Exact_Match_Score": 0, "Fuzzy_Match_Score": 0.9},
+ {"Exact_Match_Score": 0, "Fuzzy_Match_Score": 0.71},
+ ]
+ metric = CodeStringMatch()
+ assert all(
+ all_close(metric.calculate(**datum), expected)
+ for datum, expected in zip(example_datum.PYTHON_CODE_EXAMPLES, expected_results)
+ )
+
+
+def test_python_ast_similarity():
+ expected_results = [
+ {"Python_AST_Similarity": 1.0},
+ {"Python_AST_Similarity": 0.0},
+ {"Python_AST_Similarity": 0.0224},
+ {"Python_AST_Similarity": 0.0},
+ {"Python_AST_Similarity": -1.0},
+ {"Python_AST_Similarity": 0.0937},
+ ]
+ metric = PythonASTSimilarity()
+ assert all(
+ all_close(metric.calculate(**datum), expected)
+ for datum, expected in zip(example_datum.PYTHON_CODE_EXAMPLES, expected_results)
+ )
diff --git a/tests/helpers/example_datum.py b/tests/helpers/example_datum.py
index 0dd6707..9dafa94 100644
--- a/tests/helpers/example_datum.py
+++ b/tests/helpers/example_datum.py
@@ -97,3 +97,41 @@
"Not really, they didn't win for season three.",
],
}
+
+# =====================================================================================
+# CODE METRICS EXAMPLES
+# =====================================================================================
+
+PYTHON_CODE_EXAMPLES = [
+ {
+ "answer": "def function(x, y):\n return x + y",
+ "ground_truths": [
+ "def foo(x, y):\n return x * y",
+ "def foo(x, y):\n return x + y",
+ ],
+ },
+ {
+ "answer": "def foo(x, y):\n print(x + y)",
+ "ground_truths": ["def function(x, y):\n return x + y"],
+ },
+ {
+ "answer": "class MyClass:\n def __init__(self, x):\n self.x = x",
+ "ground_truths": [
+ "class MyClass:\n def __init__(self, x):\n self._x = x\n @property\n def x(self):\n return self._x",
+ ],
+ },
+ {
+ "answer": "print('Hello, World!')",
+ "ground_truths": ["def function(x, y):\n return x + y"],
+ },
+ {
+ "answer": "function(x, y):\nreturn x + y",
+ "ground_truths": ["def function(x, y):\n return x + y"],
+ },
+ {
+ "answer": "def rotate(text, key):\n alpha = string.ascii_lowercase\n alpha_shift = alpha[key:] + alpha[:key]\n table = str.maketrans(alpha + alpha.upper(), alpha_shift + alpha_shift.upper())\n return text.translate(table)",
+ "ground_truths": [
+ "def rotate(text, key):\n newchars = string.ascii_lowercase[key:] + string.ascii_lowercase[:key]\n trans = str.maketrans(string.ascii_lowercase + string.ascii_lowercase.upper(), newchars + newchars.upper())\n return text.translate(trans)"
+ ],
+ },
+]