# Conversation Closing Detector

In [49]:
import csv
import pandas as pd
import json
from dotenv import load_dotenv
from langchain import OpenAI, PromptTemplate, LLMChain
from langchain_openai import ChatOpenAI

In [50]:
load_dotenv("../.env")

True

## Load Transcripts

In [51]:
TRANSCRIPT_PATH = "../data/patients/patients_1.0_with_transcripts.json"

In [52]:
with open(TRANSCRIPT_PATH, "r") as file:
    data = json.load(file)

In [53]:
transcripts = {patient['id']: patient['chat_transcript'] for patient in data.values()}

In [54]:
# Function to extract segments where the conversation starts with 'Doctor' then 'Patient'
def extract_segments(data):
    segments = []

    for patient_id, patient_data in data.items():
        transcript = patient_data["chat_transcript"]
        for i in range(1, len(transcript)):
            ai_output = transcript[i - 1]
            user_input = transcript[i]

            # Check if ai_output starts with 'Doctor' and user_input starts with 'Patient'
            if ai_output.startswith("Doctor: ") and user_input.startswith("Patient: "):
                # Remove the 'Doctor: ' and 'Patient: ' prefixes
                ai_output_clean = ai_output[len("Doctor: "):]
                user_input_clean = user_input[len("Patient: "):]

                segments.append({
                    'patient_id': patient_id,
                    'ai_output': ai_output_clean,
                    'user_input': user_input_clean
                })
    
    return segments

# Extract segments
segments = extract_segments(data)

# Save segments to CSV for manual labeling
csv_file_path = '../data/interim/transcript_segments_for_labeling.csv'

with open(csv_file_path, 'w', newline='') as csvfile:
    fieldnames = ['patient_id', 'ai_output', 'user_input']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

    writer.writeheader()
    for segment in segments:
        writer.writerow(segment)

print(f"Segments have been extracted and saved to {csv_file_path}")

Segments have been extracted and saved to ../data/interim/transcript_segments_for_labeling.csv


Note that the above CSV needs to me manually labeled

In [55]:
# Load the CSV file
csv_file_path = '../data/interim/transcript_segments_labeled.csv'
labeled_segments = pd.read_csv(csv_file_path)

## Define Conversation Closure Detection Model

In [56]:
# Define the function to check conversation closing
model = ChatOpenAI(temperature=0.7, model_name="gpt-3.5-turbo")

conversation_closing_prompt_template = """
You are analyzing a conversation between a doctor and a patient. Based on the last user input and the previous AI output, determine if the conversation is coming to a close. Respond with 'True' or 'False'.

Examples:

Example 1:
AI Output: Based on our conversation, Kevin, it seems you are mainly experiencing tiredness and leg swelling, and you are currently taking Furosemide, Spironolactone, and fish oil for your heart condition. Is there anything else you would like to share regarding your symptoms, vital signs, or medications?
User Input: No, Doctor, I think that covers everything for now. Thank you for checking in on me.
Response: True

Example 2:
AI Output: Based on our conversation, Kevin, it seems like you are mainly experiencing tiredness and leg swelling. Could you please provide your latest vital signs, starting with your temperature?
User Input: My temperature is 97.4 degrees, Doctor.
Response: False

Now, analyze the following conversation and determine if it is coming to a close.

AI Output:
{ai_output}

User Input:
{user_input}
"""

prompt = PromptTemplate(
    input_variables=["ai_output", "user_input"],
    template=conversation_closing_prompt_template,
)

llm_chain = LLMChain(llm=model, prompt=prompt)

def check_conversation_closing(ai_output, user_input):
    # Run the LLM chain with the AI output and user input
    result = llm_chain.run(ai_output=ai_output, user_input=user_input)
    
    return result.strip().lower() == 'true'

# Prepare the data for testing
y_true = labeled_segments['end'].tolist()
y_pred = []

for _, row in labeled_segments.iterrows():
    ai_output = row['ai_output']
    user_input = row['user_input']
    prediction = check_conversation_closing(ai_output, user_input)
    y_pred.append(int(prediction))

  warn_deprecated(


In [59]:
from sklearn.metrics import accuracy_score, confusion_matrix

# Calculate accuracy and confusion matrix
accuracy = accuracy_score(y_true, y_pred)
conf_matrix = confusion_matrix(y_true, y_pred)

print(f"Accuracy: {accuracy}")
print(f"Confusion Matrix:\n{conf_matrix}")

Accuracy: 0.878
Confusion Matrix:
[[423  61]
 [  0  16]]


In [63]:
pd.DataFrame(conf_matrix)

Unnamed: 0,0,1
0,423,61
1,0,16


In [66]:
# Print out cases where y_true and y_pred are not the same
for index, (true_label, pred_label) in enumerate(zip(y_true, y_pred)):
    if true_label != pred_label:
        print(f"Index: {index}")
        print(f"Patient ID: {labeled_segments.loc[index, 'patient_id']}")
        print(f"AI Output: {labeled_segments.loc[index, 'ai_output']}")
        print(f"User Input: {labeled_segments.loc[index, 'user_input']}")
        print(f"True Label: {true_label}, Predicted Label: {pred_label}")
        print("---")

Index: 14
Patient ID: 12305811
AI Output: Thank you for sharing your current medications, Kevin. Is there any other medication you are taking for your heart condition or any other health issue?
User Input: No, Doctor, those are the main ones for my heart. I also take some fish oil for general health.
True Label: 0, Predicted Label: 1
---
Index: 33
Patient ID: 14185111
AI Output: Thank you for sharing your oxygen saturation level. Lastly, could you please provide your blood pressure for today?
User Input: My blood pressure today is 123/56.
True Label: 0, Predicted Label: 1
---
Index: 35
Patient ID: 14185111
AI Output: Based on your responses, it seems like you're experiencing ankle swelling and the need to prop yourself up at night to breathe comfortably. Let's continue with your medications. Besides beta-blockers and diuretics, are you taking any other medications currently?
User Input: No, those are the main ones I'm taking right now. Is there anything specific you want to know about 

It looks like this isn't going to be good enough