In [1]:
import os
import re
import ast
from typing import Dict, Any, List

import weave
from weave import Dataset, Evaluation
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from openai import OpenAI
import autopep8
import isort
from autoflake import fix_code

load_dotenv()

True

In [2]:
WEAVE_PROJECT = "codegen-cookbook"
# weave.init(WEAVE_PROJECT)

In [3]:
prompt_dataset = Dataset(name="minimal_code_gen_example", rows=[
    {
        "prompt": "Create a Python function that calculates the factorial of a given number."
    }
])
# weave.publish(prompt_dataset)

In [4]:
class GeneratedCode(BaseModel):
    code: str

class ProgramRunner(BaseModel):
    main_function_code: str

class UnitTest(BaseModel):
    test_code: str

class CodeFormatter(BaseModel):

    @weave.op()
    def lint_code(self, code: str) -> str:
        # Replace escaped newlines with actual newlines
        code = code.replace('\\n', '\n')

        # Remove unused imports and variables
        code = fix_code(code, remove_all_unused_imports=True,
                        remove_unused_variables=True)

        # Sort imports
        code = isort.code(code)

        # Apply PEP 8 formatting
        code = autopep8.fix_code(code, options={'aggressive': 1})

        return code

    @weave.op()
    def format_generated_code(self, generated_code: GeneratedCode) -> GeneratedCode:
        cleaned_code = self.lint_code(generated_code.code)
        return GeneratedCode(code=cleaned_code)

    @weave.op()
    def format_program_runner(self, program_runner: ProgramRunner) -> ProgramRunner:
        cleaned_code = self.lint_code(program_runner.main_function_code)
        return ProgramRunner(main_function_code=cleaned_code)

    @weave.op()
    def format_unit_test(self, unit_test: UnitTest) -> UnitTest:
        cleaned_code = self.lint_code(unit_test.test_code)
        return UnitTest(test_code=cleaned_code)

In [5]:
class CodeGenerationPipeline(weave.Model):

    model_name: str
    formatter: CodeFormatter
    client: OpenAI

    def __init__(self, model_name: str = "gpt-4o", formatter: CodeFormatter = CodeFormatter(), client: OpenAI = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))):
        super().__init__(model_name=model_name, formatter=formatter, client=client)
        self.model_name = model_name
        self.formatter = formatter
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

    @weave.op()
    async def predict(self, prompt: str):
        generated_code = self.generate_code(prompt)
        formatted_generated_code = self.formatter.format_generated_code(generated_code)

        program_runner = self.generate_program(formatted_generated_code)
        formatted_program_runner = self.formatter.format_program_runner(program_runner)

        unit_tests = self.generate_tests(formatted_generated_code, formatted_program_runner)
        formatted_unit_tests = self.formatter.format_unit_test(unit_tests)
        
        return {
            "generated_code": formatted_generated_code,
            "program_runner": formatted_program_runner,
            "unit_tests": formatted_unit_tests,
        }

    @weave.op()
    def generate_code(self, prompt: str) -> GeneratedCode:
        completion = self.client.beta.chat.completions.parse(
            model=self.model_name,
            messages=[
                {"role": "system", "content": "You are an expert Python code generator."},
                {"role": "user", "content": prompt}
            ],
            response_format=GeneratedCode,
        )
        message = completion.choices[0].message
        if message.parsed:
            return message.parsed
        else:
            raise ValueError(message.refusal)

    @weave.op()
    def generate_program(self, generated_code: GeneratedCode) -> ProgramRunner:
        completion = self.client.beta.chat.completions.parse(
            model=self.model_name,
            messages=[
                {"role": "system", "content": """
You are an expert Python program generator. Create a main function that orchestrates the execution of the given functions. Follow these guidelines:

1. Create a main() function that calls the necessary functions to run the program.
2. Include a proper if __name__ == "__main__": guard to call the main() function.
3. Do not redefine or implement any functions; use only the functions provided.
4. Do not include any imports or package specifications.
5. Use clear and concise code with proper indentation.
6. Do not use escape characters for newlines; write actual line breaks.
7. Keep the main() function simple, calling only the top-level function(s) needed.

Example structure:

def main():
    result = top_level_function()
    print(result)

if __name__ == "__main__":
    main()

Remember, your task is solely to create the main() function and the __main__ guard. All other functions are assumed to be already defined.
"""},
                {"role": "user", "content": f"Generate a main function for this code:\n\n{generated_code.code}"}
            ],
            response_format=ProgramRunner,
        )
        message = completion.choices[0].message
        if message.parsed:
            return message.parsed
        else:
            raise ValueError(message.refusal)

    @weave.op()
    def generate_tests(self, generated_code: GeneratedCode, program_runner: ProgramRunner) -> UnitTest:
        completion = self.client.beta.chat.completions.parse(
            model=self.model_name,
            messages=[
                {"role": "system", "content": "You are an expert Python unit test generator."},
                {"role": "user", "content": f"""
Generate a complete unittest for the following code:

Context (Surrounding Code):
```python
{generated_code.code}
{program_runner.main_function_code}
```

Requirements:

1. **Structure:** Use `unittest.TestCase` and name the class `Test<FunctionName>`.
2. **Coverage:** Include tests for normal cases, edge cases, and potential errors.
3. **Naming:** Use descriptive test method names (e.g., `test_valid_input`, `test_empty_input`, `test_invalid_input_type`).
4. **Type Hints:** Include type hints for clarity.
5. **Mocking:** Mock external dependencies (e.g., database interactions, API calls) when necessary.
6. **Assertions:** Use appropriate assertions (e.g., `assertEqual`, `assertRaises`, `assertTrue`).
7. **Isolation:** Ensure test isolation to prevent interference between tests.
9. **Executable:** Include a `__main__` block to run tests directly: `if __name__ == '__main__': unittest.main()`
10. **Formatting:** Ensure proper indentation and formatting for readability.
11. **Imports:** Include all necessary imports.
12. **Completeness:** Provide a complete, runnable test file.

Provide only the complete, properly formatted test code, no explanations or markdown.
"""}
            ],
            response_format=UnitTest,
        )
        message = completion.choices[0].message
        if message.parsed:
            return message.parsed
        else:
            raise ValueError(message.refusal)


In [6]:
class TestResultScorer(weave.Scorer):
    @weave.op()
    def score(self, model_output: Dict[str, Any], prompt: str) -> Dict[str, Any]:
        if not model_output or "generated_code" not in model_output:
            return {"error": "No generated code provided"}

        generated_code = model_output["generated_code"].code
        unit_tests = model_output["unit_tests"].test_code

        code_quality_score = self.assess_code_quality(generated_code)
        test_coverage_score = self.assess_test_coverage(generated_code, unit_tests)
        functionality_score = self.assess_functionality(generated_code, prompt)

        overall_score = (code_quality_score + test_coverage_score + functionality_score) / 3

        return {
            "code_quality_score": code_quality_score,
            "test_coverage_score": test_coverage_score,
            "functionality_score": functionality_score,
            "overall_score": overall_score
        }

    @weave.op()
    def assess_code_quality(self, code: str) -> float:
        score = 0.0
        try:
            tree = ast.parse(code)
            
            # Check for docstrings
            for node in ast.walk(tree):
                if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
                    if ast.get_docstring(node):
                        score += 0.2

            # Check for type hints
            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef):
                    if node.returns or any(arg.annotation for arg in node.args.args):
                        score += 0.2

            # Check for meaningful variable names
            for node in ast.walk(tree):
                if isinstance(node, ast.Name):
                    if len(node.id) > 1 and not node.id.startswith('_'):
                        score += 0.1

            # Penalize for excessive line length
            lines = code.split('\n')
            if any(len(line) > 100 for line in lines):
                score -= 0.2

        except SyntaxError:
            return 0.0

        return min(max(score, 0.0), 1.0)

    @weave.op()
    def assess_test_coverage(self, code: str, unit_tests: str) -> float:
        score = 0.0
        
        # Check if unit tests are provided
        if not unit_tests:
            return 0.0

        try:
            code_tree = ast.parse(code)
            test_tree = ast.parse(unit_tests)

            code_functions = [node.name for node in ast.walk(code_tree) if isinstance(node, ast.FunctionDef)]
            test_functions = [node.name for node in ast.walk(test_tree) if isinstance(node, ast.FunctionDef) and node.name.startswith('test_')]

            # Score based on the number of test functions relative to code functions
            coverage_ratio = len(test_functions) / len(code_functions) if code_functions else 0
            score = min(coverage_ratio, 1.0)

            # Bonus for using assertions
            if 'self.assert' in unit_tests:
                score += 0.2

        except SyntaxError:
            return 0.0

        return min(score, 1.0)
    
    @weave.op()
    def assess_functionality(self, code: str, prompt: str) -> float:
        score = 0.0

        # Check if the code addresses the main points in the prompt
        prompt_keywords = set(re.findall(r'\b\w+\b', prompt.lower()))
        code_keywords = set(re.findall(r'\b\w+\b', code.lower()))

        keyword_match_ratio = len(prompt_keywords.intersection(code_keywords)) / len(prompt_keywords)
        score += keyword_match_ratio

        # Check if the code contains expected elements based on the prompt
        try:
            tree = ast.parse(code)
            if 'calculate' in prompt.lower() and any(isinstance(node, ast.Return) for node in ast.walk(tree)):
                score += 0.3

            if 'function' in prompt.lower() and any(isinstance(node, ast.FunctionDef) for node in ast.walk(tree)):
                score += 0.3
        except SyntaxError:
            return 0.0

        return min(score, 1.0)

In [7]:
for model_name in ["gpt-4o"]:
    pipeline = CodeGenerationPipeline(model_name=model_name)
    test_result_scorer = TestResultScorer()
    evaluation = Evaluation(
        name="minimal_code_gen_evaluation",
        dataset=prompt_dataset,
        scorers=[test_result_scorer]
    )
    results = await evaluation.evaluate(pipeline)
    print(results)

Traceback (most recent call last):
  File "/Users/anishshah/Documents/GitHub/weave/.venv/lib/python3.12/site-packages/weave/flow/eval.py", line 164, in predict_and_score
    model_output = await async_call(model_predict, **model_predict_args)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anishshah/Documents/GitHub/weave/.venv/lib/python3.12/site-packages/weave/trace/op.py", line 326, in wrapper
    return await func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/7l/0d0kns_n6f19hvbjw81shpkh0000gp/T/ipykernel_73119/4120024944.py", line 15, in predict
    generated_code = self.generate_code(prompt)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anishshah/Documents/GitHub/weave/.venv/lib/python3.12/site-packages/weave/trace/op.py", line 335, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/7l/0d0kns_n6f19hvbjw81shpkh0000gp/T/ipykernel_73119/4120024944.py"

{'TestResultScorer': None, 'model_latency': {'mean': 0.39632105827331543}}
