In [None]:
from typing import Dict, List, Literal, Optional, Tuple

import instructor
import openai
import pandas as pd
import weave
from pydantic import BaseModel, Field
from set_env import set_env

In [None]:
set_env("OPENAI_API_KEY")
set_env("WANDB_API_KEY")
print("Env set")

In [None]:
from notebooks.utils.config import ENTITY, WEAVE_PROJECT

In [None]:
weave.init(f"{ENTITY}/{WEAVE_PROJECT}")

In [5]:
N_SAMPLES = 67

In [None]:
from notebooks.utils.prompts import medical_system_prompt, medical_task

In [7]:
client = openai.OpenAI()

medical_dataset_url = "https://raw.githubusercontent.com/wyim/aci-bench/main/data/challenge_data/train.csv"

In [8]:
def load_medical_data(url: str, num_samples: int = N_SAMPLES) -> List[Dict]:
    df = pd.read_csv(url)
    print(df.shape)
    samples = df.sample(n=num_samples, random_state=42)
    return samples.to_dict("records")

In [None]:
samples = load_medical_data(medical_dataset_url)

In [None]:
samples[0]

In [11]:
def format_transcript(record):
    dialogue = record["dialogue"].replace("\n", " ")
    note = record["note"].replace("\n", " ")
    transcript = f"Dialogue: {dialogue}\n\nMedical Note: {note}"
    return transcript


@weave.op()
def process_medical_record(record: Dict) -> Dict:
    transcript = format_transcript(record)
    prompt = medical_task.format(transcript=transcript)

    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": medical_system_prompt},
            {"role": "user", "content": prompt},
        ],
    )

    extracted_info = response.choices[0].message.content

    return {
        "input": transcript,
        "output": extracted_info,
    }


@weave.op()
def generate_medical_data(num_samples: int = N_SAMPLES) -> List[Dict]:
    data = load_medical_data(medical_dataset_url, num_samples)
    processed_data = []

    for record in data:
        processed_record = process_medical_record(record)
        processed_data.append(processed_record)

    return processed_data

In [None]:
results = generate_medical_data()

In [None]:
results[0:2]

In [None]:
weave.publish(results, name="medical_data_raw")

In [15]:
client = instructor.patch(openai.OpenAI())

In [16]:
class MainCriteria(BaseModel):
    word_count: Literal[0, 1] = Field(
        description="1 if the word count is within the limit of 150 words, 0 otherwise",
    )
    presence_of_keys: Literal[0, 1] = Field(
        description="1 if all the six targeted keys (Chief complaint, History of present illness, Physical examination, Symptoms, New medications with dosages, Follow-up instructions) are present, 0 otherwise",
    )
    absence_of_PII: Literal[0, 1] = Field(
        description="1 if no PII is present, 0 otherwise",
    )

In [17]:
# TODO: Make each desired field a separate annotation


class AnnotationResult(BaseModel):
    annotation: Literal[0, 1] = Field(
        description="Binary score: 1 if the extraction meets all criteria, 0 if it fails on any",
    )
    criteria_annotations: MainCriteria = Field(
        description="A score for each of the main criteria",
    )
    note: str = Field(
        description="Brief explanation of the annotation decision, highlighting any issues or exemplary aspects",
    )


annotation_prompt = """
    Review the following medical data extraction task results:

    Task System Prompt:
    {medical_system_prompt}

    Task:
    {medical_task}

    Input:
    {input_text}

    Output:
    {output_text}

    Evaluate the extraction based on these criteria. Only refer to the Output in your evaluation and NOT the Medical Note field:
    1. Completeness: All required fields addressed (Chief complaint, History of present illness, Physical examination, Symptoms, New medications with dosages, Follow-up instructions)
    2. Accuracy: Information correctly extracted from input
    3. Format: Proper bullet list format used (•key: value)
    4. Privacy: No personal identifiable information (PII) included
    5. Conciseness: ~150 words, key information summarized
    6. Use of "N/A" for missing information

    Provide:
    1. Annotation: 1 if the extraction meets all criteria, 0 if it fails on any
    2. Note: Brief explanation of your decision, highlighting any issues or exemplary aspects
"""

annotation_system_prompt = """
You are an AI assistant tasked with evaluating medical data extraction results.
"""

In [18]:
@weave.op()
def process_annotation(input_text: str, output_text: str) -> AnnotationResult:
    prompt = annotation_prompt.format(
        medical_system_prompt=medical_system_prompt,
        medical_task=medical_task,
        input_text=input_text,
        output_text=output_text,
    )

    return client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": annotation_system_prompt},
            {"role": "user", "content": prompt},
        ],
        response_model=AnnotationResult,
    )

In [19]:
DataPoint = Tuple[
    dict,
    dict,
    Literal[0, 1],
    MainCriteria,
    str,
    Optional[str],
    Optional[str],
]


@weave.op()
def generate_annotations(results: List[Dict]) -> List[DataPoint]:
    annotations = []

    for result in results:
        input_text = result["input"]
        output_text = result["output"]
        annotation_result = process_annotation(input_text, output_text)

        combined_task_description = (
            f"System Prompt: {medical_system_prompt}\n\nTask: {medical_task}"
        )

        data_point: DataPoint = (
            {"input": input_text},  # input
            {"output": output_text},  # output
            annotation_result.annotation,  # annotation (1 for correct, 0 for incorrect)
            annotation_result.criteria_annotations.model_dump(),  # criteria_annotations
            annotation_result.note,  # note
            combined_task_description,  # human_description_for_task_or_judge
            "word count, presence of the six targeted keys, and absence of PII, with the first two implemented via code- based assertions and the last via an LLM evaluator",  # human_description_for_metric_details
        )

        annotations.append(data_point)

    return annotations

In [None]:
annotations = generate_annotations(results)

In [None]:
annotations[0]

In [None]:
weave.publish(annotations, name="medical_data_annotations")