# LLM Assisted Medical Coding Extraction for Healthcare

Traditionally, healthcare providers have relied on various tools and methodologies to estimate patient risk scores. However, these conventional methods often fall short in addressing the complexity and variability inherent in patient data. 

In this notebook, we will show you how to leverage an LLM and implement a [distillation flow](https://www.datacamp.com/blog/distillation-llm) with a Llama 405b model to generate training samples to teach a smaller model the code extraction task with the same accuracy but lower cost.

LLM distillation focuses on replicating the performance of a large model on a specific task by transferring its capabilities to a smaller model. This allows developers to achieve similar results to models like GPT-4 but with reduced computational cost and faster performance—though only for the targeted task.

In the notebook, we will show you how to fine-tune an LLM using [Cortex Fine-tuning](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-finetuning) to help extract ICD10 codes.

# Import packages

In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Data Engineering

Let's take a look at the synthetic medical files (PDFs) we'll be using for this demo.

In [None]:
SELECT * FROM DIRECTORY(@PUBLIC.REPORTS_DATA);

Now, we can use Snowflake's native Cortex AI functions ([PARSE_DOCUMENT](https://docs.snowflake.com/en/sql-reference/functions/parse_document-snowflake-cortex) and [SPLIT_TEXT_RECURSIVE_CHARACTER](https://docs.snowflake.com/en/sql-reference/functions/split_text_recursive_character-snowflake-cortex)) to read/parse the PDFs and chunk the data. Text is chunked into sections for easier processing, allowing overlap for context.

In [None]:
CREATE OR REPLACE TABLE DOCS_CHUNKS_TABLE AS
    SELECT
        RELATIVE_PATH, 
        SNOWFLAKE.CORTEX.SPLIT_TEXT_RECURSIVE_CHARACTER(
            to_variant(SNOWFLAKE.CORTEX.PARSE_DOCUMENT(
                @PUBLIC.REPORTS_DATA,
                RELATIVE_PATH,
                {'mode': 'layout'}
            )):content, 'markdown', 4000, 400) as chunks
from DIRECTORY(@PUBLIC.REPORTS_DATA);

CREATE OR REPLACE TABLE DOCS_CHUNKS_TABLE AS
SELECT RELATIVE_PATH, c.value::string as CHUNK 
FROM DOCS_CHUNKS_TABLE f, 
    LATERAL FLATTEN(INPUT => f.chunks) c;

In [None]:
SELECT * FROM DOCS_CHUNKS_TABLE;

Let's see the number of chunks created per doc.

In [None]:
SELECT RELATIVE_PATH, COUNT(*) AS NUM_CHUNKS
FROM DOCS_CHUNKS_TABLE
GROUP BY RELATIVE_PATH
ORDER BY NUM_CHUNKS DESC;

Now, let's create a table to store the needed metadata and create a few columns we will populate subsequently.

In [None]:
CREATE OR REPLACE TRANSIENT TABLE DOCS_AND_TEXT AS
SELECT
    RELATIVE_PATH,
    LISTAGG(CHUNK, ' ') AS DOC_TEXT,
    NULL AS REPORT,
    NULL AS SPECIALTY
FROM DOCS_CHUNKS_TABLE
WHERE RELATIVE_PATH LIKE '%.pdf'
GROUP BY ALL;

select * from docs_and_text limit 5;

In [None]:
SELECT * FROM DOCS_AND_TEXT;

Now, we can extract the `speciality` and the `report summary` using Snowflake's native LLM function, Complete(https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex), for LLM-assisted completion.

In [None]:
UPDATE DOCS_AND_TEXT AS L
SET L.REPORT = R.REPORT,
    L.SPECIALTY = R.SPECIALTY
FROM (
    SELECT
        RELATIVE_PATH,
        SNOWFLAKE.CORTEX.COMPLETE('llama3.1-8b', CONCAT(DOC_TEXT, 'In less than 5 words, how would you best describe the type of the document content?  Do not provide explanation. Remove special characters.')) AS REPORT,
        SNOWFLAKE.CORTEX.COMPLETE('llama3.1-8b', CONCAT(DOC_TEXT, 'What is the medical specialty? Do not provide explanation. Remove special characters.')) AS SPECIALTY
    FROM DOCS_AND_TEXT
) AS R
WHERE L.RELATIVE_PATH = R.RELATIVE_PATH;

In [None]:
SELECT * FROM DOCS_AND_TEXT LIMIT 5;

# Fine-tuning using `llama3.1-405B`

The model `llama3.1-405b` is used to extract ICD-10 codes from medical documents by prompting Snowflake Cortex Complete to identify relevant codes. Outputs are stored in a table called `LLAMA_OUTPUT_ICD`.

In [None]:
CREATE OR REPLACE TABLE LLAMA_OUTPUT_ICD AS
SELECT
    RELATIVE_PATH,
    DOC_TEXT,
    REPORT,
    SPECIALTY,
    SNOWFLAKE.CORTEX.COMPLETE('llama3.1-405b', CONCAT(DOC_TEXT, 'Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a list ONLY: []. DO NOT include any other text.')) AS AI_ICD10_CODE
FROM DOCS_AND_TEXT;


In [None]:
SELECT * FROM LLAMA_OUTPUT_ICD;

### Prepare the fine-tuning data

Let's split our data 70:30 as train/validation sets (common in model training workflows) for fine-tuning.

In [None]:
CREATE OR REPLACE TEMPORARY TABLE TEMP_SPLIT_TABLE AS
WITH NUMBERED_ROWS AS (
    SELECT
        *,
        ROW_NUMBER() OVER (ORDER BY RANDOM()) AS ROW_NUM,
        COUNT(*) OVER() AS TOTAL_ROWS
    FROM LLAMA_OUTPUT_ICD
)
SELECT
    *,
    CASE
        WHEN ROW_NUM < TOTAL_ROWS * 0.7 THEN 'train'
        WHEN ROW_NUM > TOTAL_ROWS * 0.7 AND ROW_NUM <= TOTAL_ROWS * 0.85 THEN 'val'
        ELSE 'test'
    END AS SPLIT
FROM NUMBERED_ROWS;

CREATE OR REPLACE TABLE CODEEXTRACTION_TRAINING AS
SELECT
    RELATIVE_PATH,
    DOC_TEXT,
    REPORT,
    SPECIALTY,
    AI_ICD10_CODE
FROM TEMP_SPLIT_TABLE
WHERE SPLIT = 'train';

CREATE OR REPLACE TABLE CODEEXTRACTION_TEST AS
SELECT
    RELATIVE_PATH,
    DOC_TEXT,
    REPORT,
    SPECIALTY,
    AI_ICD10_CODE
FROM TEMP_SPLIT_TABLE
WHERE SPLIT = 'test';

CREATE OR REPLACE TABLE CODEEXTRACTION_VAL AS
SELECT
    RELATIVE_PATH,
    DOC_TEXT,
    REPORT,
    SPECIALTY,
    AI_ICD10_CODE
FROM TEMP_SPLIT_TABLE
WHERE SPLIT = 'val';

In [None]:
SELECT * FROM CODEEXTRACTION_TRAINING LIMIT 2;

(Optional) If interested, you can also take a look at the baseline output from the smaller model `llama3-8b` 

In [None]:
CREATE OR REPLACE TABLE LLAMA38B_ICDOUTPUT AS
SELECT
    RELATIVE_PATH,
    DOC_TEXT,
    REPORT,
    SPECIALTY,
    AI_ICD10_CODE,
    SNOWFLAKE.CORTEX.COMPLETE('llama3-8b', CONCAT(DOC_TEXT, 'Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a list ONLY: []. DO NOT include any other text.')) AS LLAMA38B_ICD10_CODE
FROM LLAMA_OUTPUT_ICD;

In [None]:
SELECT * FROM llama38b_ICDOutput;

### Start the fine-tuning job

[Cortex Fine-tuning](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-finetuning) allows users to leverage parameter-efficient fine-tuning (PEFT) to create customized adaptors for use with pre-trained models on more specialized tasks. If you don't want the high cost of training a large model from scratch but need better latency and results than you're getting from prompt engineering or even retrieval augmented generation (RAG) methods, fine-tuning an existing large model is an option. Fine-tuning allows you to use examples to adjust the behavior of the model and improve the model’s knowledge of domain-specific tasks.

In [None]:
DROP MODEL IF EXISTS FINETUNE_llama38b_ICD10;

In [None]:
SELECT SNOWFLAKE.CORTEX.FINETUNE(
    'CREATE', 
    -- Custom model name, make sure name below is unique
    'FINETUNE_llama38b_ICD10',
    -- Base model name
    'llama3-8b',
    -- Training data query
    'SELECT doc_text || '' Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a list ONLY: []. DO NOT include any other text. '' AS PROMPT, AI_ICD10_Code AS COMPLETION FROM codeextraction_training',
    -- Test data query 
    'SELECT doc_text || '' Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a list ONLY: []. DO NOT include any other text. '' AS PROMPT, AI_ICD10_Code AS COMPLETION FROM codeextraction_val'
);

# STOP

Replace `<>` in the query below with the workflow id returned from the previous cell.

In [None]:
-- The output is the job ID of the fine-tuning job:
Select SNOWFLAKE.CORTEX.FINETUNE(
  'DESCRIBE',
'<>');--replace <> with the workflow id returned from the execution of last cell

# STOP -  PROCEED ONLY WHEN THE STATUS FIELD FOR THE JOB CHANGES TO `FINISHED` IN THE PREVIOUS CELL

# Inference with fine-tuned model

Now, we can apply our new model to our training data to extract codes from the fine-tuned model.

In [None]:
CREATE OR REPLACE TABLE LLAMA38B_ICD_CODES AS
SELECT
    RELATIVE_PATH,
    DOC_TEXT,
    REPORT,
    SPECIALTY,
    SNOWFLAKE.CORTEX.COMPLETE('FINETUNE_llama38b_ICD10', CONCAT(DOC_TEXT, 'Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a list ONLY: []. DO NOT include any other text.')) AS FT_ICD10_CODE
FROM CODEEXTRACTION_TRAINING;

In [None]:
llama38b_ICD_Code_FT_df=session.table('llama38b_ICD_Codes').to_pandas()
llama38b_ICD_Code_FT_df

# Takeaways from using LLM fine-tuning:
 1. **HIGHER ACCURACY** from the larger model (run once just for training)
 2. **LOWER COST** from using a smaller model in production
 3. **HIGH THROUGHPUT** from using a smaller model in production

# **Bonus:**
You can use Streamlit to visualize using the IDC-10 codes for calculating patient risk scores as a simple interactive app.

In [None]:
# Import python packages
import streamlit as st
from snowflake.snowpark.context import get_active_session
import pandas as pd
import numpy as np
import re
import random

st.title('❄️ Medical Coding Assistant ❄️')
st.subheader(
    """Calculate the risk score accurately by leveraging ICD10 Codes extracted by Fine Tuning a Llama3 with Cortex AI
    """
)

# Get the current credentials
session = get_active_session()
reports_df = session.table("llama38b_ICD_Codes").to_pandas()

# Fetch ICD codes and descriptions from the Snowflake table
def load_icd_data():
    codes = list(set(sum([x.split("\n") for x in reports_df['FT_ICD10_CODE'].tolist()], [])))
    cleaned_dict = {}
    for line in codes:
        match = re.match(r"^\[(.*?)\]\s*-?\s*(.*)$", line)
        if match:
            code = match.group(1)
            description = match.group(2)
            if description:  # Only add if description is not empty
                cleaned_dict[code] = description
    return cleaned_dict

cleaned_dict = load_icd_data()

def calculate_risk_score(icd_code):
    """Calculates a random risk score for a given ICD code."""
    return np.random.rand(1)[0] * 10

@st.cache_data
def create_patient_dataframe(cleaned_dict):
    patient_data = {
        'patient_id': [1, 2, 3, 4, 5],
        'name': ['Ava Lee', 'Jane Smith', 'Alice Johnson', 'Ella Rose', 'Riley Green'],
        'age': [45, 62, 30, 50, 40]
    }
    patient_df = pd.DataFrame(patient_data)
    patient_df['icd_code'] = [random.choice(list(cleaned_dict.keys())) for _ in range(len(patient_df))]
    patient_df['risk_score'] = patient_df['icd_code'].apply(calculate_risk_score)
    return patient_df

patient_df = create_patient_dataframe(cleaned_dict)

def display_icd_code_with_explanation(icd_code):
    """Displays the ICD code with its description."""
    explanation = cleaned_dict.get(icd_code, 'Explanation not available')
    return f"ICD Code: {icd_code} - {explanation}"

# Select a patient by name
patient_name = st.selectbox("Select Patient", patient_df['name'])

# Get the selected patient’s data
selected_patient = patient_df[patient_df['name'] == patient_name].iloc[0]
patient_icd_code = selected_patient['icd_code']
patient_risk_score = selected_patient['risk_score']

# Display ICD code and risk score
st.write("---")
st.subheader(display_icd_code_with_explanation(patient_icd_code))
st.write(f"**Risk Score:** {patient_risk_score:.2f}")
st.write("---")

# Associated Medical Reports section
st.subheader("Associated Medical Reports")

# Filter reports based on the patient's ICD code
filtered_reports = reports_df[reports_df['FT_ICD10_CODE'].str.contains(patient_icd_code, case=False, na=False)]

if not filtered_reports.empty:
    with st.expander(f"View Associated Reports"):
        for idx, report in filtered_reports.iterrows():
            st.write(f"**Report Name:** {report['RELATIVE_PATH']}")
            st.write(f"**Report Description:** {report['REPORT']}")
            st.write(f"**Speciality:** {report['SPECIALTY']}")
            st.write(f"**Extracted Text:** {report['DOC_TEXT']}")
            st.write("---")
else:
    with st.expander("No Reports Found"):
        st.write("No associated reports found for this ICD code.")