## Daata Preparation
Medical conversation data from https://huggingface.co/datasets/yfyeung/medical. This is a dataset consisting of realistic conversations between a real doctor and simulated patients; details (here)[https://www.nature.com/articles/s41597-022-01423-1]. 
 

In [None]:
!pip install -U openai
!pip install jinja2
!pip install datasets

In [1]:
import pandas as pd
import numpy as np
import os
import jinja2 as j2
import json, logging

In [2]:
GEMINI_PRO_MODEL = "models/gemini-2.5-pro-preview-05-06"
GEMINI_FLASH_MODEL = "models/gemini-2.5-flash-preview-04-17-thinking"

In [3]:
# load jinja template from prompts/diagnosis_prompt.txt
prompts_loader = j2.FileSystemLoader('prompts')
SYSTEM_PROMPT = j2.Environment(loader=prompts_loader).get_template('agent_prompt.j2').render(thinking_model=True)
DATA_DIR = "../../data/"

In [4]:
# iterate through the files in the data/raw directory
data_dir = f'{DATA_DIR}/raw'
files = os.listdir(data_dir)
# for each file, read file line by line. aggregate 2 lines in the file into 1 row in a dataframe. 
# The first column is the doctor_response and the second column is the patient_response. All files start with doctor_response.
# The third column is the name of the file
data = []
for i_file, file in enumerate(files):
    try:
        with open(os.path.join(data_dir, file), 'r') as f:
            lines = f.readlines()
            for i in range(0, len(lines), 2):
                doctor_response = lines[i].strip()
                patient_response = lines[i+1].strip() if i+1 < len(lines) else ''
                # the file name is the last part of the file path
                file_name = os.path.basename(file).replace(".txt", "")
                data.append([file_name, doctor_response, patient_response])
    except Exception as e:
        print(f"Error processing file {file}: {e}")
        continue
print(f"Processed {i_file} files; {len(data)} turns.")
# create a dataframe from the data
df = pd.DataFrame(data, columns=['conversation_id','doctor_response', 'patient_response'])
df.to_csv(f'{DATA_DIR}/processed/doctor_patient_conversations.csv', index=False)

Error processing file RES0054.txt: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte
Error processing file RES0002.txt: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte
Processed 271 files; 13288 turns.


In [5]:
df = pd.read_csv(f"{DATA_DIR}/processed/doctor_patient_conversations.csv")

df.head()

Unnamed: 0,conversation_id,doctor_response,patient_response
0,RES0181,HOW MAY I HELP YOU,HI UMM SO IVE HAD A SORE THROAT FOR THE PAST T...
1,RES0181,YEAH FOR SURE SO YOU SAID THAT THE SORE THROAT...,NEITHER ITS BEEN THE SAME
2,RES0181,OK IS IT PAINFUL TO SWALLOW FOOD OR LIQUIDS,ITS PAINFUL TO SWALLOW SOLIDS YEAH
3,RES0181,OK UH AND HOW IS YOUR UMM HAVE YOU HAD ANY VOI...,NO NOT REALLY NO
4,RES0181,OK AND HAVE YOU NOTICED ANY LIKE NECK SWELLING,NO


In [6]:
df.conversation_id.nunique(), df.shape[0]
n_training_examples = df.conversation_id.nunique()

In [7]:
# sample n_training_examples random IDs from the dataframe
def sample_ids(df, n_training_examples):
    # get the unique IDs in the dataframe
    ids = df['conversation_id'].unique()
    ids = np.random.choice(ids, n_training_examples, replace=False)
    # filter the dataframe to only include rows with the sampled IDs
    df_chosen_ids = df[df['conversation_id'].isin(ids)]
    return df_chosen_ids.conversation_id.unique()

In [15]:
from types import SimpleNamespace
config = SimpleNamespace()
config.conversation_length_mean = 20
config.conversation_length_sd = 10
config.diagnosis_model = GEMINI_PRO_MODEL
config.random_exit_prob = 0.1

n_training_examples = 35 #df.conversation_id.nunique()
ids = sample_ids(df, n_training_examples)

In [16]:
from training_data_helper import create_training_chat_conversations
from rubric_helper import create_rubrics
training_conversations, metadata = create_training_chat_conversations(df, ids, config, system_prompt=SYSTEM_PROMPT)
rubric = create_rubrics(training_conversations,prompts_loader=prompts_loader, metadata=metadata)
training_conversations, metadata


Processing conversation ID: MSK0032, predicted length: 25.872491365416288
Processing conversation ID: RES0194, predicted length: 29.12743891588255
Processing conversation ID: GAS0002, predicted length: 22.156833096359655
Processing conversation ID: RES0035, predicted length: 39.056842592020075
Processing conversation ID: MSK0024, predicted length: 31.68337279194011
Processing conversation ID: MSK0031, predicted length: 30.651495646739733
Processing conversation ID: RES0034, predicted length: 25.253453863851067
Processing conversation ID: RES0187, predicted length: 23.54635377916527
Processing conversation ID: RES0018, predicted length: 32.49900888530745
Processing conversation ID: RES0185, predicted length: 13.545358044961704
Processing conversation ID: MSK0050, predicted length: 18.47608679414702
Processing conversation ID: RES0120, predicted length: 16.463121398390914
Processing conversation ID: RES0123, predicted length: 15.475048278036905
Processing conversation ID: RES0132, predic

({'MSK0032_standard_0': [{'role': 'system',
    'content': '/think You are a helpful assistant that helps users with understanding the clinical diagnosis of their medical conditions.\nIn interacting with the patient try to obtain as much critical information as you can\n\n\n[Thought Instructions]\n\n* At each step determine the differential diagnosis. Do not share this with the user but use it to determine the next question to ask.\n* Identify the top question to ask the member to gain higher confidence on the most likely   diagnosis.\n\n\n[Response Instructions]\n\n* Ask the member 1 question at a time. Based on the diagnosis, pick the question that will provide you most information to narrow down the diagnosis.\n* Once you have asked ~25 questions, provide your best diagnosis and what extra information you might need to narrow down.\n\nExamples:\n---------\nUser: I have a cold and fever.\nAssistant: <think> A partial differential diagnosos for these symptoms: a. Cold b. Influenza c. 

In [37]:
modified_rubric_dicts = {}
for cid, irubric in rubric.items():
    cleaned_rubric_str = irubric.replace("```json\n", "").replace("```", "")
    rubric_dict, parsed_dict = {}, None
    try:
        parsed_dict = json.loads(cleaned_rubric_str)
    except json.JSONDecodeError as e:
        pass
    if parsed_dict is not None:
        rubric_dict['think'] = parsed_dict['differential_diagnosis']
        rubric_dict['answer'] = parsed_dict['clarifying_questions']
    else:
        rubric_dict['think'] = []
        rubric_dict['answer'] = cleaned_rubric_str
    modified_rubric_dicts[cid] = json.dumps(rubric_dict)
print(f"Modified rubric dicts: {modified_rubric_dicts[cid]}")

Modified rubric dicts: {"think": [], "answer": "Thanks for your patience so far. I would like to ask you a few questions to better understand your situation. Is that OK?"}


In [38]:
## Create a huggingface dataset from the training conversations and the rubric
from datasets import Dataset

def create_dataset(training_conversations, metadata, rubric):
    # create a list of dictionaries with the training conversations and the rubric
    data = []
    for cid, training_conversation in training_conversations.items():
        data.append({
            "conversation_id": cid,
            "conversation": json.dumps(training_conversation),
            "rubric": modified_rubric_dicts[cid],
            "metadata": metadata[cid],

        })
    # create a huggingface dataset from the data
    dataset = Dataset.from_list(data)
    return dataset


dataset = create_dataset(training_conversations, metadata, rubric)
# save the dataset to a jsonl file
dataset.save_to_disk(f"{DATA_DIR}/training/dataset_multi_state_{n_training_examples}.jsonl")

Saving the dataset (0/1 shards):   0%|          | 0/105 [00:00<?, ? examples/s]

In [241]:
# write the conversations to a jsonl file
import json
with open(f"data/training_conversations_{n_training_examples}.jsonl", "w") as f:
    for cid, conversation in training_conversations.items():
        line = {'cid': cid, 'conversation': training_conversations[cid]}
        f.write(json.dumps(line) + "\n")

In [242]:
# Define a system prompt. The prompt tells - 1) the data consists of a conversatin between a doctor and patient. The doctor is represented by role "assistant" and patient by role "user"
# The LLM should 
# * create a differential diagnosis based on the conversation history.
# * Generate the list of all the clarifying questions to ask the patient to narrow down the diagnosis.
# Provide the result for this conversation in a json format with the following keys:
# * "differential_diagnosis" - a list; contains possible diagnoses
# * "clarifying_questions" - a list; contains all the clarifying questions to ask the patient
# load training conversations from the jsonl file
# write the rubric to a jsonl file
with open(f"data/rubric_{n_training_examples}.jsonl", "w") as f:
    for key, value in rubric.items():
        f.write(json.dumps({key: value}) + "\n")

In [112]:
# list models 
from openai import OpenAI
client = OpenAI(
    api_key=os.environ["OPENAI_API_KEY"],
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)
models = client.models.list()
for m in models:
    print(m.id)

models/chat-bison-001
models/text-bison-001
models/embedding-gecko-001
models/gemini-1.0-pro-vision-latest
models/gemini-pro-vision
models/gemini-1.5-pro-latest
models/gemini-1.5-pro-001
models/gemini-1.5-pro-002
models/gemini-1.5-pro
models/gemini-1.5-flash-latest
models/gemini-1.5-flash-001
models/gemini-1.5-flash-001-tuning
models/gemini-1.5-flash
models/gemini-1.5-flash-002
models/gemini-1.5-flash-8b
models/gemini-1.5-flash-8b-001
models/gemini-1.5-flash-8b-latest
models/gemini-1.5-flash-8b-exp-0827
models/gemini-1.5-flash-8b-exp-0924
models/gemini-2.5-pro-exp-03-25
models/gemini-2.5-pro-preview-03-25
models/gemini-2.5-flash-preview-04-17
models/gemini-2.5-flash-preview-04-17-thinking
models/gemini-2.5-pro-preview-05-06
models/gemini-2.0-flash-exp
models/gemini-2.0-flash
models/gemini-2.0-flash-001
models/gemini-2.0-flash-exp-image-generation
models/gemini-2.0-flash-lite-001
models/gemini-2.0-flash-lite
models/gemini-2.0-flash-preview-image-generation
models/gemini-2.0-flash-lite-p