## Heirarchical Coding with Small Language Models  
  
**Goal of this notebook**  
- Mitigate the challenge of having a vast label space by breaking down the potential classes only to specific targets. In some cases, this would reduce the label space from 1000s of codes to fewer than 20.

**Methodology** 
   
The approach is the leverage small language model(s) to traverse a heirarchy tree of ICD-9 codes and ask many small, simple, questions to classify a Note Event from the MIMI-III dataset.

For example, the image below is a representation of a small portion of the ICD9 code tree. The branch in the picture below shows a subset of the 'Infectious and Parasitic Disease' Chapter of the ICD9 code tree. [View a full json representation of the taxonomy here.](./icd9_full.json) For this implementation, ICD-9 Code levels are broken down into Chapters, Blocks, Categories. To expand on this implementation the Tree can be broken down further into Sub-categories, Extension I, and Extension II codes.
  
The coding algorithm recursively walks the tree, starting at the top level and continuing down any branch(es) directed by the mini model until the final codes are returned. 

Sample subset (condensed Chapter 0 branch) of the ICD9 code tree: 
```
0        Infectious and Parasitic Diseases  
├── 00   Intestinal infectious diseases  
│   ├── 001 Cholera  
│   ├── 002 Typhoid and paratyphoid fevers  
│   ├── 003 Salmonella  
│   ├── 004 Shigellosis  
│   ├── 005 Other poisoning (bacterial)  
│   ├── 006 Amebiasis  
│   ├── 007 Other protozoal intestinal diseases  
│   ├── 008 Intestinal infections due to other organisms  
│   └── 009 Ill-defined intestinal infections  
...  
├── 09   Rickettsioses and other arthropod-borne diseases  
│   ├── ...   
│   ├── 087 Relapsing fever  
│   └── 088 Other arthropod-borne diseases

```

**Further Reading**  
  
This approach is explored further in this paper:  
 [Automated clinical coding using off-the-shelf large language models](https://arxiv.org/pdf/2310.06552) (Boyle et al.)

In [None]:
from src.tree import TaxonomyParser
from zensols.mednlp import ApplicationFactory

from nltk import flatten
from tqdm import tqdm
from dotenv import load_dotenv, find_dotenv
from textwrap import dedent
from openai import AzureOpenAI
from typing import List, Dict, Any

import pandas as pd
import ast
import functools
import logging
import os


doc_parser = ApplicationFactory.get_doc_parser()
logger = logging.getLogger(__name__)

load_dotenv(find_dotenv(), override=True)
print(os.getenv("AZURE_OPENAI_BASE"))

pd.set_option('display.max_colwidth', None)

In [None]:
# Initialize Code Tree
code_tree = TaxonomyParser()
code_tree.read_from_json("icd9_tax.json")

print(code_tree.find_by_name("00"))

In [None]:
# View Tree
code_tree.visualize("0")

#### Define Helper Functions

In [None]:
# Note Parsing Functions

def format_icd9(x):
    new_codes = []
    code_list = ast.literal_eval(x)
    for code in code_list:
        new_codes.append(f"{code:0>3}".format(num="1"))

    return str(new_codes)

def parse_note(note:str) -> str:
    
    doc = doc_parser(note)

    new_note = set([])
    for tok in doc.tokens:
        if tok.is_concept and tok.tuis_ in ['T184', 'T047', 'T046', 'T033', 'T037','T191','T005', 'T004', 'T007', 'T008']:
            
            # print(tok, tok.detected_name_, tok.sub_names, tok.pref_name_, tok.tuis_, tok.tui_descs_)
            new_note.add(tok.detected_name_.replace("~"," "))
            new_note.add(tok.pref_name_.lower())

    logger.info(f"Note Parsing Complete.")
    
    return " ".join(new_note)

In [None]:
#Scoring Functions

def recall_score(truth, generated):
    actual_list = ast.literal_eval(truth)
    generated_list = ast.literal_eval(generated)

    similar = len(set(actual_list) & set(generated_list))

    return similar / len(actual_list)

def precision_score(truth, generated):
    actual_list = ast.literal_eval(truth)
    generated_list = ast.literal_eval(generated)

    if len(generated_list) == 0:
        return 0

    similar = len(set(actual_list) & set(generated_list))

    return similar / len(generated_list)

def f1_score(truth, generated):
    precision = precision_score(truth, generated)
    recall = recall_score(truth, generated)

    if precision + recall == 0:
        return 0
    else:
        return 2 * (precision * recall) / (precision + recall)

In [None]:
# Make Call to AOAI

def call_aoai(sys:str, prompt:str) -> List:

    aoai_client = AzureOpenAI(
        azure_endpoint = os.getenv("AZURE_OPENAI_BASE"), 
        api_key=os.getenv("AZURE_OPENAI_KEY"),
        api_version="2024-02-01"
    )
    
    response = aoai_client.chat.completions.create(
        model=os.getenv("AZURE_DEPLOYMENT_NAME"), # model = "deployment_name".
        messages=[
            {"role": "system", "content": dedent(sys)},
            {"role": "user", "content": dedent(prompt)}
        ],
    )

    try:
        output = ast.literal_eval(response.choices[0].message.content)
        return output
    except Exception as e:
        logger.warning(f"{e}")
        return []

In [None]:
# Build Prompt Dymanically

def get_options(tree, parent_code):
    children = tree.get_children(parent_code)
    options = []
    for child in children:
        options.append(f"{child.name}: {child.description}")
    
    return '|'.join(options)

def build_prompt(tree, parent_code, note, categories):
    sys = """
    You are a medical expert. Your job is to classify notes of an event into one or more categories. ACCURACY is VERY IMPORTANT to your job.
    Choose the best option(s) based on the categories offered. ALWAYS return at least one index. ONLY choose from categories listed. 
    Respond with a list of quoted string indeces of the categories the note belongs to.
    Think through your answer. 
    
    ### EXAMPLE ###
    Categories = 0: Infectious and Parasitic Diseases | 1: Neoplasms | 2: Endocrine, Nutritional and Metabolic Diseases, and Immunity Disorders
    Note = Patient has Tuberculosis and an Immunity Disorder
    Answer: ['0','2']
    ## END EXAMPLE ##
    """
    
    
    prompt = f"""
    Categories = {categories}
    Note = {note}
    Answer:
    """

    return sys, prompt

In [None]:
# Recursive Walk of tree and call aoai to get codes

def get_codes_for_note(parent_code, tree, note, level=3):
    
    categories = get_options(tree, parent_code)
    sys, prompt = build_prompt(tree, parent_code, note, categories)

    codes = call_aoai(sys, prompt)
    
    logger.info(f"Parent Code: {parent_code} | Found: {codes}")
    logger.info(f"Prompt: {prompt}")

    if codes == [] or codes == ['']:
        return ['X'*level]
    elif all(len(i) == level for i in codes):
        return codes
    else:
        return list(map(functools.partial(get_codes_for_note, tree=tree, note=note, level=level), codes))
    

## Prepare Data

In [None]:
# df = transform_data("data/") # Only re-run if change in preparation logic
df = pd.read_csv("data/joined/dataset_single_001_088.csv.gz")
print(df.shape)
display(df.dtypes)

In [None]:
# Get L1 and L2 codes for grading purposes

def get_parent_codes(code_tree, codes):
    code_list = ast.literal_eval(codes)
    parent_codes = []
    for code in code_list:
        parent_codes.append(code_tree.find_by_name(code).parent.name)
    
    parent_codes = list(set(parent_codes))
    return str(parent_codes)

df['L2_CODES'] = df['ICD9_CODE'].apply(lambda x: get_parent_codes(code_tree, x))
df['L1_CODES'] = df['L2_CODES'].apply(lambda x: get_parent_codes(code_tree, x))
display(df[['ICD9_CODE', 'L2_CODES', 'L1_CODES']].head(5))

In [None]:
# Take Final Subset

df = df[0:10]
print(df.shape)

In [None]:
# Add Parsed Text field
tqdm.pandas()
df['PARSED_TEXT'] = df['TEXT'].progress_apply(parse_note)

In [None]:
print(df.shape)
display(df.head(2))

## Get ICD9 Codes

### Part 1 - Get Codes from Gpt-4o mini

In [None]:
### SIMPLE TEST ###
"""
res = flatten(get_codes_for_note("root", code_tree, "Tuberculosis of the bones and joints and HIV"))
print(res)
"""
#### END SIMPLE TEST ###

In [None]:
results = []
df['Generated'] = ""
for index, row in df.iterrows():

    # Parse Note
    note = ast.literal_eval(row['TEXT'])[0]
    print(f"Note: {note}")
    # Get Codes
    result = flatten(get_codes_for_note("0", code_tree, note, level=2)) # Change level here if needed

    # Add result to DF
    df.at[index, 'Generated'] = str(result)

In [None]:
# View Results

display(df[['ICD9_CODE','L1_CODES','L2_CODES', 'Generated']].head(10))

## Score Results

#### Grade L2 Output

In [None]:
results = pd.DataFrame()


results['ICD9_CODE'] = df['ICD9_CODE'].apply(format_icd9)
results['Recall'] = df.apply(lambda x: recall_score(x['L2_CODES'], x['Generated']), axis=1)
results['Precision'] = df.apply(lambda x: precision_score(x['L2_CODES'], x['Generated']), axis=1)
results['F1 Score'] = df.apply(lambda x: f1_score(x['L2_CODES'], x['Generated']), axis=1)
display(results[['Recall', 'Precision', 'F1 Score']].mean(axis=0)*100)

#### Grade Final  ICD 9 Code Output

In [None]:
results = pd.DataFrame()

results['ICD9_CODE'] = df['ICD9_CODE'].apply(format_icd9)
results['Recall'] = df.apply(lambda x: recall_score(x['ICD9_CODE'], x['Generated']), axis=1)
results['Precision'] = df.apply(lambda x: precision_score(x['ICD9_CODE'], x['Generated']), axis=1)
results['F1 Score'] = df.apply(lambda x: f1_score(x['ICD9_CODE'], x['Generated']), axis=1)

display(results[['Recall', 'Precision', 'F1 Score']].mean(axis=0)*100)

#### Results Summary

In [None]:
print(f"Recall = {round(results['Recall'].mean(),2)}")
print(f"Precision = {round(results['Precision'].mean(),2)}")

## Implement Med NLP Note Parsing

In [None]:
results = []
df['Parsed_Generated'] = ""
for index, row in df.iterrows():

    # Parse Note
    note = row['PARSED_TEXT']
    print(f"Note: {note}")

    # Get Codes
    result = flatten(get_codes_for_note("0", code_tree, note, level=2)) # Change level here if needed

    # Add result to DF
    df.at[index, 'Parsed_Generated'] = str(result)

In [None]:
display(df[['ICD9_CODE','L1_CODES','L2_CODES', 'Parsed_Generated']].head(10))

In [None]:
results = pd.DataFrame()


results['ICD9_CODE'] = df['ICD9_CODE'].apply(format_icd9)
results['Recall'] = df.apply(lambda x: recall_score(x['L2_CODES'], x['Parsed_Generated']), axis=1)
results['Precision'] = df.apply(lambda x: precision_score(x['L2_CODES'], x['Parsed_Generated']), axis=1)
results['F1 Score'] = df.apply(lambda x: f1_score(x['L2_CODES'], x['Parsed_Generated']), axis=1)
display(results[['Recall', 'Precision', 'F1 Score']].mean(axis=0)*100)