In [17]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold
from skmultilearn.model_selection import iterative_train_test_split
from sklearn.metrics import f1_score,  precision_score, recall_score, hamming_loss

import openai
import json

from dotenv import load_dotenv
load_dotenv()

True

## **Load Cleaned Data**

In [2]:
# Load the cleaned datasets
df_user_inputs = pd.read_csv('../dataset/user_inputs_cleaned.csv')
df_labels = pd.read_csv('../dataset//labels_cleaned.csv')

# Remove unnecessary index columns
df_user_inputs.drop(df_user_inputs.columns[0], axis=1, inplace=True)
df_labels.drop(df_labels.columns[0], axis=1, inplace=True)

# Ensure alignment
assert len(df_labels) == len(df_user_inputs), "Datasets do not align!"

print(df_user_inputs.shape)
df_user_inputs.head(10)

(3974, 1)


Unnamed: 0,text
0,er is een teek op mijn been ik ben bang dat di...
1,er is een teek op mijn rug en ik krijg hem er ...
2,op mijn been zit een teek ik heb hem geprobeer...
3,ik heb allergieen
4,huid
5,roodheid
6,schilfering
7,ik heb wratten onder mijn voet
8,ik heb gisteren naar het bos geweest en zie nu...
9,ik voelde iets prikken


In [3]:
df_labels.head()

Unnamed: 0,"Niet lekker voelen, algehele malaise",Beenklachten,Bloedneus,Misselijkheid en overgeven,Brandwond,Buikpijn,Suikerziekte (ontregeld),Diarree,Duizeligheid,Gebitsklachten,...,Coronavirus,Knieklachten,Liesklachten,Elleboogklachten,Schouderklachten,Oorsuizen,Hand- en polsklachten,Enkelklachten,Dikke enkels of voeten,Vingerklachten
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [4]:
# Determine model baseline
label_frequencies = df_labels.sum().sort_values(ascending=False)
label_frequencies / df_labels.shape[0]

Huidklachten                            0.088827
Beenklachten                            0.072723
Buikpijn                                0.062154
Oorklachten                             0.052843
Misselijkheid en overgeven              0.037242
                                          ...   
Liesklachten                            0.005033
Tekenbeet                               0.005033
Verdrinking                             0.005033
Verwonding aan de buik                  0.005033
Niet lekker voelen, algehele malaise    0.002516
Length: 74, dtype: float64

We see that if we predict every time the label with the highest frequency (Huidklachten), our model will be correct around 9% of the time. We want our model to perform at least better than this 9% threshold

## **Prepare Data to Model Format**

To accomodate the nature of multi-label classification. Instead of using the traditional method `train_test_split`, we employ iterative stratified sampling `iterative_train_test_split`, to provide a well-balanced distribution of all label combinations in both training and test sets.

In [5]:
## Split data to train:val:test

# Prepare data for iterative train test split
# X must be 2D np.ndarray and y must be 2D binary np.ndarray
X_texts = df_user_inputs['text'].values
X_texts = X_texts.reshape(-1, 1)
y = df_labels.values

# Split the data 60:20:20 with multi-label stratification
train_texts, y_train, test_texts, y_test = iterative_train_test_split(X_texts, y, test_size = 0.2)
#val_texts, y_val, test_texts, y_test = iterative_train_test_split(tmp_texts, y_tmp, test_size = 0.5)

# Sanity checks to confirm the shapes of the datasets
assert train_texts.shape[0] == y_train.shape[0], "Mismatch in train data and labels"
assert test_texts.shape[0] == y_test.shape[0], "Mismatch in test data and labels"

train_texts, test_texts = train_texts.ravel(), test_texts.ravel()
#val_texts = val_texts.ravel()

print(train_texts.shape, y_train.shape, test_texts.shape)
train_texts

(3175,) (3175, 74) (799,)


array(['er is een teek op mijn been ik ben bang dat die er al een tijdje op heeft gezeten',
       'huid', 'roodheid', ...,
       'vannacht met slapen denk ik gekke beweging gemaakt want mn nek is nu helemaal stijf kan niet meer naar rechts kijken',
       'heb al langere tijd pijn in mn nek krijg dan soms tintelingen over mijn arm heb dan ook minder kracht in mijn arm',
       'doet zeer als ik mn hoofd beweeg'], dtype=object)

## **Model Building: GPT-4**

### **A. Data & Model Preparation**

In [117]:
# Load the API key from an environment variable
openai_api_key = os.getenv('OPENAI_API_KEY')

def classify_medical_text(user_input):
    """
    Classify a medical text into possible chief complaints using GPT-4.

    Args:
    user_input (str): The medical text to classify.
    model (str): The model identifier for OpenAI API.

    Returns:
    dict: Top chief complaints with confidence scores.
    """
    # Define the prompt
    prompt = """Classify the user medical situation (in Dutch), as having the following medical chief complaints (separated by ;):
              "Niet lekker voelen, algehele malaise;Beenklachten;Bloedneus;Misselijkheid en overgeven;Brandwond;Buikpijn;Suikerziekte (ontregeld);Diarree;Duizeligheid;Gebitsklachten;Geslachtsorgaanklachten;Hartkloppingen, overslaan van het hart;Hoesten;Hoofdpijn;Eczeem;Epileptische aanval of insult;Vergiftiging of intoxicatie;Keelklachten;Koorts bij volwassenen;Kortademigheid;Nekklachten;Neurologische uitval, klachten van de zenuwbanen;Verstopping, obstipatie;Oogklachten of beschadigingen aan het oog;Oorklachten;Pijn op de borst of pijn in de borstkas;Anusklachten;Rugpijn;Verwonding aan armen of benen;Verwonding aan de buik;Verwonding aan de rug;Verwonding aan de borstkas;Urinewegproblemen;Vaginaal bloedverlies;Verdrinking;Flauwvallen, wegraking;Zwangerschap en bevalling;Pijnlijke mondhoeken of kloofjes in de mondhoeken;Stemklachten of heesheid;Slechter horen;Verkoudheid;Hooikoorts of neusklachten;Huidschimmel;Voetblaar;Voetwrat;Ouderdomswrat;Acne, jeugdpuistjes;Pijnlijke teen door stoten;Gewrichtsklachten;Ooglidklachten;Vaginale afscheiding;Vaginale schimmelinfectie;No complaint;Vragen over drain, sonde of katheter;Huidklachten;Tekenbeet;Borstontsteking ;Allergische reactie;Insectenbeet;Armklachten;Wonden of bijtwonden;Verwonding aan de nek;Verwonding aan het hoofd en/of het gezicht;Voorwerp in lichaam;ICD (apparaat voor hartritmestoornissen);Coronavirus;Knieklachten;Liesklachten;Elleboogklachten;Schouderklachten;Oorsuizen;Hand- en polsklachten;Enkelklachten;Dikke enkels of voeten;Vingerklachten"
              Please return in JSON the top five most possible chief complaints along with confidence scores for each, with key fields 'chief_complaint' and 'conf'.
              Output format as follows:
                [{'chief_complaint': ___, 'conf': ___}, ... 5 times]
              """
    model_name = "gpt-4-1106-preview"

    # Get the API key from an environment variable
    openai.api_key = openai_api_key
    

    # Make the request to the OpenAI API
    response = openai.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": prompt},
            {"role": "user", "content": user_input},
        ],
        temperature=0
    )

    # Extracting and cleaning the response
    try:
        chief_complaints = response.choices[0].message.content

        # Remove placeholders and strip whitespace
        chief_complaints = chief_complaints.replace('json', '').replace('`', '').strip()

        # Parse the cleaned response into JSON
        chief_complaints = json.loads(chief_complaints)
    except json.JSONDecodeError:
        chief_complaints = {"error": "Response parsing failed"}

    return chief_complaints

user_input = "Goedendag Ik heb last van een droge eikel en kleine scheurtjes in de voorhuid ze zijn niet diep maar voelt heel irriterent aan is er iets wat ik kan doen om het te verhelpen zo der daarvoor naar een dokter te moeten?"
chief_complaints = classify_medical_text(user_input)
chief_complaints

{'top_five_chief_complaints': [{'chief_complaint': 'Geslachtsorgaanklachten',
   'conf': 0.85},
  {'chief_complaint': 'Vaginale schimmelinfectie', 'conf': 0.75},
  {'chief_complaint': 'Huidklachten', 'conf': 0.7},
  {'chief_complaint': 'Eczeem', 'conf': 0.65},
  {'chief_complaint': 'Vaginale afscheiding', 'conf': 0.6}]}

### **B. Evaluate model**

In [55]:
# Create a mapping of chief complaints to their indices
complaints_to_index = {complaint: index for index, complaint in enumerate(df_labels.columns)}
complaints_to_index

{'Niet lekker voelen, algehele malaise': 0,
 'Beenklachten': 1,
 'Bloedneus': 2,
 'Misselijkheid en overgeven': 3,
 'Brandwond': 4,
 'Buikpijn': 5,
 'Suikerziekte (ontregeld)': 6,
 'Diarree': 7,
 'Duizeligheid': 8,
 'Gebitsklachten': 9,
 'Geslachtsorgaanklachten': 10,
 'Hartkloppingen, overslaan van het hart': 11,
 'Hoesten': 12,
 'Hoofdpijn': 13,
 'Eczeem': 14,
 'Epileptische aanval of insult': 15,
 'Vergiftiging of intoxicatie': 16,
 'Keelklachten': 17,
 'Koorts bij volwassenen': 18,
 'Kortademigheid': 19,
 'Nekklachten': 20,
 'Neurologische uitval, klachten van de zenuwbanen': 21,
 'Verstopping, obstipatie': 22,
 'Oogklachten of beschadigingen aan het oog': 23,
 'Oorklachten': 24,
 'Pijn op de borst of pijn in de borstkas': 25,
 'Anusklachten': 26,
 'Rugpijn': 27,
 'Verwonding aan armen of benen': 28,
 'Verwonding aan de buik': 29,
 'Verwonding aan de rug': 30,
 'Verwonding aan de borstkas': 31,
 'Urinewegproblemen': 32,
 'Vaginaal bloedverlies': 33,
 'Verdrinking': 34,
 'Flauwvalle

In [111]:
# Step 1: Predictions Generation
predicted_labels = []
for i, text in enumerate(test_texts):
    print(i)
    predicted_complaints = classify_medical_text(text)

    # Step 2: Prediction Parsing
    # Check for error in response and skip to the next iteration if found
    if isinstance(predicted_complaints, dict) and predicted_complaints.get('error') == 'Response parsing failed':
        continue
    # Check if the result is a dictionary with 'chief_complaints' key or a list
    if isinstance(predicted_complaints, dict) and 'chief_complaints' in predicted_complaints:
        predicted_complaints = predicted_complaints['chief_complaints']
    elif isinstance(predicted_complaints, dict) and 'chief_complaint' in predicted_complaints:
        predicted_complaints = [predicted_complaints]  # Wrap single entry in a list]

    # Initialize an array of zeros for each prediction
    processed_result = np.zeros(len(complaints_to_index))

    # Set 1 for complaints with conf >= 0.8
    for complaint in predicted_complaints:
        if complaint['conf'] >= 0.7:
            index = complaints_to_index.get(complaint['chief_complaint'])
            if index is not None:
                processed_result[index] = 1

    predicted_labels.append(processed_result)

predicted_labels = np.array(predicted_labels)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151


KeyboardInterrupt: 

In [115]:
# Step 3: Evaluation (remains the same)
f1 = f1_score(y_test[:len(predicted_labels)], predicted_labels, average='micro')
hamming = hamming_loss(y_test[:len(predicted_labels)], predicted_labels)

print("F1 Score:", f1)
print("Hamming Loss:", hamming)

F1 Score: 0.49115044247787604
Hamming Loss: 0.021143592572164
