In [1]:
import openai
import pandas as pd
from typing import List, Dict, Tuple, Literal, Optional
from pydantic import BaseModel, Field
import weave
import instructor
from set_env import set_env

In [2]:
set_env("OPENAI_API_KEY")
set_env("WANDB_API_KEY")

        Unable to set WANDB_API_KEY=WANDB_API_KEY,
        not in colab or Secrets not set, not kaggle
        or Secrets not set, no .env/dotenv/env file
        in the current working dir or parent dirs.[0m


loading envfile='/Users/anishshah/Documents/Manual Library/GitHub(1)/improve-evals/.env' with dotenv_values(envfile)


In [3]:
weave.init("medical_data_results")

Logged in as Weights & Biases user: a-sh0ts.
View Weave data at https://wandb.ai/a-sh0ts/medical_data_results/weave


<weave.weave_client.WeaveClient at 0x168ed70b0>

In [4]:
N_SAMPLES = 3

In [5]:
client = openai.OpenAI()

medical_task = """
You are extracting insights from some medical records.
The records contain a medical note and a
dialogue between a doctor and a patient. You need
to extract values for the following: Chief
complaint, History of present illness, Physical
examination, symptoms experienced by the patient,
New medications prescribed or changed, including
dosages (N/A if not provided), and Follow-up
instructions (N/A if not provided). Your answer
should not include any personal identifiable
information (PII) such as name, age, gender, or
ID. Use "the patient" instead of their name, for
example. Return your answer as a bullet list,
where each bullet is formatted like •chief
complaint: xx. If there is no value for the key,
the value should be N/A. Keep your response
around 150 words (you may have to summarize some
extracted values to stay within the word limit).
{transcript}
"""

medical_system_prompt = """
You are a medical data extraction AI assistant. Your task is to accurately extract and summarize key medical information from patient records, adhering strictly to privacy guidelines and formatting instructions provided in the user's prompt. Focus on relevance and conciseness while ensuring all required fields are addressed.
"""

medical_dataset_url = "https://raw.githubusercontent.com/wyim/aci-bench/main/data/challenge_data/train.csv"


In [6]:
def load_medical_data(url: str, num_samples: int = N_SAMPLES) -> List[Dict]:
    df = pd.read_csv(url)
    samples = df.sample(n=num_samples, random_state=42)
    return samples.to_dict('records')

In [7]:
samples = load_medical_data(medical_dataset_url)

In [8]:
samples[0]

{'dataset': 'aci',
 'encounter_id': 'D2N037',
 'dialogue': "[doctor] hey dylan what's going on so i lift quite a bit of weights i try to stay in shape as much as i can i'm not like normal people i lift heavy weights and my elbow is extremely sore which elbow is it\n[patient] actually it's both my elbows but my right elbow is hurting me the most\n[doctor] okay and you said you lift a lot of weights\n[patient] mm-hmm\n[doctor] did you play any sports when you were younger\n[patient] no anything you can think of primarily it was basketball baseball and football\n[doctor] okay and did your elbows hurt at that time or is this a a new injury\n[patient] it's new\n[doctor] when did it start\n[patient] probably year and a half ago\n[doctor] okay on both elbows about a year and a half ago\n[patient] yeah\n[doctor] okay have you taken anything for the pain\n[patient] ibuprofen eight hundred milligrams three times a day\n[doctor] okay and does anything make it better or worse\n[patient] the more i

In [9]:
def format_transcript(record):
    dialogue = record['dialogue'].replace('\n', ' ')
    note = record['note'].replace('\n', ' ')
    transcript = f"Dialogue: {dialogue}\n\nMedical Note: {note}"
    return transcript

@weave.op()
def process_medical_record(record: Dict) -> Dict:
    transcript = format_transcript(record)
    prompt = medical_task.format(transcript=transcript)
    
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": medical_system_prompt},
            {"role": "user", "content": prompt}
        ],
    )
    
    extracted_info = response.choices[0].message.content
    
    return {
        "input": transcript,
        "output": extracted_info,
    }

@weave.op()
def generate_medical_data(num_samples: int = N_SAMPLES) -> List[Dict]:
    data = load_medical_data(medical_dataset_url, num_samples)
    processed_data = []
    
    for record in data:
        processed_record = process_medical_record(record)
        processed_data.append(processed_record)
    
    return processed_data

In [10]:
results = generate_medical_data()

🍩 https://wandb.ai/a-sh0ts/medical_data_results/r/call/0191e175-3379-76a2-9067-72a010f27202


In [11]:
results[0:2]

[{'input': "Dialogue: [doctor] hey dylan what's going on so i lift quite a bit of weights i try to stay in shape as much as i can i'm not like normal people i lift heavy weights and my elbow is extremely sore which elbow is it [patient] actually it's both my elbows but my right elbow is hurting me the most [doctor] okay and you said you lift a lot of weights [patient] mm-hmm [doctor] did you play any sports when you were younger [patient] no anything you can think of primarily it was basketball baseball and football [doctor] okay and did your elbows hurt at that time or is this a a new injury [patient] it's new [doctor] when did it start [patient] probably year and a half ago [doctor] okay on both elbows about a year and a half ago [patient] yeah [doctor] okay have you taken anything for the pain [patient] ibuprofen eight hundred milligrams three times a day [doctor] okay and does anything make it better or worse [patient] the more i use my hands or my arms the more it hurts [doctor] o

In [12]:
client = instructor.patch(openai.OpenAI())

In [13]:
class AnnotationResult(BaseModel):
    annotation: Literal[0, 1] = Field(
        description="Binary score: 1 if the extraction meets all criteria, 0 if it fails on any"
    )
    note: str = Field(
        description="Brief explanation of the annotation decision, highlighting any issues or exemplary aspects"
    )

annotation_prompt = """
    Review the following medical data extraction task results:

    Task System Prompt:
    {medical_system_prompt}

    Task:
    {medical_task}

    Input:
    {input_text}

    Output:
    {output_text}

    Evaluate the extraction based on these criteria:
    1. Completeness: All required fields addressed (Chief complaint, History of present illness, Physical examination, Symptoms, New medications with dosages, Follow-up instructions)
    2. Accuracy: Information correctly extracted from input
    3. Format: Proper bullet list format used (•key: value)
    4. Privacy: No personal identifiable information (PII) included
    5. Conciseness: ~150 words, key information summarized
    6. Use of "N/A" for missing information

    Provide:
    1. Annotation: 1 if the extraction meets all criteria, 0 if it fails on any
    2. Note: Brief explanation of your decision, highlighting any issues or exemplary aspects
"""

annotation_system_prompt = """
You are an AI assistant tasked with evaluating medical data extraction results.
"""

In [14]:
@weave.op()
def process_annotation(input_text: str, output_text: str) -> AnnotationResult:
    prompt = annotation_prompt.format(medical_system_prompt=medical_system_prompt, medical_task=medical_task, input_text=input_text, output_text=output_text)
    
    return client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": annotation_system_prompt},
            {"role": "user", "content": prompt}
        ],
        response_model=AnnotationResult
    )

In [15]:
DataPoint = Tuple[dict, dict, Literal[0, 1], str, Optional[str], Optional[str]]

@weave.op()
def generate_annotations(results: List[Dict]) -> List[DataPoint]:
    annotations = []
    
    for result in results:
        input_text = result["input"]
        output_text = result["output"]
        annotation_result = process_annotation(input_text, output_text)

        combined_task_description = f"System Prompt: {medical_system_prompt}\n\nTask: {medical_task}"
        
        data_point: DataPoint = (
            {"input": input_text},  # input
            {"output": output_text},  # output
            annotation_result.annotation,  # annotation (1 for correct, 0 for incorrect)
            annotation_result.note,  # note
            combined_task_description,  # human_description_for_task_or_judge
            "word count, presence of the six targeted keys, and absence of PII, with the first two implemented via code- based assertions and the last via an LLM evaluator"  # human_description_for_metric_details
        )
        
        annotations.append(data_point)
    
    return annotations

In [16]:
annotations = generate_annotations(results)

🍩 https://wandb.ai/a-sh0ts/medical_data_results/r/call/0191e175-6375-70c0-b091-9c4a90970f8d


In [17]:
annotations[0]

({'input': "Dialogue: [doctor] hey dylan what's going on so i lift quite a bit of weights i try to stay in shape as much as i can i'm not like normal people i lift heavy weights and my elbow is extremely sore which elbow is it [patient] actually it's both my elbows but my right elbow is hurting me the most [doctor] okay and you said you lift a lot of weights [patient] mm-hmm [doctor] did you play any sports when you were younger [patient] no anything you can think of primarily it was basketball baseball and football [doctor] okay and did your elbows hurt at that time or is this a a new injury [patient] it's new [doctor] when did it start [patient] probably year and a half ago [doctor] okay on both elbows about a year and a half ago [patient] yeah [doctor] okay have you taken anything for the pain [patient] ibuprofen eight hundred milligrams three times a day [doctor] okay and does anything make it better or worse [patient] the more i use my hands or my arms the more it hurts [doctor] o

In [18]:
weave.publish(annotations, name="medical_data_annotations")

📦 Published to https://wandb.ai/a-sh0ts/medical_data_results/weave/objects/medical_data_annotations/versions/4utHXhRnO2oquowlrvJxCztretSxtFavBUHVciMZJBw


ObjectRef(entity='a-sh0ts', project='medical_data_results', name='medical_data_annotations', digest='4utHXhRnO2oquowlrvJxCztretSxtFavBUHVciMZJBw', extra=())