## Fine Tuning for Medical Coding  
#### Part 1: Data Preparation  

---

**Goal for this Notebook**
- Prepare a dataset to fine tune a model for L1 (Chapter) level classification. The reason for fine-tuning at a 'higher' level is to eliminate challenges assocaited with the long-tail problem. This exercise will fine tune a model for multi-label classification with 17 label options.
  
<small>[_Click here for a complete list of ICD9 Chapters_](https://en.wikipedia.org/wiki/List_of_ICD-9_codes)</small>

**Approach**
  
The dataset will be created using the ICD9 code tree to create descriptions for chapter classifications. For example, for a single chapter we can create many rows for training by using the ICD code name, child code names, and supplement information from UMLS.
<small>
```markdown
--> Chapter name / description  
    --> UMLS concept atoms  
    --> UMLS concept definitions  
    --> Children, grandchildren, etc node names / descriptions  
        --> UMLS concept atoms  
        --> UMLS concept definitions  
```
</small>  

**Data**

The dataset will need to be formatted in json format as follows:
<small>
```json
{"messages": [{"role": "system", "content": "<SYSTEM MSG>"}, {"role": "user", "content": "<PROMPT>"}, {"role": "assistant", "content": "<CODE>"}]}
{"messages": [{"role": "system", "content": "<SYSTEM MSG>"}, {"role": "user", "content": "<PROMPT>"}, {"role": "assistant", "content": "<CODE>"}]}
{"messages": [{"role": "system", "content": "<SYSTEM MSG>"}, {"role": "user", "content": "<PROMPT>"}, {"role": "assistant", "content": "<CODE>"}]}
```
</small>

In [None]:
from azure.ai.textanalytics import TextAnalyticsClient
from azure.core.credentials import AzureKeyCredential
from src.icd9_tree import ICD9
from dotenv import load_dotenv, find_dotenv
from textwrap import dedent

import pandas as pd
import json
import requests
import os
import re

load_dotenv(find_dotenv(), override=True)

In [4]:
# Authenticate to Client
# Authenticate the client using your key and endpoint 
key = os.getenv("LANGUAGE_KEY")
endpoint = os.getenv("LANGUAGE_ENDPOINT")

ta_credential = AzureKeyCredential(key)
client = TextAnalyticsClient(
        endpoint=endpoint, 
        credential=ta_credential)

---
#### Setup Code Tree

In [None]:
# Read ICD9 codes in as a tree and view top level. These 'Chapter' codes will be the labels for our fine tuned model.

tree = ICD9('icd9_codes_full.json')
chapter_codes = []
# list of top level codes (e.g., '001-139', ...)
toplevelnodes = tree.children
for node in toplevelnodes:
    if node.code[0] not in ['E', 'V']:
        print(node.code, node.description)
        chapter_codes.append(node.code)

---
#### Establish Helper Functions

In [6]:
# Function to get the UMLS CUID(s) for a given text
# This function uses Azure Text Analytics for Health

def get_umls_concepts(client, documents):
    umls_concepts = []
    poller = client.begin_analyze_healthcare_entities(documents)
    result = poller.result()

    docs = [doc for doc in result if not doc.is_error]

    for idx, doc in enumerate(docs):
        for entity in doc.entities:
            if entity.data_sources and entity.category in ['SymptonOrSign', 'Diagnosis']:
                for data_source in entity.data_sources:
                    if data_source.name == "UMLS":
                        umls_concepts.append(data_source.entity_id)

    return umls_concepts

In [7]:
# Function to get the UMLS atoms from a cuid

def get_umls_atoms(cuid):
    synonyms = []
    sabs = ['ICD10', 'ICD10CM', 'ICD9CM', 'SNOMEDCT_US', 'MDR']      
    atom_uri = f"https://uts-ws.nlm.nih.gov/rest/content/2024AA/CUI/{cuid}/atoms"
    page = 0  
    try:   
        while True:
            page += 1
            atom_query = {'apiKey':os.getenv("UMLS_API_KEY"), 'pageNumber':page, 'language':'ENG', 'sabs': ','.join(sabs)}
            a = requests.get(atom_uri, params=atom_query)
            a.encoding = 'utf-8'
            
            if a.status_code != 200:
                break

            all_atoms = a.json()
        
            for atom in all_atoms['result']:
                synonyms.append(re.sub("[\(\[].*?[\)\]]", "", atom['name']).lower().rstrip())
                #print(f'{atom}')

            return list(set(synonyms))
            
    except Exception as except_error:
        print(except_error)
        return

In [8]:
# Function to get UMLS definition list from a cuid

def umls_define(cuid):    
    definitions = []
    umls_uri = f"https://uts-ws.nlm.nih.gov/rest/content/current/CUI/{cuid}/definitions"
    root_sources = ['CSP','NCI','MSH','PDQ', 'MTH', 'HPO', 'DXP', 'SNMI', 'SNOMEDCT_US', 'ICD10CM', 'ICD10', 'ICD9CM', 'MDR']  
    page = 0  
    try:   
        while True:
            page += 1
            query = {'apiKey':os.getenv("UMLS_API_KEY"), 'pageNumber':page}
            a = requests.get(umls_uri, params=query)
            a.encoding = 'utf-8'
            
            if a.status_code != 200:
                break
            result = a.json()
        
            for value in result['result']:
                if value['rootSource'] in root_sources:
                    definitions.append(value['value'].lower().rstrip())

            return list(set(definitions))
            
    except Exception as except_error:
        print(except_error)
        return

In [None]:
# Combine the above functions to get json line for all atoms/definitions for a given ICD9 code

def generate_dataset(sys, description, l1_code, az_ta_cli, training_file):
    with open(training_file, "a") as file:

        desc_data = {"messages":[ {"role": "system", "content": sys}, {"role": "user", "content": description}, {"role": "assistant", "content": l1_code} ] }
        file.write(json.dumps(desc_data) + "\n")

        umls_concepts = get_umls_concepts(az_ta_cli, [description])
        for cuid in umls_concepts:
            
            atoms = get_umls_atoms(cuid)
            if atoms:
                for atom in atoms:
                    atom_data = {"messages":[ {"role": "system", "content": sys}, {"role": "user", "content": atom}, {"role": "assistant", "content": l1_code} ] }
                    file.write(json.dumps(atom_data) + "\n")

            definitions = umls_define(cuid)
            if definitions:
                for definition in definitions:
                    def_data = {"messages":[ {"role": "system", "content": sys}, {"role": "user", "content": definition}, {"role": "assistant", "content": l1_code} ] }
                    file.write(json.dumps(def_data) + "\n")

    file.close()
    return


### Test ###
# open('data/ft/training_data.jsonl', 'w').close()
# generate_dataset("System", "Diabetes", "001-139", client, "data/ft/training_data.jsonl")

---
#### Create Fine Tuning Training Dataset

In [None]:
# system prompt will be a constant in all examples

sys = 'Classify the following text into an ICD9 code chapter. The text is a clinical note from a patient medical record. ### You must choose from the following semi-colon delimited list of codes:{0} ### RESPOND ONLY WITH A CODE FROM THE LIST ABOVE.'.format('; '.join(chapter_codes))
print(dedent(sys))

In [None]:
# Build the dataset - woohoo! 
# (NOTE: This may take a while)
# TODO: Make this more efficient

output_file = "data/ft/training_data.jsonl"
open(output_file, 'w').close()

for L1_node in tree.children:
    if L1_node.code[0] not in ['E', 'V']:
        # Get all json lines at the L1 level
        # print(f"L1: {L1_node.code} - {L1_node.description}")
        generate_dataset(sys, L1_node.description, L1_node.code, client, output_file)
        for L2_node in L1_node.children:
            # print(f"L2: {L2_node.code} - {L2_node.description}")
            generate_dataset(sys, L2_node.description, L1_node.code, client, output_file)
            for L3_node in L2_node.children:
                # print(f"L3: {L3_node.code} - {L3_node.description}")
                generate_dataset(sys, L3_node.description, L1_node.code, client, output_file)
                for L4_node in L3_node.children:
                    # print(f"L4: {L4_node.code} - {L4_node.description}")
                    # generate_dataset(sys, L4_node.description, L1_node.code, client, output_file)
                    for L5_node in L4_node.children:
                        # print(f"L5: {L5_node.code} - {L5_node.description}")
                        # generate_dataset(sys, L5_node.description, L1_node.code, client, output_file)
                        pass


---
#### Examine the Dataset

In [None]:
df = pd.read_json(output_file, lines=True)
display(df.head(2))

print(df.shape)

In [None]:
# Long tail problem no more

df['code'] = df['messages'].apply(lambda x: x[2]['content'])
print(df['code'].value_counts())