In [None]:
import glob
import re
import json
import numpy as np
import ast
import os
from sklearn.metrics import f1_score

In [None]:
comp2id = {'unsatisfactory': 0, 'satisfactory': 1, 'excellent': 2}

COMPlAINTS_PATH = "/home/jovyan/isviridov/gm/3mdbench/data/complaints.json"
GLOBAL_DIALOGUES_PATH = "/home/jovyan/isviridov/gm/3mdbench/results"
GLOBAL_ASSESSMENTS_PATH = "/home/jovyan/isviridov/gm/3mdbench/results/assessment"
GLOBAL_OBTAINED_DIAGNOSES_PATH = "/home/jovyan/isviridov/gm/3mdbench/results/assessment/diags"

In [None]:
with open(COMPlAINTS_PATH, 'r') as f:
    complaints = json.load(f)

true_diags = set()
for complaint in complaints:
    true_diags.add(complaints[complaint]['diagnosis'].lower())
true_diags = list(true_diags)

In [None]:
def get_doctor_replics(dialogue):
    utterances = [x.strip().lower() for x in re.split('Patient:|Doctor:|DIAG:', dialogue) if len(x.strip()) > 0]
    doc_utterances = [utterances[i] for i in range(1, len(utterances), 2)]
    return doc_utterances

In [None]:
def get_replace(answer):
    try:
        return int(answer)
    except ValueError:
        answer = "".join(re.findall(r"[a-zA-Z]+", answer.lower()))
        if answer == "no":
            return 0
        elif answer == "yes":
            return 1
        else:
            raise NotImplementedError

In [None]:
def get_diags(preds):
    diags = [diag.strip() for diag in preds.split(',')]
    return diags

In [None]:
def get_correct_diags(preds):
    diags = [diag.strip() for diag in preds.split(',')]
    for i, diag_i in enumerate(diags):
        if "herpes" in diag_i or "cold sore" in diag_i or "hsv" in diag_i:
            diags[i] = "herpes"
        elif "grown" in diag_i and "nail" in diag_i:
            diags[i] = "ingrown nail"
        elif "hives" in diag_i or "urticaria" in diag_i:
            diags[i] = "hives"
        elif "cavities" in diag_i or "caries" in diag_i:
            diags[i] = "caries"
        elif 'solar keratosis' in diag_i:
            diags[i] = 'actinic keratosis'
        elif "wart" in diag_i:
            diags[i] = "warts"
        elif 'atopic dermatitis' in diag_i:
            diags[i] = "eczema"
        elif 'gum disease' in diag_i:
            diags[i] = 'gingivitis'
        elif "nail" in diag_i and "fun" in diag_i:
            diags[i] = "onychomycosis"
        elif "subungual hematoma" in diag_i or "onychodystrophy" in diag_i:
            diags[i] = "nail dystrophy"
        elif "dandruff" in diag_i:
            diags[i] = 'seborrheic dermatitis'
        elif "varicella" in diag_i:
            diags[i] = "chickenpox"
        elif "hordeolum" in diag_i:
            diags[i] = "stye"
        elif "tinea versicolor" in diag_i or "ringworm" in diag_i or "fungal infection" in diag_i:
            diags[i] = "mycosis"
        elif "tartar buildup" in diag_i or "plaque buildup" in diag_i:
            diags[i] = "dental calculus"
        elif "aphthous ulcers" in diag_i:
            diags[i] = "stomatitis"
        else:
            continue
    return diags

In [None]:
def count_dialogue_metrics(assessments_list, dialogues_path, model_name):
    assessments, failed_assessments, failed_cases, is_failed = [], [], [], []
    for dialogue in assessments_list:
        dialogue_case = dialogue.split('/')[-1].split('_')[-1].split('.')[0]
        with open(dialogue, 'r') as f:
            assessment_data = json.load(f)

        assessment = assessment_data[dialogue_case]["assessment"].replace("```", '').replace("json", '').replace("python", '')
        assessment = assessment[:assessment.rfind('}')+1]
        try:
            assessment = ast.literal_eval(assessment)["Doctor assessment"]
            answ = [int(get_replace(assessment["Diagnostic abilities"]["0.1"])),
                    int(get_replace(assessment["Medical Interviewing Skills"]["1.1"])),
                    int(get_replace(assessment["Medical Interviewing Skills"]["1.2"])),
                    int(get_replace(assessment["Medical Interviewing Skills"]["1.3"])),
                    int(get_replace(assessment["Humanistic Care"]["3.1"])),
                    int(get_replace(assessment["Humanistic Care"]["3.2"])),
                    int(get_replace(assessment["Comprehensive Diagnostic and Treatment Abilities"]["4.1"])),
                    int(get_replace(assessment["Comprehensive Diagnostic and Treatment Abilities"]["4.2"])),
                    int(comp2id[assessment["Overall Clinical Competence"]["5.1"].lower()]),
                    int(get_replace(assessment["Slayness"]["6.1"]))
            ]
            assessments.append(answ)
        except SyntaxError:
            failed_assessments.append(dialogue)
            failed_cases.append(assessment)
            with open(f'{dialogues_path}/{model_name}/case_{dialogue_case}.json', 'r') as f:
                dialogue_text = json.load(f)
            is_failed.append(not dialogue_text[dialogue_case]["dialogue_ended"])
    assert np.sum(is_failed) == len(failed_assessments), "You have some unassessed finished dialogues!"
    return assessments

---

In [None]:
experiment_name = "llama_test"

#### Dialogue metrics

In [None]:
assessed_dialogues = glob.glob(os.path.join(GLOBAL_ASSESSMENTS_PATH, experiment_name) + "/*.json")

Results

In [None]:
res = np.array(count_dialogue_metrics(assessed_dialogues, GLOBAL_DIALOGUES_PATH, experiment_name))
res = list(map(lambda x: float(round(x, 3)), np.mean(res, axis=0)))
res

#### Diagnostic metrics

In [None]:
cases = glob.glob(f"{GLOBAL_OBTAINED_DIAGNOSES_PATH}/{experiment_name}/*.json") 

In [None]:
pred_diags = {}

for case_i in cases:
    k = case_i.split('/')[-1].split('_')[-1].split('.')[0]
    try:
        with open(case_i, 'r') as f:
            res = json.load(f)
    except FileNotFoundError:
        pass
    else:
        pred_diags[k] = res[k]["diags"].replace('\n', '').replace('`', '')

In [None]:
model_preds, model_gts, failed_diags = [], [], 0

for k in pred_diags:
    diags_list = get_correct_diags(pred_diags[k])
    if diags_list == ["none"]:
        failed_diags += 1
        continue
    else:
        for pred_diag in diags_list:
            model_preds.append(pred_diag)
            # adding ground truth label multiple times in case of multiple predictions for the same case
            model_gts.append(complaints[k]['diagnosis'].lower())
failed_diags /= len(pred_diags)

Results

In [None]:
print(f1_score(model_gts, model_preds, average="weighted", labels=list(true_diags)), failed_diags*100)