## Fine Tuning for Medical Coding  
#### Part 1: Data Preparation  

---

**Goal for this Notebook**
- Prepare a dataset to fine tune a model for L1 (Chapter) level classification. The reason for fine-tuning at a 'higher' level is to eliminate challenges assocaited with the long-tail problem. This exercise will fine tune a model for multi-label classification with 17 label options.
  
<small>[_Click here for a complete list of ICD9 Chapters_](https://en.wikipedia.org/wiki/List_of_ICD-9_codes)</small>

**Approach**
  
The dataset will be created using the ICD9 code tree to create descriptions for chapter classifications. For example, for a single chapter we can create many rows for training by using the ICD code name, child code names, and supplement information from UMLS.
<small>
```markdown
--> Chapter name / description  
    --> UMLS concept atoms  
    --> UMLS concept definitions  
    --> Children, grandchildren, etc node names / descriptions  
        --> UMLS concept atoms  
        --> UMLS concept definitions  
```
</small>  

**Data**

The dataset will need to be formatted in json format as follows:
<small>
```json
{"messages": [{"role": "system", "content": "<SYSTEM MSG>"}, {"role": "user", "content": "<PROMPT>"}, {"role": "assistant", "content": "<CODE>"}]}
{"messages": [{"role": "system", "content": "<SYSTEM MSG>"}, {"role": "user", "content": "<PROMPT>"}, {"role": "assistant", "content": "<CODE>"}]}
{"messages": [{"role": "system", "content": "<SYSTEM MSG>"}, {"role": "user", "content": "<PROMPT>"}, {"role": "assistant", "content": "<CODE>"}]}
```
</small>

In [166]:
from azure.ai.textanalytics import TextAnalyticsClient
from azure.core.credentials import AzureKeyCredential
from src.icd9_tree import ICD9
from dotenv import load_dotenv, find_dotenv
from textwrap import dedent

import pandas as pd
import numpy as np
import random
import requests
import os
import re

load_dotenv(find_dotenv(), override=True)
pd.set_option('display.max_colwidth', None)

In [2]:
# Authenticate to Client
# Authenticate the client using your key and endpoint 
key = os.getenv("LANGUAGE_KEY")
endpoint = os.getenv("LANGUAGE_ENDPOINT")

ta_credential = AzureKeyCredential(key)
client = TextAnalyticsClient(
        endpoint=endpoint, 
        credential=ta_credential)

---
#### Setup Code Tree

In [None]:
# Read ICD9 codes in as a tree and view top level. These 'Chapter' codes will be the labels for our fine tuned model.

tree = ICD9('icd9_codes_full.json')
chapter_codes = []
# list of top level codes (e.g., '001-139', ...)
toplevelnodes = tree.children
for node in toplevelnodes:
    if node.code[0] not in ['E', 'V']:
        print(node.code, node.description)
        chapter_codes.append(node.description)

---
#### Establish Helper Functions

In [None]:
# Get code from description
def get_chapter_code(description):
    code = None
    for node in tree.children:
        if node.description.strip() == description.strip():
            code = node.code
            break
        else:
            for child in node.children:
                if child.description.strip() == description.strip():
                    code = child.code
                    break
    return code

print(get_chapter_code('COMPLICATIONS OF PREGNANCY, CHILDBIRTH, AND THE PUERPERIUM'))

In [10]:
# Function to get the UMLS CUID(s) for a given text
# This function uses Azure Text Analytics for Health

def get_umls_concepts(client, documents):
    umls_concepts = []
    poller = client.begin_analyze_healthcare_entities(documents)
    result = poller.result()

    docs = [doc for doc in result if not doc.is_error]

    for idx, doc in enumerate(docs):
        for entity in doc.entities:
            if entity.data_sources and entity.category in ['SymptonOrSign', 'Diagnosis']:
                for data_source in entity.data_sources:
                    if data_source.name == "UMLS":
                        umls_concepts.append(data_source.entity_id)

    return umls_concepts

In [11]:
# Function to get the UMLS atoms from a cuid

def get_umls_atoms(cuid):
    synonyms = []
    sabs = ['ICD10', 'ICD10CM', 'ICD9CM', 'SNOMEDCT_US', 'MDR']      
    atom_uri = f"https://uts-ws.nlm.nih.gov/rest/content/2024AA/CUI/{cuid}/atoms"
    page = 0  
    try:   
        while True:
            page += 1
            atom_query = {'apiKey':os.getenv("UMLS_API_KEY"), 'pageNumber':page, 'language':'ENG', 'sabs': ','.join(sabs)}
            a = requests.get(atom_uri, params=atom_query)
            a.encoding = 'utf-8'
            
            if a.status_code != 200:
                break

            all_atoms = a.json()
        
            for atom in all_atoms['result']:
                synonyms.append(re.sub("[\(\[].*?[\)\]]", "", atom['name']).lower().rstrip())
                #print(f'{atom}')

            return list(set(synonyms))
            
    except Exception as except_error:
        print(except_error)
        return

In [12]:
# Function to get UMLS definition list from a cuid

def umls_define(cuid):    
    definitions = []
    umls_uri = f"https://uts-ws.nlm.nih.gov/rest/content/current/CUI/{cuid}/definitions"
    root_sources = ['CSP','NCI','MSH','PDQ', 'MTH', 'HPO', 'DXP', 'SNMI', 'SNOMEDCT_US', 'ICD10CM', 'ICD10', 'ICD9CM', 'MDR']  
    page = 0  
    try:   
        while True:
            page += 1
            query = {'apiKey':os.getenv("UMLS_API_KEY"), 'pageNumber':page}
            a = requests.get(umls_uri, params=query)
            a.encoding = 'utf-8'
            
            if a.status_code != 200:
                break
            result = a.json()
        
            for value in result['result']:
                if value['rootSource'] in root_sources:
                    definitions.append(value['value'].lower().rstrip())

            return list(set(definitions))
            
    except Exception as except_error:
        print(except_error)
        return

In [None]:
# generate pd dataset

def generate_dataset(description, chapter, az_ta_cli, dataset_list):
    dataset_list.append({'description': description, 'chapter': chapter})

    umls_concepts = get_umls_concepts(az_ta_cli, [description])
    for cuid in umls_concepts:
        
        atoms = get_umls_atoms(cuid)
        if atoms:
            for atom in atoms:
                dataset_list.append({'description': atom, 'chapter': chapter})

        definitions = umls_define(cuid)
        if definitions:
            for definition in definitions:
                dataset_list.append({'description': definition, 'chapter': chapter})
    return

# Test
test_list = []
generate_dataset('COMPLICATIONS OF PREGNANCY, CHILDBIRTH, AND THE PUERPERIUM', '001-139', client, test_list)
print(test_list)

---
#### Create Fine Tuning Training Dataset

In [190]:
# Build the dataset 
# NOTE: This may take a while (15 min per 1500 samples)
# TODO: Make this more efficient

ft_df_list = []

for L1_node in tree.children:
    if L1_node.code[0] not in ['E', 'V']:
        # print(f"L1: {L1_node.code} - {L1_node.description}")
        generate_dataset(L1_node.description, L1_node.description, client, ft_df_list)
        for L2_node in L1_node.children:
            # print(f"L2: {L2_node.code} - {L2_node.description}")
            generate_dataset(L2_node.description, L1_node.description, client, ft_df_list)
            for L3_node in L2_node.children:
                # print(f"L3: {L3_node.code} - {L3_node.description}")
                generate_dataset(L3_node.description, L1_node.description, client, ft_df_list)
                for L4_node in L3_node.children:
                    # print(f"L4: {L4_node.code} - {L4_node.description}")
                    # generate_dataset(L4_node.description, L1_node.description, client, ft_df_list)
                    for L5_node in L4_node.children:
                        # print(f"L5: {L5_node.code} - {L5_node.description}")
                        # generate_dataset(L5_node.description, L1_node.description, client, ft_df_list)
                        pass

In [None]:
# Load to dataframe and examine data
ft_df = pd.DataFrame(ft_df_list)
ft_df.chapter = ft_df.chapter.apply(lambda x: x.strip())
print(ft_df.shape)
print(ft_df.dtypes)

In [214]:
# Add multi-label examples to the dataframe (normal distribution with a mean of 6 labels per example [min 1, max 12])

def multi_sample(code_count, sample_count):
    new_rows = []
    for i in range(sample_count):
        code_samples = list(map(str.strip, random.sample(chapter_codes, code_count)))
        item = {'description': '', 'chapter': ';'.join(code_samples)}
        desciption_list = []
        for chapter in code_samples:
            sample = ft_df[ft_df['chapter']==chapter].sample(1)
            desciption_list.append(sample['description'].values[0])

        item['description'] = ','.join(desciption_list)   
        new_rows.append(item)
    return new_rows


In [215]:
# Add multi-label examples to the dataframe
# TODO: Add data according to the distribution of the MIMIC-III dataset; this is a rough estimate

sample_multiplier = 100

ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(2,1*sample_multiplier))], ignore_index=True)
ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(3,2*sample_multiplier))], ignore_index=True)
ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(4,3*sample_multiplier))], ignore_index=True)
ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(5,5*sample_multiplier))], ignore_index=True)
ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(6,10*sample_multiplier))], ignore_index=True)
ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(7,6*sample_multiplier))], ignore_index=True)
ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(8,5*sample_multiplier))], ignore_index=True)
ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(9,4*sample_multiplier))], ignore_index=True)
ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(10,3*sample_multiplier))], ignore_index=True)
ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(11,2*sample_multiplier))], ignore_index=True)
ft_df = pd.concat([ft_df, pd.DataFrame(multi_sample(12,1*sample_multiplier))], ignore_index=True)

In [None]:
print(ft_df.shape)
display(ft_df.sample(3))

In [None]:
# define system prompt
# system prompt will be a constant in all examples
sys = 'Classify the following text into an ICD9 code chapter. The text is a clinical note from a patient medical record. ### You must choose from the following semi-colon delimited list of codes:{0} ### RESPOND ONLY WITH A CODE FROM THE LIST ABOVE.'.format('; '.join(chapter_codes))
print(dedent(sys))

In [None]:
# apply formatting to each row
ft_df["chapter"] = ft_df.chapter.apply(lambda x: {"role": "assistant", "content": x})
ft_df["description"] = ft_df.description.apply(lambda x: {"role": "user", "content": x})
ft_df['sys'] = sys
ft_df["sys"] = ft_df.sys.apply(lambda x: {"role": "system", "content": x})
ft_df = ft_df.reindex(columns=['sys', 'description', 'chapter'])

out_df = pd.DataFrame()
out_df = ft_df.apply(lambda x: {"messages": x.values}, axis=1)

display(out_df.head(10))

In [219]:
# write to file
output_file_name = "data/ft/training_data_L1toL3_multi.jsonl"
out_df.to_json(output_file_name, orient="records", lines=True)

---
#### Examine the Dataset

In [None]:
df = pd.read_json(output_file_name, lines=True)
print(df.shape)

In [None]:
# confirm value counts
df['code'] = df['messages'].apply(lambda x: x[1]['content'])
print(df['code'].value_counts())