In [1]:
import pandas
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from openai import OpenAI
import utils

In [2]:
disease='JDM'
symp_type='formal'
model_old='claude-3-opus'
model_new='claude-3.5-sonnet'
model_new_name = 'claude-3-5-sonnet-20240620'

#model_old='chatgpt-3.5'
#model_new='chatgpt-4o'
#model_new_name = 'gpt-4o'
filename=f'data/{disease}-{symp_type}-{model_old}.txt'
output_filename = f'data/{disease}-{symp_type}-{model_new}.txt'
keywords_list = ['dermatomyositis', 'jdm']

## Run More Complex Models

The goal is to run more complex models and compare against the outputs of the baseline model.

In [3]:

all_patients = utils.process_file(filename)

In [7]:
print(all_patients[0].symptom_list)
utils.mk_prompt(all_patients[0])

['Progressive proximal muscles weakness', 'Easy fatigue', 'Gottron Papule', 'Helioptrope rash', 'Malar rash', 'Muscle pain or tenderness', 'Weight loss', 'Falling episodes', 'Arthritis', 'Lymphadenopathy', 'Cutaneous ulceration']


'The patient is a 10 year old female with a weight of 54.0 lbs and height of 46.0 inches. The list of symptoms include  (1) Progressive proximal muscles weakness (2) Easy fatigue (3) Gottron Papule (4) Helioptrope rash (5) Malar rash (6) Muscle pain or tenderness (7) Weight loss (8) Falling episodes (9) Arthritis (10) Lymphadenopathy and (11) Cutaneous ulceration.'

In [5]:
utils.run_llm_and_parse_response(all_patients[1], model_new_name)


Patient ID: 16
claude-3-5-sonnet-20240620


In [6]:
resps = [ utils.run_llm_and_parse_response(p, model_new_name) for p in all_patients]

Patient ID: 1
Patient ID: 2
Patient ID: 3
Patient ID: 4
Patient ID: 5
Patient ID: 6
Patient ID: 7
Patient ID: 8
Patient ID: 9
Patient ID: 10
Patient ID: 11
Patient ID: 12
Patient ID: 13
Patient ID: 14
Patient ID: 15
Patient ID: 16
Patient ID: 17
Patient ID: 18
Patient ID: 19
Patient ID: 20
Patient ID: 21
Patient ID: 22
Patient ID: 23
Patient ID: 24
Patient ID: 25
Patient ID: 26
Patient ID: 27
Patient ID: 28
Patient ID: 29
Patient ID: 30
Patient ID: 31
Patient ID: 32
Patient ID: 33
Patient ID: 34
Patient ID: 35
Patient ID: 36
Patient ID: 37
Patient ID: 38
Patient ID: 39
Patient ID: 40
Patient ID: 41
Patient ID: 42
Patient ID: 43
Patient ID: 44
Patient ID: 45
Patient ID: 46
Patient ID: 47
Patient ID: 48
Patient ID: 49
Patient ID: 50
Patient ID: 51
Patient ID: 52
Patient ID: 53
Patient ID: 54
Patient ID: 55
Patient ID: 56
Patient ID: 57
Patient ID: 58
Patient ID: 59
Patient ID: 60
Patient ID: 61
Patient ID: 62
Patient ID: 63
Patient ID: 64
Patient ID: 65
Patient ID: 66
Patient ID: 67
Pati

In [7]:
print(resps)

[[('Juvenile Dermatomyositis', '85', ' Classic skin rashes (Gottron papules, heliotrope rash) combined with muscle weakness and pain are hallmark symptoms of this condition in children.'), ('Systemic Lupus Erythematosus (SLE)', '60', ' Malar rash, arthritis, and muscle weakness are common in SLE, though some symptoms are more typical of dermatomyositis.'), ('Mixed Connective Tissue Disease', '40', ' Combines features of various autoimmune diseases, explaining the diverse symptoms including muscle weakness, arthritis, and rashes.'), ('Polymyositis', '30', ' Shares many symptoms with dermatomyositis, but typically lacks the characteristic skin rashes seen in this patient.'), ('Juvenile Idiopathic Arthritis', '20', " Can explain arthritis and fatigue, but doesn't typically cause the specific rashes and severe muscle weakness described.")], [('Juvenile Dermatomyositis', '85', ' Characteristic skin rashes (Gottron papules, malar rash, V sign) combined with muscle weakness and pain strongly 

In [8]:

def dump_response_to_file(filehandle, patient, resp):
    print(f'P: {patient.n}, {patient.gender}, {patient.age}, {patient.weight}, {patient.height}', file=filehandle)
    for s in patient.symptom_list:
        print(f'S:{s}', file=filehandle)
    for (j, r) in enumerate(resp):
        (disease_name, prob, descr) = r
        print(f'R: {j+1}, {disease_name}, {prob}, {descr}', file=filehandle)

# dump responses to output file
fhandle = open(output_filename,'w')
for (pt, resp) in zip(all_patients, resps):
    dump_response_to_file(fhandle, pt, resp)
fhandle.close()

In [9]:
def analyze_responses_for_top_three(resps, keywords_list, k=3):
    num_matches = 0
    for resp in resps:
        for (j, r) in enumerate(resp):
            if (j < k):
                (disease_name, prob, descr) = r
                if utils.approx_matches(keywords_list, [disease_name]):
                    num_matches = num_matches + 1
                    break 
    return num_matches

In [13]:
matches_1 = analyze_responses_for_top_three(resps, keywords_list, k=1)
print(f'{model_new} is able to find JDM in top 1 for {matches_1} out of {len(resps)} cases')
matches_3 = analyze_responses_for_top_three(resps, keywords_list, k=3)
print(f'{model_new} is able to find JDM in top 3 for {matches_3} out of {len(resps)} cases ')

claude-3.5-sonnet is able to find JDM in top 1 for 249 out of 249 cases
claude-3.5-sonnet is able to find JDM in top 3 for 249 out of 249 cases 


In [14]:
not_in_top_3 = utils.select_where_x_is_not_top(all_patients, keywords_list)
print(len(not_in_top_3))

0


In [15]:
chat_gpt_4_responses = [resps[pt.n] for pt in not_in_top_3]
matches_1 = analyze_responses_for_top_three(chat_gpt_4_responses, keywords_list, k=1)
print(f'{model_new} is able to find JDM in top 1 for {matches_1} out of {len(not_in_top_3)} cases where {model_old} fails')
matches_3 = analyze_responses_for_top_three(chat_gpt_4_responses, keywords_list, k=3)
print(f'{model_new} is able to find JDM in top 3 for {matches_3} out of {len(not_in_top_3)} cases where {model_old} fails ')

claude-3.5-sonnet is able to find JDM in top 1 for 0 out of 0 cases where claude-3-opus fails
claude-3.5-sonnet is able to find JDM in top 3 for 0 out of 0 cases where claude-3-opus fails 
