In [None]:
import openai
import json
from typing import List, Tuple, Dict, Optional
from collections import defaultdict
import instructor
import weave
from set_env import set_env
from pydantic import BaseModel, Field
from typing import List, Literal, Union, Any
from pprint import pprint
from instructor_models import TaskDescription, CombinedTaskDescription, Criterion, EvaluationCriteria, PythonAssertion, LLMAssertion, CriterionAssertions
import asyncio
import nest_asyncio

In [None]:
set_env("OPENAI_API_KEY")
set_env("WANDB_API_KEY")

In [None]:
try:
    import IPython
    in_jupyter = True
except ImportError:
    in_jupyter = False
if in_jupyter:
    nest_asyncio.apply()

In [None]:
import random
weave.init(f"evalgen_test_{random.randint(0, 1000000)}")

In [None]:
client = instructor.from_openai(openai.AsyncOpenAI())

In [None]:
DataPoint = Tuple[dict, dict, Literal[0, 1], Optional[str], Optional[str], Optional[str]]  # (input, output, annotation, note, human_description_for_task_or_judge, human_description_for_metric_details)

In [None]:
MODEL = "gpt-4o-2024-08-06"

In [None]:
TEST_TASK = "medical"
if TEST_TASK == "medical":
    data = weave.ref("weave:///a-sh0ts/medical_data_results/object/medical_data_annotations:4utHXhRnO2oquowlrvJxCztretSxtFavBUHVciMZJBw").get()
elif TEST_TASK == "product":
    pass
else:
    data = [
        ({"text": "Summarize the impact of climate change on polar bears."}, {"text": "Climate change is reducing sea ice, which polar bears rely on for hunting seals."}, 1, "Accurate and relevant."),
        ({"text": "Explain the process of photosynthesis."}, {"text": "Photosynthesis is the process by which plants use sunlight to synthesize foods from carbon dioxide and water."}, 1, "Correct and detailed."),
        ({"text": "What are the main causes of the American Civil War?"}, {"text": "The main causes were slavery, states' rights, and economic differences."}, 1, "Concise and accurate."),
        ({"text": "Describe the symptoms of COVID-19."}, {"text": "COVID-19 is caused by a virus that originated in bats."}, 0, "Irrelevant and incorrect."),
        ({"text": "What is the significance of the Magna Carta?"}, {"text": "The Magna Carta was a document that limited the power of the king and established certain legal rights."}, 1, "Historically accurate and relevant.")
    ]

In [None]:
pprint(data[0])

In [None]:
# TODO: Batch this as opposed to one at a time
# or sample the dataset and ensure that taking into tokens (maybe something fun with a distribution)
# distribution = more stuff we can grab and throw into prompt in smart way
@weave.op()
async def get_task_description(data: List[DataPoint]) -> str:
    task_description = ""
    
    for i, datapoint in enumerate(data):
        input_data, output_data, annotation, note = datapoint[0], datapoint[1], datapoint[2], datapoint[3]
        
        prompt = f"""
        Current task description: {task_description}

        New datapoint:
        Input: {input_data}
        Output: {output_data}
        Annotation: {"Correct" if annotation == 1 else "Incorrect"}
        Note: {note}

        Based on this new datapoint and the current task description, provide an updated, more refined task description. 
        If this is the first datapoint, create an initial task description.
        Focus on:
        1. The nature of the input and output data
        2. The specific information being extracted or transformed
        3. Any formatting or style requirements
        4. Evaluation criteria (based on the annotation and note)

        Keep the description concise yet comprehensive.
        """

        response = await client.chat.completions.create(
            model=MODEL,
            messages=[{"role": "user", "content": prompt}],
            response_model=TaskDescription
        )
        
        new_description = response.description
        
        # TODO: Add guardrails to prevent LLM from saying no update needed
        if new_description.lower().startswith("no update needed"):
            continue
        
        task_description = new_description

    return task_description

In [None]:
llm_task_description = await get_task_description(data)

In [None]:
@weave.op()
async def combine_human_and_llm_descriptions(data: List[DataPoint], llm_description: str) -> str:
    human_descriptions = set()
    for dp in data:
        if len(dp) > 4 and dp[4]:  # Check if human description exists
            human_descriptions.add(dp[4])
    
    if not human_descriptions:
        return llm_description
    
    human_context = "\n".join(f"- {desc}" for desc in human_descriptions)
    
    prompt = f"""
    LLM-generated task description:
    {llm_description}

    Additional human-provided context:
    {human_context}

    Your task is to create a comprehensive, coherent task description that combines insights from both the LLM-generated description and the human-provided context. Ensure that:
    1. The final description is clear and concise.
    2. It incorporates key points from both sources.
    3. Any contradictions are resolved logically.
    4. The description maintains a professional tone.
    5. It provides a complete picture of the task requirements and evaluation criteria.

    Please provide the combined description in a single, well-structured paragraph.
    """

    response = await client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": prompt}],
        response_model=CombinedTaskDescription
    )
    
    return response.description

In [None]:
finalized_task_description = await combine_human_and_llm_descriptions(data, llm_task_description)

In [None]:
finalized_task_description

In [None]:
def format_single_datapoint(dp: DataPoint, finalized_task_description: str) -> str:
    input_data, output_data, annotation, note = dp[0], dp[1], dp[2], dp[3]
    metrics_details = dp[5] if len(dp) > 5 else None

    formatted = [
        f"Task Description: {finalized_task_description}",
        "",
        "Input:",
        "\n".join(f"  {key.capitalize()}: {value}" for key, value in input_data.items()),
        "",
        "Output:",
        "\n".join(f"  {key.capitalize()}: {value}" for key, value in output_data.items()),
        "",
        f"Annotation: {'Correct' if annotation == 1 else 'Incorrect'}",
        f"Note: {note}"
    ]

    if metrics_details:
        formatted.append(f"Metrics Details: {metrics_details}")

    return "\n".join(formatted)

In [None]:
formatted_dp = format_single_datapoint(data[0], finalized_task_description)

In [None]:
pprint(formatted_dp)

In [None]:
@weave.op()
async def process_criteria(formatted_dp: str, finalized_task_description: str) -> str:
    prompt = f"""
Task Description: {finalized_task_description}

Analyze the following annotated datapoint:

{formatted_dp}

Generate 1-3 evaluation criteria that can be used to assess the quality of outputs for this task. Consider the following guidelines:

1. If a 'Metrics Details' field is present in the datapoint, prioritize this information as it provides the most important evaluation criteria.
2. Focus on general aspects of quality that can be used across multiple outputs.
3. Consider criteria that address potential misalignment between LLM outputs and human preferences.
4. Include criteria that can be evaluated both by code and by LLM-based evaluators.
5. Think about criteria that might reveal hallucinations, instruction-following, or other common LLM issues.
6. Generate criteria that could help in debugging or improving the LLM pipeline.

Provide each criterion as a concise statement, followed by a brief explanation of why it's important and how it might be evaluated (e.g., via code, LLM evaluator, or human judgment).

Return the criteria in this format:
[Criterion]: [Brief explanation and evaluation method]
[Criterion]: [Brief explanation and evaluation method]
[Criterion]: [Brief explanation and evaluation method]

Aim for a mix of straightforward, code-evaluable criteria and more nuanced criteria that might require LLM or human evaluation.
"""
    response = await client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": prompt}],
        response_model=EvaluationCriteria
    )
    return response

@weave.op()
async def generate_criteria(data: List[DataPoint], finalized_task_description: str) -> List[Criterion]:
    all_criteria = []

    async def process_datapoint(dp):
        formatted_dp = format_single_datapoint(dp, finalized_task_description)
        response = (await process_criteria(formatted_dp, finalized_task_description)).criteria
        return response

    tasks = [process_datapoint(dp) for dp in data]
    results = await asyncio.gather(*tasks)

    for new_criteria in results:
        # TODO: Add an additional check to see if a nearly identical criterion is already in the list
        all_criteria.extend(new_criteria)
    


    return all_criteria

In [None]:
# Generate criteria
criteria = await generate_criteria(data, finalized_task_description)

In [None]:
criteria[0]

In [None]:
criteria

In [None]:
@weave.op()
async def create_candidate_assertions(formatted_data_string: str, criterion: Criterion) -> CriterionAssertions:
    prompt = f"""
Given the following evaluation criterion and annotated data, generate 1-3 specific, testable assertions:

Criterion: {criterion.dict()}

Annotated data: {formatted_data_string}

Your task is to create assertions that can be used to evaluate LLM outputs based on this criterion. Follow these guidelines:

1. Make each assertion clear, concise, and directly related to the criterion
2. For Python assertions:
   - Provide a valid Python method that can be used within a unittest.TestCase class
   - Ensure the method name is in snake case and starts with test_
   - The method should take 'self' as the only input, where 'self.output' is a dictionary containing the LLM output being evaluated
   - The 'self.output' dictionary will have the same keys and shape as the output in the annotated data
   - Use unittest assertion methods (e.g., self.assertTrue, self.assertEqual) to test the output
   - The test should pass if the assertion is met, and fail otherwise
   - Only use the keys and shapes present in the annotated data output for your assertions
3. For LLM assertions:
   - Provide a clear, detailed prompt for an LLM to evaluate the assertion
   - The prompt should guide the LLM to return "PASS" or "FAIL" based on the evaluation
4. Include a mix of positive and negative assertions where appropriate
5. Consider edge cases and potential failure modes for the criterion
6. Aim for assertions that could be applied across multiple types of outputs

Ensure that your assertions are directly evaluable and avoid vague or subjective language. Focus on creating assertions that align with human preferences and can be used to validate the quality of LLM-generated evaluations.

Format your response as a JSON object with the following structure:
{{
  "assertions": [
    {{
      "test_name": "Name of the test case method in snake case",
      "text" or "code": "Assertion text or code",
      "evaluation_type": "python" or "llm"
    }},
    ...
  ]
}}
"""
    response = await client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": prompt}],
        response_model=CriterionAssertions
    )
    return response

In [None]:
#TODO: improve this function
def format_all_datapoints(data: List[DataPoint], finalized_task_description: str) -> str:
    formatted = [f"Task Description: {finalized_task_description}\n"]
    
    for i, dp in enumerate(data, 1):
        input_data, output_data, annotation, note = dp[0], dp[1], dp[2], dp[3]
        
        formatted.extend([
            f"Example {i}:",
            "Input:",
            json.dumps(input_data, indent=2),
            "",
            "Output:",
            json.dumps(output_data, indent=2),
            "",
            f"Annotation: {'Correct' if annotation == 1 else 'Incorrect'}",
            f"Note: {note}",
            "\n" + "-"*50 + "\n"  # Separator between examples
        ])
    
    return "\n".join(formatted)

In [None]:
formatted_data = format_all_datapoints(data, finalized_task_description)

In [None]:
pprint(formatted_data)

In [None]:
@weave.op()
async def generate_all_assertions(criteria, formatted_data):
    all_assertions = []

    async def process_criterion(criterion):
        assertions = (await create_candidate_assertions(formatted_data, criterion)).assertions
        return assertions

    # Create tasks for all criteria
    tasks = [process_criterion(criterion) for criterion in criteria]

    # Use asyncio.gather to run all tasks concurrently
    results = await asyncio.gather(*tasks)

    for assertions in results:
        all_assertions.extend(assertions)

    return all_assertions


In [None]:
# Usage
all_assertions = await generate_all_assertions(criteria, formatted_data)

In [None]:
[assertion for assertion in all_assertions if isinstance(assertion, PythonAssertion)]

In [None]:
[assertion for assertion in all_assertions if isinstance(assertion, LLMAssertion)]

In [None]:
def convert_datapoint_to_example(task_description: str, data: List[DataPoint]) -> List[Dict[str, Any]]:
    examples = []
    for dp in data:
        input_data, output_data, annotation, note = dp[0], dp[1], dp[2], dp[3]
        examples.append({
            "task_description": task_description,
            "input_data": input_data,
            "model_output": {"output": output_data},
            "annotation": annotation,
            "note": note
        })
    return examples

In [None]:
annotation_examples = convert_datapoint_to_example(finalized_task_description, data)

In [None]:
from combined_scorer import AssertionScorer, predict_passthrough

# Initialize the AssertionScorer with the assertions
scorer = AssertionScorer(
    assertions=all_assertions,
    llm_model="gpt-4o-2024-08-06",
    prompt_template="""
Task Description:
{task_description}

Evaluate the following output based on the given task, input, and assertion:

Input:
{input_data}

Output:
{model_output}

Assertion:
{assertion_text}

Consider the task description and input when evaluating the output against the assertion.
Respond with either 'PASS' if the output meets the assertion criteria in the context of the task and input, or 'FAIL' if it does not.
""",
    system_prompt="You are an AI assistant evaluating the quality of text outputs based on given tasks, inputs, and assertions."
)


# TODO: figure out how to get each examples individual results as opposed to aggregate
# Create a custom summarize function?
evaluation = weave.Evaluation(
    scorers=[scorer],
    dataset=annotation_examples,
)


assertion_results = asyncio.run(evaluation.evaluate(predict_passthrough))



In [None]:
assertion_results

In [None]:
@weave.op()
async def evaluate(scorer: AssertionScorer, annotation_examples: List[Dict[str, Any]]) -> Dict[str, List[Tuple[int, int, str]]]:
    async def process_example(example):
        result = await scorer.score(
            model_output={"output": example["model_output"]["output"]},
            task_description=example["task_description"],
            input_data=example["input_data"]
        )
        return result, example["annotation"]

    # Run all examples concurrently
    results = await asyncio.gather(*[process_example(example) for example in annotation_examples])

    # Initialize the result dictionary
    assertion_results: Dict[str, List[Tuple[int, int, str]]] = {}

    # Populate the result dictionary
    for result, human_annotation in results:
        llm_results = result.get('llm_assertion_results', {})
        code_results = result.get('code_assertion_results', {}).get('test_results', {})
        
        for assertion_name, score in llm_results.items():
            if assertion_name not in assertion_results:
                assertion_results[assertion_name] = []
            assertion_results[assertion_name].append((score, human_annotation, "llm"))
        
        for assertion_name, details in code_results.items():
            if assertion_name.endswith('_score'):
                base_name = assertion_name[:-6]  # Remove '_score' suffix
                if base_name not in assertion_results:
                    assertion_results[base_name] = []
                assertion_results[base_name].append((details, human_annotation, "python"))

    return assertion_results

In [None]:
assertion_results = asyncio.run(evaluate(scorer, annotation_examples))

In [None]:
assertion_results

1. **Selectivity**:
   ```python
   selectivity = passes / total_outputs
   ```
   Selectivity measures how often an assertion passes LLM outputs. A lower selectivity means the assertion is more "picky" or strict.

2. **Coverage**:
   ```python
   coverage = fails_on_bad / total_bad if total_bad > 0 else 0
   ```
   Coverage measures how well our assertions catch the outputs that humans marked as bad. A higher coverage means we're better at identifying problematic outputs.

3. **False Failure Rate (FFR)**:
   ```python
   ffr = fails_on_good / total_good if total_good > 0 else 0
   ```
   FFR shows how often our assertions incorrectly fail outputs that humans thought were good. A lower FFR is better, as it means we're not being overly strict.

4. **Alignment**:
   ```python
   alignment = 2 * (coverage * (1 - ffr)) / (coverage + (1 - ffr)) if (coverage + (1 - ffr)) > 0 else 0
   ```
   Alignment combines coverage and FFR into a single score. It represents how well our automated evaluations match human judgments overall.

These metrics help us refine our assertion set over time, aiming to catch more bad outputs while avoiding false alarms on good ones

In [None]:
def calculate_metrics(assertion_results: Dict[str, List[Tuple[int, int, str]]]) -> Dict[str, Dict[str, Union[float, str]]]:
    metrics = {}
    
    for assertion, results in assertion_results.items():
        total_outputs = len(results)
        total_bad = sum(1 for _, human_annotation, _ in results if human_annotation == 0)
        total_good = total_outputs - total_bad

        passes = sum(1 for score, _, _ in results if score == 1)
        fails = total_outputs - passes
        fails_on_bad = sum(1 for score, human_annotation, _ in results if human_annotation == 0 and score == 0)
        fails_on_good = sum(1 for score, human_annotation, _ in results if human_annotation == 1 and score == 0)

        selectivity = passes / total_outputs if total_outputs > 0 else 0
        coverage = fails_on_bad / total_bad if total_bad > 0 else 1  # If no bad outputs, perfect coverage
        ffr = fails_on_good / total_good if total_good > 0 else 0  # If no good outputs, no false failures

        # Calculate alignment
        if coverage + (1 - ffr) > 0:
            alignment = 2 * (coverage * (1 - ffr)) / (coverage + (1 - ffr))
        else:
            alignment = 0

        # Get the evaluation type (assuming it's consistent for all results of this assertion)
        eval_type = results[0][2] if results else "unknown"

        metrics[assertion] = {
            "type": eval_type,
            "selectivity": selectivity,
            "coverage": coverage,
            "ffr": ffr,
            "alignment": alignment,
            "total_outputs": total_outputs,
            "total_good": total_good,
            "total_bad": total_bad,
            "passes": passes,
            "fails": fails,
            "fails_on_bad": fails_on_bad,
            "fails_on_good": fails_on_good
        }

    return metrics

In [None]:
# Calculate metrics
metrics = calculate_metrics(assertion_results)
print("Assertion metrics:", json.dumps(metrics, indent=2))

In [None]:
def select_best_assertions(
    metrics: Dict[str, Dict[str, Union[float, str]]],
    num_llm_tests: int = None,
    num_code_tests: int = None,
    alignment_threshold: float = None
) -> Dict[str, str]:
    best_assertions = {}
    
    # First, filter assertions based on the alignment threshold
    if alignment_threshold is not None:
        filtered_metrics = {
            a: m for a, m in metrics.items() 
            if m['alignment'] >= alignment_threshold
        }
    else:
        filtered_metrics = metrics
    
    # Separate assertions by type
    llm_assertions = [a for a, m in filtered_metrics.items() if m['type'] == 'llm']
    code_assertions = [a for a, m in filtered_metrics.items() if m['type'] == 'python']
    
    # Sort assertions by alignment score
    llm_assertions.sort(key=lambda a: filtered_metrics[a]['alignment'], reverse=True)
    code_assertions.sort(key=lambda a: filtered_metrics[a]['alignment'], reverse=True)
    
    # Select top N assertions for each type
    if num_llm_tests is not None:
        best_assertions.update({a: 'llm' for a in llm_assertions[:num_llm_tests]})
    
    if num_code_tests is not None:
        best_assertions.update({a: 'python' for a in code_assertions[:num_code_tests]})
    
    # If no criteria provided or no assertions selected, select the best assertion overall
    if not best_assertions and filtered_metrics:
        best_assertion = max(filtered_metrics.keys(), key=lambda a: filtered_metrics[a]['alignment'])
        best_assertions[best_assertion] = filtered_metrics[best_assertion]['type']
    
    return best_assertions

In [None]:
best_assertions = select_best_assertions(
    metrics,
    num_llm_tests=2,
    num_code_tests=1,
    alignment_threshold=0.0
)

In [None]:
#TODO: Also add filters based on criteria so no two assertions solve the same criteria
best_assertions

In [None]:
def get_best_assertion_details(best_assertions: Dict[str, str], all_assertions: List[Union[PythonAssertion, LLMAssertion]]) -> List[Union[PythonAssertion, LLMAssertion]]:
    best_assertion_details = []
    
    for assertion_name, assertion_type in best_assertions.items():
        matching_assertions = [
            assertion for assertion in all_assertions
            if assertion.test_name == assertion_name and assertion.evaluation_type == assertion_type
        ]
        
        if matching_assertions:
            best_assertion_details.append(matching_assertions[0])
        else:
            print(f"Warning: No matching assertion found for {assertion_name} of type {assertion_type}")
    
    return best_assertion_details

# Usage
best_assertion_details = get_best_assertion_details(best_assertions, all_assertions)

# Print the details of the best assertions
for assertion in best_assertion_details:
    print(f"Test Name: {assertion.test_name}")
    print(f"Evaluation Type: {assertion.evaluation_type}")
    if isinstance(assertion, PythonAssertion):
        print(f"Code:\n{assertion.code}")
    elif isinstance(assertion, LLMAssertion):
        print(f"Text: {assertion.text}")
    print("-" * 50)

In [None]:
best_assertion_details

In [None]:
# def calculate_overall_metrics(data: List[DataPoint], best_assertions: Dict[str, str], assertion_results: Dict[str, List[int]]) -> Dict[str, float]:
#     total_outputs = len(data)
#     total_bad = sum(1 for _, _, annotation, _ in data if annotation == 0)
#     total_good = total_outputs - total_bad

#     fails_on_bad = sum(1 for i, (_, _, annotation, _) in enumerate(data) 
#                        if annotation == 0 and any(assertion_results[assertion][i] == 0 for assertion in best_assertions.values()))
#     fails_on_good = sum(1 for i, (_, _, annotation, _) in enumerate(data) 
#                         if annotation == 1 and any(assertion_results[assertion][i] == 0 for assertion in best_assertions.values()))

#     coverage = fails_on_bad / total_bad if total_bad > 0 else 0
#     ffr = fails_on_good / total_good if total_good > 0 else 0
#     alignment = 2 * (coverage * (1 - ffr)) / (coverage + (1 - ffr)) if (coverage + (1 - ffr)) > 0 else 0

#     return {
#         "coverage": coverage,
#         "ffr": ffr,
#         "alignment": alignment
#     }

In [None]:
# # Calculate overall metrics
# overall_metrics = calculate_overall_metrics(data, best_assertions, assertion_results)
# print("Overall metrics:", json.dumps(overall_metrics, indent=2))

In [None]:
# # Generate final report
# report = {
#     "final_assertions": best_assertions,
#     "assertion_metrics": {assertion: metrics[assertion] for assertion in best_assertions.values()},
#     "overall_metrics": overall_metrics
# }

# print("\nFinal Report:")
# print(json.dumps(report, indent=2))

In [None]:
# Ensure we can improve the selecged aligned judges when a new batch of annotations come thru the application

In [None]:
# Make the weave.Scorer workflow work somehow

In [None]:
# Final Dash post alignment
# - Show big number of the final alignment 
# - The list of all assertions -> Weave.object (or Scorer
# - MATCH THE SCORE FROM PAPER or do better