In [12]:
import openai
import json
from typing import List, Tuple, Dict, Optional
from collections import defaultdict
import instructor
import weave
from set_env import set_env
from pydantic import BaseModel, Field
from typing import List, Literal, Union
from pprint import pprint

In [2]:
set_env("OPENAI_API_KEY")
set_env("WANDB_API_KEY")

        Unable to set WANDB_API_KEY=WANDB_API_KEY,
        not in colab or Secrets not set, not kaggle
        or Secrets not set, no .env/dotenv/env file
        in the current working dir or parent dirs.[0m


loading envfile='/Users/anishshah/Documents/Manual Library/GitHub(1)/improve-evals/.env' with dotenv_values(envfile)


In [3]:
client = instructor.from_openai(openai.OpenAI())

In [4]:
DataPoint = Tuple[dict, dict, Literal[0, 1], Optional[str], Optional[str], Optional[str]]  # (input, output, annotation, note, human_description_for_task_or_judge, human_description_for_metric_details)

In [5]:
MODEL = "gpt-4o-2024-08-06"

In [6]:
TEST_TASK = "medical"
if TEST_TASK == "medical":
    data = weave.ref("weave:///a-sh0ts/medical_data_results/object/medical_data_annotations:4utHXhRnO2oquowlrvJxCztretSxtFavBUHVciMZJBw").get()
elif TEST_TASK == "product":
    pass
else:
    data = [
        ({"text": "Summarize the impact of climate change on polar bears."}, {"text": "Climate change is reducing sea ice, which polar bears rely on for hunting seals."}, 1, "Accurate and relevant."),
        ({"text": "Explain the process of photosynthesis."}, {"text": "Photosynthesis is the process by which plants use sunlight to synthesize foods from carbon dioxide and water."}, 1, "Correct and detailed."),
        ({"text": "What are the main causes of the American Civil War?"}, {"text": "The main causes were slavery, states' rights, and economic differences."}, 1, "Concise and accurate."),
        ({"text": "Describe the symptoms of COVID-19."}, {"text": "COVID-19 is caused by a virus that originated in bats."}, 0, "Irrelevant and incorrect."),
        ({"text": "What is the significance of the Magna Carta?"}, {"text": "The Magna Carta was a document that limited the power of the king and established certain legal rights."}, 1, "Historically accurate and relevant.")
    ]

In [84]:
pprint(data[0])

WeaveList([{'input': "Dialogue: [doctor] hey dylan what's going on so i lift quite a bit of weights i try to stay in shape as much as i can i'm not like normal people i lift heavy weights and my elbow is extremely sore which elbow is it [patient] actually it's both my elbows but my right elbow is hurting me the most [doctor] okay and you said you lift a lot of weights [patient] mm-hmm [doctor] did you play any sports when you were younger [patient] no anything you can think of primarily it was basketball baseball and football [doctor] okay and did your elbows hurt at that time or is this a a new injury [patient] it's new [doctor] when did it start [patient] probably year and a half ago [doctor] okay on both elbows about a year and a half ago [patient] yeah [doctor] okay have you taken anything for the pain [patient] ibuprofen eight hundred milligrams three times a day [doctor] okay and does anything make it better or worse [patient] the more i use my hands or my arms the more it hurts 

In [21]:
class TaskDescription(BaseModel):
    description: str = Field(..., description="A concise yet comprehensive task description")

# TODO: Batch this as opposed to one at a time
# or sample the dataset and ensure that taking into tokens (maybe something fun with a distribution)
# distribution = more stuff we can grab and throw into prompt in smart way
def get_task_description(data: List[DataPoint]) -> str:
    task_description = ""
    
    for i, datapoint in enumerate(data):
        input_data, output_data, annotation, note = datapoint[0], datapoint[1], datapoint[2], datapoint[3]
        
        prompt = f"""
        Current task description: {task_description}

        New datapoint:
        Input: {input_data}
        Output: {output_data}
        Annotation: {"Correct" if annotation == 1 else "Incorrect"}
        Note: {note}

        Based on this new datapoint and the current task description, provide an updated, more refined task description. 
        If this is the first datapoint, create an initial task description.
        Focus on:
        1. The nature of the input and output data
        2. The specific information being extracted or transformed
        3. Any formatting or style requirements
        4. Evaluation criteria (based on the annotation and note)

        Keep the description concise yet comprehensive.
        """

        response = client.chat.completions.create(
            model=MODEL,
            messages=[{"role": "user", "content": prompt}],
            response_model=TaskDescription
        )
        
        new_description = response.description
        
        # TODO: Add guardrails to prevent LLM from saying no update needed
        if new_description.lower().startswith("no update needed"):
            continue
        
        task_description = new_description

    return task_description

In [22]:
llm_task_description = get_task_description(data)

In [26]:
class CombinedDescription(BaseModel):
    description: str = Field(..., description="A comprehensive task description that combines LLM and human insights")

def combine_human_and_llm_descriptions(data: List[DataPoint], llm_description: str) -> str:
    human_descriptions = set()
    for dp in data:
        if len(dp) > 4 and dp[4]:  # Check if human description exists
            human_descriptions.add(dp[4])
    
    if not human_descriptions:
        return llm_description
    
    human_context = "\n".join(f"- {desc}" for desc in human_descriptions)
    
    prompt = f"""
    LLM-generated task description:
    {llm_description}

    Additional human-provided context:
    {human_context}

    Your task is to create a comprehensive, coherent task description that combines insights from both the LLM-generated description and the human-provided context. Ensure that:
    1. The final description is clear and concise.
    2. It incorporates key points from both sources.
    3. Any contradictions are resolved logically.
    4. The description maintains a professional tone.
    5. It provides a complete picture of the task requirements and evaluation criteria.

    Please provide the combined description in a single, well-structured paragraph.
    """

    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": prompt}],
        response_model=CombinedDescription
    )
    
    return response.description

In [27]:
finalized_task_description = combine_human_and_llm_descriptions(data, llm_task_description)

In [28]:
finalized_task_description

"The task involves transforming dialogues between a doctor and a patient into structured medical summaries, combining insights from both a medical note and the dialogue. The goal is to accurately extract key medical information, including: chief complaint, history of present illness, physical examination findings, symptoms, new medications with dosages, and follow-up instructions, using a bullet-point format such as '- key: value'. In situations where data is unavailable, 'N/A' should be used. It's crucial to ensure the summaries are complete, accurate, concise, exclude any personal identifiable information (PII), and adhere to privacy guidelines. The summaries should be well-formatted, focusing on relevance and concise representation of the essential details, avoiding PII such as name, age, gender, or ID. Evaluations will emphasize the completeness, accuracy, conciseness, and adherence to formatting instructions, all while keeping responses within a limit of 150 words, thus necessitat

In [50]:
def format_single_datapoint(dp: DataPoint, finalized_task_description: str) -> str:
    input_data, output_data, annotation, note = dp[0], dp[1], dp[2], dp[3]
    metrics_details = dp[5] if len(dp) > 5 else None

    formatted = [
        f"Task Description: {finalized_task_description}",
        "",
        "Input:",
        "\n".join(f"  {key.capitalize()}: {value}" for key, value in input_data.items()),
        "",
        "Output:",
        "\n".join(f"  {key.capitalize()}: {value}" for key, value in output_data.items()),
        "",
        f"Annotation: {'Correct' if annotation == 1 else 'Incorrect'}",
        f"Note: {note}"
    ]

    if metrics_details:
        formatted.append(f"Metrics Details: {metrics_details}")

    return "\n".join(formatted)

In [51]:
formatted_dp = format_single_datapoint(data[0], finalized_task_description)

In [52]:
pprint(formatted_dp)

('Task Description: The task involves transforming dialogues between a doctor '
 'and a patient into structured medical summaries, combining insights from '
 'both a medical note and the dialogue. The goal is to accurately extract key '
 'medical information, including: chief complaint, history of present illness, '
 'physical examination findings, symptoms, new medications with dosages, and '
 "follow-up instructions, using a bullet-point format such as '- key: value'. "
 "In situations where data is unavailable, 'N/A' should be used. It's crucial "
 'to ensure the summaries are complete, accurate, concise, exclude any '
 'personal identifiable information (PII), and adhere to privacy guidelines. '
 'The summaries should be well-formatted, focusing on relevance and concise '
 'representation of the essential details, avoiding PII such as name, age, '
 'gender, or ID. Evaluations will emphasize the completeness, accuracy, '
 'conciseness, and adherence to formatting instructions, all w

In [58]:
class Criterion(BaseModel):
    criterion: str = Field(..., description="A concise, specific statement describing a single aspect of evaluation")
    explanation: str = Field(..., description="A detailed explanation of the criterion's importance and potential evaluation methods")
    evaluation_method: Literal["code", "llm"] = Field(..., description="The primary method for evaluating this criterion: 'code' for programmatic checks, 'llm' for language model-based assessment")

class EvaluationCriteria(BaseModel):
    criteria: List[Criterion] = Field(
        ...,
        min_items=1,
        max_items=3,
        description="A list of 1-3 distinct evaluation criteria, each focusing on a different aspect of output quality"
    )

#TODO: take into account any previous iterations of successfully aligned criteria judges
@weave.op()
def process_criteria(formatted_dp: str, finalized_task_description: str) -> str:
    prompt = f"""
Task Description: {finalized_task_description}

Analyze the following annotated datapoint:

{formatted_dp}

Generate 1-3 evaluation criteria that can be used to assess the quality of outputs for this task. Consider the following guidelines:

1. If a 'Metrics Details' field is present in the datapoint, prioritize this information as it provides the most important evaluation criteria.
2. Focus on general aspects of quality that can be used across multiple outputs.
3. Consider criteria that address potential misalignment between LLM outputs and human preferences.
4. Include criteria that can be evaluated both by code and by LLM-based evaluators.
5. Think about criteria that might reveal hallucinations, instruction-following, or other common LLM issues.
6. Generate criteria that could help in debugging or improving the LLM pipeline.

Provide each criterion as a concise statement, followed by a brief explanation of why it's important and how it might be evaluated (e.g., via code, LLM evaluator, or human judgment).

Return the criteria in this format:
[Criterion]: [Brief explanation and evaluation method]
[Criterion]: [Brief explanation and evaluation method]
[Criterion]: [Brief explanation and evaluation method]

Aim for a mix of straightforward, code-evaluable criteria and more nuanced criteria that might require LLM or human evaluation.
"""
    response = client.chat.completions.create(
        model=MODEL,
        max_tokens=500,
        messages=[{"role": "user", "content": prompt}],
        response_model=EvaluationCriteria
    )
    return response

@weave.op()
def generate_criteria(data: List[DataPoint], finalized_task_description: str) -> List[str]:
    all_criteria = []
    
    for dp in data:
        formatted_dp = format_single_datapoint(dp, finalized_task_description)
        response = process_criteria(formatted_dp, finalized_task_description)
        
        #TODO: Add an additional check to see if a nearly identical criterion is already in the list
        new_criteria = [c for c in response.criteria]
        all_criteria.extend(new_criteria)
    
    return list(all_criteria)

In [59]:
# Generate criteria
criteria = generate_criteria(data, finalized_task_description)

In [63]:
criteria[0].dict()

{'criterion': 'Completeness',
 'explanation': 'Ensure all required fields (chief complaint, history of present illness, physical examination, symptoms, medications, follow-up instructions) are addressed. This is critical to maintain the integrity and utility of the medical summary.',
 'evaluation_method': 'code'}

In [86]:
criteria

[Criterion(criterion='Completeness', explanation='Ensure all required fields (chief complaint, history of present illness, physical examination, symptoms, medications, follow-up instructions) are addressed. This is critical to maintain the integrity and utility of the medical summary.', evaluation_method='code'),
 Criterion(criterion='Presence of Personal Identifiable Information (PII)', explanation='Ensure no PII such as name, age, gender, or ID is included to adhere to privacy guidelines, which is crucial for ethical considerations in handling sensitive medical data.', evaluation_method='llm'),
 Criterion(criterion='Conciseness and Adherence to Word Limit', explanation='Ensure the summary is concise and within a 150-word limit to maintain readability and focus on key medical information. This can be evaluated by checking the word count and assessing if information is essential.', evaluation_method='code'),
 Criterion(criterion='Completeness and Accuracy', explanation='The output must

In [79]:
class PythonAssertion(BaseModel):
    code: str = Field(..., description="A clear, concise assertion written as a unittest.TestCase method.")
    evaluation_type: Literal["python"]

class LLMAssertion(BaseModel):
    text: str = Field(..., description="in the works")
    evaluation_type: Literal["llm"]

class CriterionAssertions(BaseModel):
    assertions: List[Union[PythonAssertion, LLMAssertion]] = Field(
        ...,
        min_items=1,
        max_items=3,
        description="Generate 1-3 specific, testable assertions that can be used to evaluate LLM outputs based on the given criterion"
    )

def create_candidate_assertions(formatted_data_string: str, criterion: Criterion) -> CriterionAssertions:
    prompt = f"""
Given the following evaluation criterion and annotated data, generate 1-3 specific, testable assertions:

Criterion: {criterion.dict()}

Annotated data: {formatted_data_string}

Your task is to create assertions that can be used to evaluate LLM outputs based on this criterion. Follow these guidelines:

1. Make each assertion clear, concise, and directly related to the criterion
2. For Python assertions:
   - Provide a valid Python method that can be used within a unittest.TestCase class
   - The method should take 'self' and 'output' as parameters, where 'output' is the LLM output being evaluated
   - Use unittest assertion methods (e.g., self.assertTrue, self.assertEqual) to test the output
   - The test should pass if the assertion is met, and fail otherwise
3. For LLM assertions:
   - Provide a clear, detailed prompt for an LLM to evaluate the assertion
   - The prompt should guide the LLM to return "PASS" or "FAIL" based on the evaluation
4. Include a mix of positive and negative assertions where appropriate
5. Consider edge cases and potential failure modes for the criterion
6. Aim for assertions that could be applied across multiple types of outputs

Ensure that your assertions are directly evaluable and avoid vague or subjective language. Focus on creating assertions that align with human preferences and can be used to validate the quality of LLM-generated evaluations.

Format your response as a JSON object with the following structure:
{{
  "assertions": [
    {{
      "text": "Assertion text or code",
      "evaluation_type": "python" or "llm"
    }},
    ...
  ]
}}
"""
    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": prompt}],
        response_model=CriterionAssertions
    )
    return response.assertions

In [80]:
#TODO: improve this function
def format_all_datapoints(data: List[DataPoint], finalized_task_description: str) -> str:
    formatted = [f"Task Description: {finalized_task_description}\n"]
    
    for i, dp in enumerate(data, 1):
        input_data, output_data, annotation, note = dp[0], dp[1], dp[2], dp[3]
        
        formatted.extend([
            f"Example {i}:",
            "Input:",
            json.dumps(input_data, indent=2),
            "",
            "Output:",
            json.dumps(output_data, indent=2),
            "",
            f"Annotation: {'Correct' if annotation == 1 else 'Incorrect'}",
            f"Note: {note}",
            "\n" + "-"*50 + "\n"  # Separator between examples
        ])
    
    return "\n".join(formatted)

In [81]:
formatted_data = format_all_datapoints(data, finalized_task_description)

In [82]:
pprint(formatted_data)

('Task Description: The task involves transforming dialogues between a doctor '
 'and a patient into structured medical summaries, combining insights from '
 'both a medical note and the dialogue. The goal is to accurately extract key '
 'medical information, including: chief complaint, history of present illness, '
 'physical examination findings, symptoms, new medications with dosages, and '
 "follow-up instructions, using a bullet-point format such as '- key: value'. "
 "In situations where data is unavailable, 'N/A' should be used. It's crucial "
 'to ensure the summaries are complete, accurate, concise, exclude any '
 'personal identifiable information (PII), and adhere to privacy guidelines. '
 'The summaries should be well-formatted, focusing on relevance and concise '
 'representation of the essential details, avoiding PII such as name, age, '
 'gender, or ID. Evaluations will emphasize the completeness, accuracy, '
 'conciseness, and adherence to formatting instructions, all w

In [85]:
all_assertions = []
for criterion in criteria:
    assertions = create_candidate_assertions(formatted_data, criterion)
    all_assertions.extend(assertions)
    print(f"\nCriterion: {criterion}")
    print("Candidate assertions:")
    for assertion in assertions:
        print(f"  Type: {assertion.evaluation_type.upper()}")
        if assertion.evaluation_type == 'python':
            print("  Code:")
            print("    " + assertion.code.replace('\n', '\n    '))
        else:
            print("  Prompt:")
            print("    " + assertion.text.replace('\n', '\n    '))
    print("-" * 80)


Criterion: criterion='Completeness' explanation='Ensure all required fields (chief complaint, history of present illness, physical examination, symptoms, medications, follow-up instructions) are addressed. This is critical to maintain the integrity and utility of the medical summary.' evaluation_method='code'
Candidate assertions:
  Type: PYTHON
  Code:
    def test_completeness_of_fields(self, output):
        required_fields = ['chief complaint', 'history of present illness', 'physical examination', 'symptoms', 'new medications prescribed or changed', 'follow-up instructions']
        for field in required_fields:
            self.assertIn(f'- {field}', output, f"Missing field: {field}")
  Type: PYTHON
  Code:
    def test_use_of_na_when_data_unavailable(self, output):
        # Check if 'N/A' is present when data is not available for new medications
        if '- new medications prescribed or changed: N/A' not in output:
            self.assertNotIn('new medications prescribed or c

In [78]:
all_assertions

[PythonAssertion(text="def has_all_required_fields(output):\n    required_fields = ['chief complaint', 'history of present illness', 'physical examination', 'symptoms experienced by the patient', 'new medications prescribed or changed', 'follow-up instructions']\n    return all(field in output for field in required_fields)\n\nassert has_all_required_fields(output)", evaluation_type='python'),
 PythonAssertion(text="def uses_NA_for_missing_data(output):\n    fields_with_NA = ['new medications prescribed or changed']\n    for field in fields_with_NA:\n        if field in output and output[field] == 'N/A':\n            return True\n    return False\n\nassert not uses_NA_for_missing_data(output)  # Test Negative Case", evaluation_type='python'),
 LLMAssertion(text="Evaluate whether the given medical summary from an LLM addresses all the specified fields: chief complaint, history of present illness, physical examination, symptoms experienced by the patient, new medications with dosages, and

In [None]:
from code_model_wrapper import run_code_evaluation
import asyncio

def evaluate_python_assertion(assertion, data):
    asyncio.run(run_code_evaluation(assertion, data))
    

In [None]:
#TODO: use a faster/smaller model gpt4o-mini
def evaluate_llm_assertion(assertion, datum):
    pass

In [None]:
def evaluate_assertions(datum, assertion):
    if assertion.evaluation_type == 'python':
        return evaluate_python_assertion(assertion.text, datum)
    else:
        return evaluate_llm_assertion(assertion.text, datum)
    
    return None

In [None]:
# Evaluate assertions
assertion_results = evaluate_assertions(data, assertions)

1. **Selectivity**:
   ```python
   selectivity = passes / total_outputs
   ```
   Selectivity measures how often an assertion passes LLM outputs. A lower selectivity means the assertion is more "picky" or strict.

2. **Coverage**:
   ```python
   coverage = fails_on_bad / total_bad if total_bad > 0 else 0
   ```
   Coverage measures how well our assertions catch the outputs that humans marked as bad. A higher coverage means we're better at identifying problematic outputs.

3. **False Failure Rate (FFR)**:
   ```python
   ffr = fails_on_good / total_good if total_good > 0 else 0
   ```
   FFR shows how often our assertions incorrectly fail outputs that humans thought were good. A lower FFR is better, as it means we're not being overly strict.

4. **Alignment**:
   ```python
   alignment = 2 * (coverage * (1 - ffr)) / (coverage + (1 - ffr)) if (coverage + (1 - ffr)) > 0 else 0
   ```
   Alignment combines coverage and FFR into a single score. It represents how well our automated evaluations match human judgments overall.

These metrics help us refine our assertion set over time, aiming to catch more bad outputs while avoiding false alarms on good ones

We need to get Alignment score right

In [None]:
def calculate_metrics(data: List[DataPoint], assertion_results: Dict[str, List[int]]) -> Dict[str, Dict[str, float]]:
    metrics = {}
    total_outputs = len(data)
    total_bad = sum(1 for _, _, annotation, _ in data if annotation == 0)
    total_good = total_outputs - total_bad

    for assertion, results in assertion_results.items():
        passes = sum(results)
        fails_on_bad = sum(1 for (_, _, annotation), result in zip(data, results) if annotation == 0 and result == 0)
        fails_on_good = sum(1 for (_, _, annotation), result in zip(data, results) if annotation == 1 and result == 0)

        selectivity = passes / total_outputs
        coverage = fails_on_bad / total_bad if total_bad > 0 else 0
        ffr = fails_on_good / total_good if total_good > 0 else 0
        alignment = 2 * (coverage * (1 - ffr)) / (coverage + (1 - ffr)) if (coverage + (1 - ffr)) > 0 else 0
        # human_accuracy = 

        metrics[assertion] = {
            "selectivity": selectivity,
            "coverage": coverage,
            "ffr": ffr,
            "alignment": alignment
        }

    return metrics

In [None]:
# Calculate metrics
metrics = calculate_metrics(data, assertion_results)
print("Assertion metrics:", json.dumps(metrics, indent=2))

In [None]:
#Rerthink this especially for python case because unittest may not always be pure 1 assertion test (like a test suite)
def select_best_assertions(assertions: Dict[str, List[str]], metrics: Dict[str, Dict[str, float]]) -> Dict[str, str]:
    best_assertions = {}
    for criterion, assertion_list in assertions.items():
        best_assertion = max(assertion_list, key=lambda a: metrics[a]['alignment'])
        best_assertions[criterion] = best_assertion
    return best_assertions

In [None]:
# Select best assertions
best_assertions = select_best_assertions(assertions, metrics)
print("Best assertions:", json.dumps(best_assertions, indent=2))

In [None]:
def calculate_overall_metrics(data: List[DataPoint], best_assertions: Dict[str, str], assertion_results: Dict[str, List[int]]) -> Dict[str, float]:
    total_outputs = len(data)
    total_bad = sum(1 for _, _, annotation, _ in data if annotation == 0)
    total_good = total_outputs - total_bad

    fails_on_bad = sum(1 for i, (_, _, annotation, _) in enumerate(data) 
                       if annotation == 0 and any(assertion_results[assertion][i] == 0 for assertion in best_assertions.values()))
    fails_on_good = sum(1 for i, (_, _, annotation, _) in enumerate(data) 
                        if annotation == 1 and any(assertion_results[assertion][i] == 0 for assertion in best_assertions.values()))

    coverage = fails_on_bad / total_bad if total_bad > 0 else 0
    ffr = fails_on_good / total_good if total_good > 0 else 0
    alignment = 2 * (coverage * (1 - ffr)) / (coverage + (1 - ffr)) if (coverage + (1 - ffr)) > 0 else 0

    return {
        "coverage": coverage,
        "ffr": ffr,
        "alignment": alignment
    }

In [None]:
# Calculate overall metrics
overall_metrics = calculate_overall_metrics(data, best_assertions, assertion_results)
print("Overall metrics:", json.dumps(overall_metrics, indent=2))

In [None]:
# Generate final report
report = {
    "final_assertions": best_assertions,
    "assertion_metrics": {assertion: metrics[assertion] for assertion in best_assertions.values()},
    "overall_metrics": overall_metrics
}

print("\nFinal Report:")
print(json.dumps(report, indent=2))

In [None]:
# Ensure we can improve the selecged aligned judges when a new batch of annotations come thru the application

In [None]:
# Make the weave.Scorer workflow work somehow

In [None]:
# Final Dash post alignment
# - Show big number of the final alignment 
# - The list of all assertions -> Weave.object (or Scorer
# - MATCH THE SCORE FROM PAPER or do better