## Heirarchical Coding with Small Language Models  
  
**Goal of this notebook**  
- Mitigate the challenge of having a vast label space by breaking down the potential classes only to specific targets. In some cases, this would reduce the label space from 1000s of codes to fewer than 20.
  
**Methodology** 
   
The approach is the leverage small language model(s) to traverse a heirarchy tree of ICD-9 codes and ask many small, simple, questions to classify a Note Event from the MIMI-III dataset.

For example, the image below is a representation of a small portion of the ICD9 code tree. The branch in the picture below shows a subset of the 'Infectious and Parasitic Disease' Chapter of the ICD9 code tree. For this implementation, ICD-9 Code levels are broken down into Chapters, Blocks, Categories. To expand on this implementation the Tree can be broken down further into Sub-categories, Extension I, and Extension II codes.
  
The coding algorithm recursively walks the tree, starting at the top level and continuing down any branch(es) directed by the mini model until the final codes are returned. 

Sample subset (condensed Chapter 0 branch) of the ICD9 code tree: 
<small>
```markdown
001-139 Infectious and Parasitic Diseases  
├── 001-009 Intestinal infectious diseases  
│   ├── 001 Cholera  
│   ├── 002 Typhoid and paratyphoid fevers  
│   ├── 003 Salmonella  
│   ├── 004 Shigellosis  
│   ├── 005 Other poisoning (bacterial)  
│   ├── 006 Amebiasis  
│   ├── 007 Other protozoal intestinal diseases  
│   ├── 008 Intestinal infections due to other organisms  
│   └── 009 Ill-defined intestinal infections  
...  
├── 090-099 Rickettsioses and other arthropod-borne diseases  
│   ├── ...   
│   ├── 087 Relapsing fever  
│   └── 088 Other arthropod-borne diseases

```
</small>  

**Further Reading**  
  
This approach is explored further in this paper: [Automated clinical coding using off-the-shelf large language models](https://arxiv.org/pdf/2310.06552) (Boyle et al.)

In [1]:
from src.icd9_tree import ICD9

from dotenv import load_dotenv, find_dotenv
from textwrap import dedent
from openai import AzureOpenAI
from nltk import flatten

import pandas as pd
import ast
import functools
import logging
import os


logger = logging.getLogger(__name__)
load_dotenv(find_dotenv(), override=True)
pd.set_option('display.max_colwidth', None)

---
#### Build Code Tree Data Structure

In [None]:
# Initialize Code Tree
tree = ICD9('icd9_codes_full.json')

# View Chapter codes (ignore E and V codes)
for chapter in tree.children:
    if chapter.code[0] not in ['E', 'V']:
        print(f"{chapter.code} - {chapter.description}")

---
#### Prepare Data

In [None]:
# df = transform_data("data/") # Only re-run if change in preparation logic
df = pd.read_csv("data/joined/dataset_single_notes_full.csv.gz")
print(df.shape)
display(df.dtypes)

In [None]:
# Take Final Subset

df = df.sample(10, replace=False ,random_state=123)
print(df.shape)

---
#### Define Helper Functions

In [5]:
# Make Call to AOAI

def call_aoai(sys, prompt):

    # Init AOAI client
    aoai_client = AzureOpenAI(
        azure_endpoint = os.getenv('AZURE_OPENAI_BASE'), 
        api_key=os.getenv('AZURE_OPENAI_KEY'),
        api_version="2024-02-01"
    )
    
    response = aoai_client.chat.completions.create(
        model=os.getenv("AOAI_MINI_DEPLOYMENT_NAME"), # model = "deployment_name".
        messages=[
            {"role": "system", "content": dedent(sys)},
            {"role": "user", "content": dedent(prompt)}
        ],
    )

    try:
        output = ast.literal_eval(response.choices[0].message.content)
        return output
    except Exception as e:
        logger.warning(f"{e}")
        return []

In [6]:
# Helper Functions to Build Prompt Dymanically

def get_options(tree, parent_code):
    children = tree.find(parent_code).children
    options = []
    for child in children:
        options.append(f"{child.code}: {child.description}")
    
    return '|'.join(options)

def build_prompt(note, categories):
    sys = """
    You are a medical expert. Your job is to classify notes of an event into one or more categories. ACCURACY is VERY IMPORTANT to your job.
    Choose the best option(s) based on the categories offered. ALWAYS return at least one index. ONLY choose from categories listed. 
    Respond with a list of quoted string indeces of the categories the note belongs to.
    Think through your answer. 
    
    ### EXAMPLE ###
    Categories = 001-009: INTESTINAL INFECTIOUS DISEASES |010-018: TUBERCULOSIS |020-027: ZOONOTIC BACTERIAL DISEASES |030-041: OTHER BACTERIAL DISEASES |042: Human immunodeficiency virus [HIV] disease|045-049: POLIOMYELITIS AND OTHER NON-ARTHROPOD-BORNE VIRAL DISEASES AND PRION DISEASES OF CENTRAL NERVOUS SYSTEM |050-059: VIRAL DISEASES ACCOMPANIED BY EXANTHEM |060-066: ARTHROPOD-BORNE VIRAL DISEASES |070-079: OTHER DISEASES DUE TO VIRUSES AND CHLAMYDIAE |080-088: RICKETTSIOSES AND OTHER ARTHROPOD-BORNE DISEASES |090-099: SYPHILIS AND OTHER VENEREAL DISEASES |100-104: OTHER SPIROCHETAL DISEASES |110-118: MYCOSES |120-129: HELMINTHIASES |130-136: OTHER INFECTIOUS AND PARASITIC DISEASES |137-139: LATE EFFECTS OF INFECTIOUS AND PARASITIC DISEASES 
    Note = Tuberculosis of the bones and joints and HIV
    Answer = ['010-018', '042']
    ## END EXAMPLE ##
    """
    
    
    prompt = f"""
    Categories = {categories}
    Note = {note}
    Answer: 
    """

    return sys, prompt

In [7]:
# Recursive Walk of tree and call aoai to get codes

def get_codes_for_note(tree, note, level, parent_code):
    
    categories = get_options(tree, parent_code)
    sys, prompt = build_prompt(note, categories)

    
    codes = call_aoai(sys, prompt)
    if '0' in codes:
        codes.remove('0')
    
    print(f"parent: {parent_code} prompt = {prompt}")
    print(f"CODES: {codes}")
    ready_codes = [item for item in codes if len(item) == level]
    incomplete_codes = [item for item in codes if len(item) != level]
    print(f"Ready Codes: {ready_codes} ; Incomplete Codes: {incomplete_codes}")
    print("**********")
    
    if codes == [] or codes == ['']:
        return ['X'*level]
    elif incomplete_codes:
        return ready_codes + list(map(functools.partial(get_codes_for_note, tree, note, level), incomplete_codes))
    else:
        return codes

---
#### Get ICD9 Codes

In [None]:
### SIMPLE TEST ###

res = flatten(get_codes_for_note(tree, "Tuberculosis of the bones and joints and HIV", 3, "ROOT"))
print(list(set(res)))

#### END SIMPLE TEST ###

In [None]:
results = []
df['Generated'] = ""
for index, row in df.iterrows():

    # Parse Note
    note = ast.literal_eval(row['TEXT'])[0]
    # print(f"Note: {note}")
    # Get Codes
    result = list(set(flatten(get_codes_for_note(tree, note, 3, "ROOT")))) # Change level here if needed

    # Add result to DF
    df.at[index, 'Generated'] = str(result)

In [None]:
# View Results

display(df[['ICD9_CODE', 'Generated']].head(10))

---
#### Score Results

In [11]:
# Scoring Functions

def recall_score(truth, generated):
    actual_list = ast.literal_eval(truth)
    generated_list = ast.literal_eval(generated)

    similar = len(set(actual_list) & set(generated_list))

    return similar / len(actual_list)

def precision_score(truth, generated):
    actual_list = ast.literal_eval(truth)
    generated_list = ast.literal_eval(generated)

    if len(generated_list) == 0:
        return 0

    similar = len(set(actual_list) & set(generated_list))

    return similar / len(generated_list)

def f1_score(truth, generated):
    precision = precision_score(truth, generated)
    recall = recall_score(truth, generated)

    if precision + recall == 0:
        return 0
    else:
        return 2 * (precision * recall) / (precision + recall)


# Note Parsing Functions
def format_icd9(x):
    new_codes = []
    code_list = ast.literal_eval(x)
    for code in code_list:
        new_codes.append(f"{code:0>3}".format(num="1"))

    return str(new_codes)

In [None]:
# Grade final code outputs

results = pd.DataFrame()

results['ICD9_CODE'] = df['ICD9_CODE'].apply(format_icd9)
results['Recall'] = df.apply(lambda x: recall_score(x['ICD9_CODE'], x['Generated']), axis=1)
results['Precision'] = df.apply(lambda x: precision_score(x['ICD9_CODE'], x['Generated']), axis=1)
results['F1 Score'] = df.apply(lambda x: f1_score(x['ICD9_CODE'], x['Generated']), axis=1)

display(results[['Recall', 'Precision', 'F1 Score']].mean(axis=0))

---
#### Limitations

- The gpt-4o-mini model was used for demo purposes to save on cost, however, this model is not great at generalizing 'up' the ICD9 code tree. Predictions at the Chapter (L1) and sub-category (L2) level are poor, which lead to poor final results.
    - Check out [the fine tuning modules](./04a_aoai_ft_data_prep.ipynb) in this repo for techniques on how to conquer this problem