In [1]:
import os # new code
import sys # new code
import math
import json
import time
import psutil
import torch
import pickle    
import signal # new code
import random
import logging
import numpy as np
import transformers
import pandas as pd
from tqdm import tqdm
from torch import optim
from collections import Counter
import matplotlib.pyplot as plt

from datasets import Dataset, load_dataset
from datetime import datetime
from types import SimpleNamespace
from transformers import AutoConfig, AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import EarlyStoppingCallback, TrainerCallback
from transformers import get_scheduler
from sentence_transformers.readers import InputExample
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.datasets import SentenceLabelDataset
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction
from sentence_transformers.losses import TripletLoss, MultipleNegativesRankingLoss, GISTEmbedLoss
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util, SentenceTransformerTrainingArguments, SentenceTransformerTrainer


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
file_path = '../../../datasets/dataset_training'

TRAIN_DATASET_PATH = f'{file_path}/triplet_dataset_v50_250_queries_10positives_50hn_train_08112025.csv' # change the dataset type here
EVAL_DATASET_PATH = f'{file_path}/triplet_dataset_v50_250_queries_10positives_50hn_eval_08112025.csv'
TEST_DATASET_PATH = f'{file_path}/triplet_dataset_v50_250_queries_10positives_50hn_test_08112025.csv'
    

def build_huggingface_dataset():
    data_files = {
        'train': TRAIN_DATASET_PATH,
        'eval':  EVAL_DATASET_PATH,
        'test':  TEST_DATASET_PATH,
    }
    # Streaming reads rows lazily from disk
    train_dataset = load_dataset('csv', data_files=data_files, split = 'train') #, streaming=True)
    eval_dataset = load_dataset('csv', data_files=data_files, split = 'eval')
    test_dataset = load_dataset('csv', data_files=data_files, split = 'test')
    
    # Remove only extraneous columns (if any); retain anchor/positives/negatives
    train_dataset = train_dataset.select_columns(['anchor', 'positives', 'negatives'])
    eval_dataset = eval_dataset.select_columns(['anchor', 'positives', 'negatives'])
    test_dataset = test_dataset.select_columns(['anchor', 'positives', 'negatives'])

    model_path = '../../../../shekhar_tanwar/ICD-ICD-Triplet/model/e5-large-v2-20250331143312-finetuned-icd-v30/'
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)

    def add_triplet_length(batch):
        # Concatenate to one long list to tokenize once.
        n = len(batch['anchor'])

        anchors = [str(x) if x is not None else "" for x in batch["anchor"]]
        positives = [str(x) if x is not None else "" for x in batch["positives"]]
        negatives = [str(x) if x is not None else "" for x in batch["negatives"]]
        
        all_texts = anchors + positives + negatives
    
        # Fast batch tokenization; ask only for lengths
        out = tokenizer(
            all_texts,
            add_special_tokens=False,
            return_length=True,
            padding=False,
            truncation=False
        )
        lens = np.asarray(out['length'], dtype=np.int32)
    
        # Slice back by role
        anchor_len   = lens[:n]
        positive_len = lens[n:2*n]
        negative_len = lens[2*n:]
    
        # Max length across the triplet
        triplet_len = np.maximum.reduce([anchor_len, positive_len, negative_len])
    
        return {'triplet_length': triplet_len.tolist()}


    # Choose large batch size; tune based on RAM/CPU cache
    BATCHED_SIZE = 8192  # try 8k, 16k, 32k
    NUM_PROC = max(2, os.cpu_count() - 2)

    train_dataset = train_dataset.map(
        add_triplet_length,
        batched=True,
        batch_size=BATCHED_SIZE,
        num_proc=NUM_PROC,
        remove_columns=[],          # don't drop anything
        desc="Computing Train triplet lengths"
    )

    eval_dataset = eval_dataset.map(
        add_triplet_length,
        batched=True,
        batch_size=BATCHED_SIZE,
        num_proc=NUM_PROC,
        remove_columns=[],          # don't drop anything
        desc="Computing Eval triplet lengths"
    )

    test_dataset = test_dataset.map(
        add_triplet_length,
        batched=True,
        batch_size=BATCHED_SIZE,
        num_proc=NUM_PROC,
        remove_columns=[],          # don't drop anything
        desc="Computing Test triplet lengths"
    )

    
    return train_dataset, eval_dataset, test_dataset
    

In [3]:
train_dataset, eval_dataset, test_dataset = build_huggingface_dataset()

Generating train split: 51118300 examples [01:20, 635570.51 examples/s]
Generating eval split: 231500 examples [00:01, 154148.57 examples/s]
Generating test split: 417100 examples [00:00, 656751.28 examples/s]
Computing Train triplet lengths (num_proc=94): 100%|██████████| 51118300/51118300 [01:31<00:00, 558957.82 examples/s] 
Computing Eval triplet lengths (num_proc=94): 100%|██████████| 231500/231500 [00:01<00:00, 183011.28 examples/s]
Computing Test triplet lengths (num_proc=94): 100%|██████████| 417100/417100 [00:01<00:00, 300773.23 examples/s]


In [51]:
from datasets import concatenate_datasets

def get_batches_by_length(data):
    

    b0 = data.filter(lambda ex : ex['triplet_length'] <= 64).shuffle(seed = 42)
    b1 = data.filter(lambda ex : 64 < ex['triplet_length'] <= 128).shuffle(seed = 42)
    b2 = data.filter(lambda ex : 128 < ex['triplet_length']).shuffle(seed = 42)

    merged_data = concatenate_datasets([b0, b1, b2])

    return merged_data

train_dataset =  get_batches_by_length(train_dataset)
eval_dataset =  get_batches_by_length(eval_dataset)
test_dataset =  get_batches_by_length(test_dataset)

    

Filter: 100%|██████████| 51118300/51118300 [02:20<00:00, 363816.51 examples/s]
Filter: 100%|██████████| 51118300/51118300 [02:07<00:00, 401391.29 examples/s]
Filter: 100%|██████████| 51118300/51118300 [02:05<00:00, 406623.60 examples/s]


In [6]:
output_dir = '../../../datasets/dataset_training/triplet_v50_splits'

train_dataset.save_to_disk(f'{output_dir}/train')
eval_dataset.save_to_disk(f'{output_dir}/eval')
test_dataset.save_to_disk(f'{output_dir}/test')

Saving the dataset (14/14 shards): 100%|██████████| 51118300/51118300 [01:16<00:00, 667699.83 examples/s] 
Saving the dataset (1/1 shards): 100%|██████████| 231500/231500 [00:00<00:00, 534506.25 examples/s] 
Saving the dataset (1/1 shards): 100%|██████████| 417100/417100 [00:00<00:00, 560993.86 examples/s] 


In [40]:
import pandas as pd

file_path = '../../../datasets/dataset_training'

TRAIN_DATASET_PATH = f'{file_path}/triplet_dataset_v50_250_queries_10positives_50hn_train_08112025.csv' # change the dataset type here
EVAL_DATASET_PATH = f'{file_path}/triplet_dataset_v50_250_queries_10positives_50hn_eval_08112025.csv'
TEST_DATASET_PATH = f'{file_path}/triplet_dataset_v50_250_queries_10positives_50hn_test_08112025.csv'

train_data = pd.read_csv(TRAIN_DATASET_PATH).iloc[:,1:]
eval_data = pd.read_csv(EVAL_DATASET_PATH).iloc[:,1:]
test_data = pd.read_csv(TEST_DATASET_PATH).iloc[:,1:]

print(f'train data shape : {train_data.shape}')
print(f'eval data shape : {eval_data.shape}')
print(f'test data shape : {test_data.shape}')

train_data = train_data.dropna()
eval_data = eval_data.dropna()
test_data = test_data.dropna()

print(f'train data shape : {train_data.shape}')
print(f'eval data shape : {eval_data.shape}')
print(f'test data shape : {test_data.shape}')

train_data.to_csv(TRAIN_DATASET_PATH)
eval_data.to_csv(EVAL_DATASET_PATH)
test_data.to_csv(TEST_DATASET_PATH)


train data shape : (51119000, 4)
eval data shape : (231500, 4)
test data shape : (417100, 4)
train data shape : (51118300, 4)
eval data shape : (231500, 4)
test data shape : (417100, 4)


In [3]:
#train_dataset, eval_dataset, test_dataset = build_huggingface_dataset()

In [4]:
file_path = '../../../datasets/dataset_training'

TRAIN_DATASET_PATH = f'{file_path}/triplet_dataset_v50_250_queries_10positives_50hn_train_08112025.csv' # change the dataset type here
EVAL_DATASET_PATH = f'{file_path}/triplet_dataset_v50_250_queries_10positives_50hn_eval_08112025.csv'
TEST_DATASET_PATH = f'{file_path}/triplet_dataset_v50_250_queries_10positives_50hn_test_08112025.csv'

train_data = pd.read_csv(TRAIN_DATASET_PATH).iloc[:,1:]
eval_data = pd.read_csv(EVAL_DATASET_PATH).iloc[:,1:]
test_data = pd.read_csv(TEST_DATASET_PATH).iloc[:,1:]

print(f'train data shape : {train_data.shape}')
print(f'eval data shape : {eval_data.shape}')
print(f'test data shape : {test_data.shape}')


train data shape : (51118300, 4)
eval data shape : (231500, 4)
test data shape : (417100, 4)


In [37]:
train_data.head()

Unnamed: 0,specialty,anchor,positives,negatives
0,acupuncturist_acupuncturist,acupuncture for swelling,Pain in right lower limb NOS,"Abscess of bursa, wrist"
1,acupuncturist_acupuncturist,acupuncture for swelling,Pain in right lower limb NOS,"Effusion, unspecified hand"
2,acupuncturist_acupuncturist,acupuncture for swelling,Pain in right lower limb NOS,Whitlow
3,acupuncturist_acupuncturist,acupuncture for swelling,Pain in right lower limb NOS,Other acne
4,acupuncturist_acupuncturist,acupuncture for swelling,Pain in right lower limb NOS,"Localized swelling, mass and lump, neck"


In [38]:
train_data = train_data.dropna()

In [39]:
train_data.shape

(51118300, 4)

In [None]:
train_data.to_csv('')

In [1]:
import os
import json
import time
import openai
import pickle
import warnings
import requests
import pandas as pd
from tqdm import tqdm
from typing import List
from collections import defaultdict
from azureml.core import Workspace
from azure.identity import DefaultAzureCredential
from azureml.core.authentication import ServicePrincipalAuthentication

from langchain.chains import LLMChain
from langchain_openai import AzureChatOpenAI
from langchain_core.prompts import  PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts.few_shot import FewShotPromptTemplate

from langchain.output_parsers import PydanticOutputParser, JsonOutputKeyToolsParser, CommaSeparatedListOutputParser
pd.set_option('display.max_rows', None)
warnings.filterwarnings("ignore")
ws = Workspace.from_config()


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


# COMBINING LIST OF DICTIONARY INTO SINGLE DICTIONARY

In [2]:
def load_splits_chunks(file_path : str, output_file_path : str):
    
    all_files = [file_path + file for file in os.listdir(file_path) if '.json' in file]

    all_split_files_dict = []
    for file_path in all_files:
        with open(file_path, 'r') as file:
            data = json.load(file)
        all_split_files_dict.append(data)


    
    # COMBINING LIST OF DICTIONARY INTO SINGLE DICTIONARY AND FILTER OUT QUERIES WITH NO ICD CODES
    specialty_query_codes_dict = {}
    for data in all_split_files_dict:
        specialty = list(data.keys())[0]
        query_codes_dict = list(data.values())[0]
    
        query_codes = {}
        for query, codes in query_codes_dict.items():
            if len(codes) == 0:
                continue
            else:
                query_codes[query] = codes
    
        specialty_query_codes_dict[specialty] = query_codes


    with open(output_file_path, 'w') as file:
        json.dump(specialty_query_codes_dict, file, indent = 4)

    return all_split_files_dict , specialty_query_codes_dict

In [3]:
file_path = '../../../datasets/datasets_augmented/final_dataset_v40/icd_filtered/filtered_icd_codes/'

output_file_path = '../../../datasets/datasets_augmented/final_dataset_v40/specialty_verification/filtered_specialty_query_dict.json'
all_split_files_dict , specialty_query_codes_dict = load_splits_chunks(file_path = file_path, output_file_path = output_file_path)

In [4]:
all_specialties = list(specialty_query_codes_dict.keys())

# CREATING SPLITS FOR FILTERED DICTIONARY

In [5]:

def split_input_data(input_file : str, output_dir : str, num_chunks : int):
    
    splits_dir = f"{output_dir}splits/"
    all_file_paths = [splits_dir + file for file in splits_dir if 'json' in file]
    if len(all_file_paths) == num_chunks:
        all_split_files = []
        
        for input_file in all_file_paths:
            with open(input_file, 'r') as f:
                specialty_data = json.load(f)
            
            all_split_files.append(specialty_data)
            
    else:
    
        # Create output directory
        splits_dir = f"{output_dir}splits/"
        os.makedirs(splits_dir, exist_ok=True)

        # Load input data
        with open(input_file, 'r') as f:
            specialty_data = json.load(f)

        # Get list of specialties
        specialties = list(specialty_data.keys())
        chunk_size = len(specialties) // num_chunks + (1 if len(specialties) % num_chunks else 0)

        all_split_files = []
        # Create chunks
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min(start_idx + chunk_size, len(specialties))

            chunk_specialties = specialties[start_idx:end_idx]
            chunk_data = {specialty: specialty_data[specialty] for specialty in chunk_specialties}

            # Save chunk
            filename = input_file.split('/')[-1].split('.')[0]
            chunk_file = f"{splits_dir}{filename}_split_{i}.json"
            all_split_files.append(chunk_data)
            with open(chunk_file, 'w') as f:
                json.dump(chunk_data, f, indent=4)

            #print(f"Chunk {i}: {len(chunk_specialties)} specialties, saved to {chunk_file}")

    return all_split_files 

In [6]:
input_file_filtered_specialty_query_code_dict = '../../../datasets/datasets_augmented/final_dataset_v40/specialty_verification/filtered_specialty_query_dict.json'

output_dir_filtered_specialty_query_dict = '../../../datasets/datasets_augmented/final_dataset_v40/specialty_verification/filtered_specialty_query_splits/'

num_chunks = 4
print(f'Creating Split for specialty_query_dict')
all_split_files_specialty_query_code_dict = split_input_data(input_file = input_file_filtered_specialty_query_code_dict, output_dir = output_dir_filtered_specialty_query_dict , num_chunks = num_chunks)

Creating Split for specialty_query_dict


# CREATING SPECIALTY QUERY CODE DESCRIPTION DICT

In [7]:
def get_icd_dataset(icd_reference_file : str):
    icd_reference_file = '../../../../shekhar_tanwar/ICD-ICD-Triplet/dataset/icd10.csv'
    dataset_icd = pd.read_csv(icd_reference_file).iloc[:,1:]
    
    dataset_icd = dataset_icd.drop_duplicates()
    dataset_icd = dataset_icd.iloc[:,13:15]
    dataset_icd.columns = ['ICD_Codes','Description']
    dataset_icd['ICD_Codes'] = dataset_icd['ICD_Codes'].apply(lambda x : x.strip())
    dataset_icd['Description'] = dataset_icd['Description'].apply(lambda x : x.strip())
    dataset_icd = dataset_icd.drop_duplicates(subset = ['ICD_Codes'], keep = 'first')

    icd_reference_lookup = {}

    for index, row in dataset_icd.iterrows():
        icd_reference_lookup[row.ICD_Codes] = row.Description

    return dataset_icd, icd_reference_lookup

In [8]:
icd_reference_file = '../../../../shekhar_tanwar/ICD-ICD-Triplet/dataset/icd10.csv'    
dataset_icd, icd_reference_lookup = get_icd_dataset(icd_reference_file = icd_reference_file) 

In [9]:

def get_query_icd_code_description_dataset(icd_reference_lookup : dict,  specialty_query_codes_dict : dict):
    
    # path to the filtered specialt query code dict files
    # file_path = '../../../datasets/datasets_augmented/final_dataset_v40/icd_filtered/filtered_icd_codes/'
    # read all the files in the file_path


    specialties = list(specialty_query_codes_dict.keys())
    specialty_query_code_desciption_dict = {}
    problematic_specialty_list = []
    
    for i in tqdm(range(len(specialties))):
        medical_specialty_subspecialty = specialties[i]

        try:
            query_codes_dict = specialty_query_codes_dict.get(medical_specialty_subspecialty)
            query_code_description_dict = {}
            
            
            for medical_query, retrieved_codes_gpt41 in query_codes_dict.items():
                icd_code_description_list = []
                # final selected_codes
                for code in retrieved_codes_gpt41:
                    if code in icd_reference_lookup:
                        icd_code_description_list.append(code + " : " + icd_reference_lookup.get(code))
        
                query_code_description_dict[medical_query] = icd_code_description_list
    
            specialty_query_code_desciption_dict[medical_specialty_subspecialty] = query_code_description_dict

        except:
            problematic_specialty_list.append(specialty)
    
    with open('../../../datasets/datasets_augmented/final_dataset_v40/icd_filtered/specialty_query_code_desciption_splits/filtered_specialty_query_code_desciption_dict.json', 'w') as file:
       json.dump(specialty_query_code_desciption_dict, file, indent = 4)             

    return specialty_query_code_desciption_dict, problematic_specialty_list
        

In [10]:
specialty_query_code_desciption_dict, _  = get_query_icd_code_description_dataset(icd_reference_lookup = icd_reference_lookup,  specialty_query_codes_dict = specialty_query_codes_dict)

100%|██████████| 591/591 [00:01<00:00, 511.17it/s]


# SPECIALTY VERIFICATION

In [13]:
class BearerAuth(requests.auth.AuthBase):
    def __init__(self, token):
        self.token = token
    def __call__(self, r):
        r.headers["authorization"] = "Bearer " + self.token
        return r

In [14]:
def initialize_llm(model_name) -> AzureChatOpenAI:
    ws = Workspace.from_config()
    keyvault = ws.get_default_keyvault()
    credential = DefaultAzureCredential()
    workspacename = keyvault.get_secret("project-workspace-name")
    access_token = credential.get_token("https://cognitiveservices.azure.com/.default")
    os.environ["AZURE_OPENAI_KEY"] = access_token.token
    openai.api_type = "azure_ad"
    os.environ["AZURE_OPENAI_ENDPOINT"] = f"https://{workspacename}openai.openai.azure.com/"
    openai.api_version = "2023-07-01-preview"
    subscriptionId = keyvault.get_secret("project-subscription-id")
    # Ensure you have these environment variables set up with your Azure OpenAI credentials
    os.environ["AZURE_OPENAI_API_KEY"] = "ee0dd46654bd4427ba4f5580b5a0db0a"
    os.environ["AZURE_OPENAI_API_BASE"] = "https://xqrojjmb2wjlqopopenai.openai.azure.com/"

    if model_name == "gpt-4o":
        os.environ["AZURE_OPENAI_API_VERSION"] = "2024-05-01-preview"
        os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "gpt-4o"
    
        subscriptionId = keyvault.get_secret("project-subscription-id")
        apiVersion = "2023-10-01-preview"
        url = f"https://management.azure.com/subscriptions/{subscriptionId}/resourceGroups/{workspacename}-common/providers/Microsoft.CognitiveServices/accounts/{workspacename}openai/deployments?api-version={apiVersion}"
        accessToken = credential.get_token("https://management.azure.com/.default")
        response = requests.get(url, auth=BearerAuth(accessToken.token))
        print(f'Initializing Model : {os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"]}')
        model = AzureChatOpenAI(
                    deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
                    azure_endpoint=os.environ["AZURE_OPENAI_API_BASE"],
                    openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
                    openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
                    max_tokens=4000,
                    temperature=0.9,
                    model_kwargs={"seed": 1337}
                )
        
        print(f'Model {model_name} Initialized')

    elif model_name == "gpt-4.1":

        os.environ["AZURE_OPENAI_API_VERSION"] = "2024-12-01-preview"
        os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "gpt-4.1"
    
    
        subscriptionId = keyvault.get_secret("project-subscription-id")
        apiVersion = "2024-12-01-preview"
        url = f"https://management.azure.com/subscriptions/{subscriptionId}/resourceGroups/{workspacename}-common/providers/Microsoft.CognitiveServices/accounts/{workspacename}openai/deployments?api-version={apiVersion}"
        accessToken = credential.get_token("https://management.azure.com/.default")
        response = requests.get(url, auth=BearerAuth(accessToken.token));
    
        print(f'Initializing Model : {os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"]}')
        model = AzureChatOpenAI(
                    deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
                    azure_endpoint=os.environ["AZURE_OPENAI_API_BASE"],
                    openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
                    openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
                    max_tokens=4000,
                    temperature=0.9,
                    model_kwargs={"seed": 1337}
                )
        
        print(f'Model {model_name} Initialized')

    return model 


In [41]:
def get_specialty_verification(model : AzureChatOpenAI, medical_specialty_subspecialty : str, medical_query : str, icd_code_description_list : list):


    class SpecialtiesResponse(BaseModel):
        queries: List[str] = Field(description="List of queries corresponding to user provided medical specialty")

    # Set up the PydanticOutputParser with the SpecialtiesResponse model
    output_parser = CommaSeparatedListOutputParser()


    system_prompt = """ You are a certified medical coder who assigns ICD-10 codes.

            **Goal:**  
            Given (1) a medical search query, (2) a user-supplied list of **ICD-10 code: description** pairs, and (3) a reference medical **specialty_subspecialty**, identify whether the reference **specialty_subspecialty** is **relevant** or **non-relevant** to the medical query and the ICD-10 code-description list. The result should be **non-relevant** if the reference **specialty_subspecialty** is either too generic for the query/ICD-10 list’s intent or clearly unrelated to that query and code list.
            
            **How to decide relevance:**  
            - **Understand the query’s clinical intent:** Determine what condition, symptom, or scenario the query is describing.  
            - **Consider the ICD-10 code list context:** The ICD-10 codes and descriptions are chosen to be consistent with each other and with the query, representing a specific clinical scenario. Use this combined context to inform your decision.  
            - **Match with the specialty_subspecialty:** Evaluate whether a practitioner of the reference **specialty_subspecialty** typically addresses the query’s scenario:  
              - The specialty_subspecialty pair may consist of a broad specialty and a more focused subspecialty. Emphasize the general **specialty** domain. If the scenario falls under that general field (even if not an exact subspecialty match), it can be considered relevant.  
              - If the scenario (query + codes) **does not fall under** the domain of the reference specialty_subspecialty — for example, the specialty_subspecialty is overly broad/vague for this specific case, or it pertains to a different field of medicine — then label it **non-relevant**.  
              - If the scenario **does** fall under the clinical domain of that specialty_subspecialty (i.e. a provider of that type would reasonably handle such cases), then label it **relevant**.  
            - **Multiple possible specialties:** There may be cases where the query and codes could belong to more than one specialty. You are **only** checking the given reference specialty_subspecialty. If the given specialty_subspecialty is one appropriate choice for this scenario, mark it **relevant** (even if other specialties could also be involved).  
            - **If unsure:** If you cannot confidently determine relevance from the information provided, label the result as `CANNOT_DECIDE`.
            
            **Response format (strict):**  
            Return **only** a single label as the answer: `relevant`, `non-relevant`, or `CANNOT_DECIDE` (use `CANNOT_DECIDE` only if you truly cannot decide). Do **not** include explanations, reasoning, or any additional text.
            
            **Inputs (to be inserted at runtime):**  
            medical_query: *{medical_query}*  
            icd_code_description_list: *{icd_code_description_list}*  
            medical_specialty_subspecialty: *{medical_specialty_subspecialty}*  
            
            **Few-shot guidance (examples):**
            
            Example 1:  
            medical_query: **aging and decreased independence and mobility**  
            icd_code_description_list: **['Z74.3 : Need for continuous supervision', 'Z73.89 : Other problems related to life management difficulty', 'Z73.6 : Limitation of activities due to disability', 'Z60.0 : Phase of life problem', 'Z74.2 : Need for assistance at home and no other household member able to render care', 'Z74.1 : Need for assistance with personal care', 'Z74.09 : Other reduced mobility', 'R54 : Senile debility', 'Z91.81 : History of falling']**  
            medical_specialty_subspecialty: **adult companion_adult companion**  
            Expected output: **relevant**
            
            Example 2:  
            medical_query: **acupuncture for headaches**  
            icd_code_description_list: **['R51.9 : Headache, unspecified', 'G44.209 : Tension-type headache, unspecified, not intractable', 'G43.009 : Migraine without aura NOS', 'G44.89 : Other headache syndrome', 'G43.909 : Migraine NOS', 'G43.709 : Chronic migraine without aura NOS']**  
            medical_specialty_subspecialty: **anesthesiology_addiction medicine**  
            Expected output: *non-relevant**
            
            Example 3:  
            medical_query: **specialty biologic and injectable therapies in healthcarer**  
            icd_code_description_list: **['T88.59XA : Other complications of anesthesia, initial encounter', 'T41.1X5A : Adverse effect of intravenous anesthetics, initial encounter', 'T88.7XXA : Unspecified adverse effect of drug or medicament, initial encounter']**  
            medical_specialty_subspecialty: **cliniccenter_student health**  
            Expected output: **non-relevant**
            
            Example 4:  
            medical_query: **itchy scalp after workplace exposure**  
            icd_code_description_list: **['L23.9 : Allergic contact dermatitis, unspecified cause', 'L28.0 : Circumscribed neurodermatitis', 'L25.9 : Unspecified contact dermatitis, unspecified cause', 'L23.8 : Allergic contact dermatitis due to other agents', 'L23.5 : Allergic contact dermatitis due to plastic', 'L29.8 : Other pruritus', 'R21 : Rash and other nonspecific skin eruption', 'L50.9 : Urticaria, unspecified', 'L24.9 : Irritant contact dermatitis, unspecified cause', 'L27.2 : Dermatitis due to ingested food', 'L24.0 : Irritant contact dermatitis due to detergents']**  
            medical_specialty_subspecialty: **cardiologist**  
            Expected output: **non-relevant**
            
            Example 5:  
            medical_query: **acne worsening at job**  
            icd_code_description_list: **['L70.1 : Acne conglobata', 'L21.9 : Seborrheic dermatitis, unspecified', 'L30.9 : Eczema NOS', 'L25.9 : Unspecified contact dermatitis, unspecified cause', 'L70.0 : Acne vulgaris', 'L23.5 : Allergic contact dermatitis due to plastic', 'L71.9 : Rosacea, unspecified', 'L24.9 : Irritant contact dermatitis, unspecified cause', 'L70.8 : Other acne', 'L70.9 : Acne, unspecified', 'L24.0 : Irritant contact dermatitis due to detergents', 'L71.0 : Perioral dermatitis']**  
            medical_specialty_subspecialty: **dermatopathology_occupational medicine**  
            Expected output: **relevant**

            Example 5:  
            medical_query: **how does diabetes affect brain tumour**  
            icd_code_description_list: **['L70.1 : Acne conglobata', 'L21.9 : Seborrheic dermatitis, unspecified', 'L30.9 : Eczema NOS', 'L25.9 : Unspecified contact dermatitis, unspecified cause', 'L70.0 : Acne vulgaris', 'L23.5 : Allergic contact dermatitis due to plastic', 'L71.9 : Rosacea, unspecified', 'L24.9 : Irritant contact dermatitis, unspecified cause', 'L70.8 : Other acne', 'L70.9 : Acne, unspecified', 'L24.0 : Irritant contact dermatitis due to detergents', 'L71.0 : Perioral dermatitis']**  
            medical_specialty_subspecialty: **neurology**  
            Expected output: **CANNOT_DECIDE**
            
            
            **Remember:** Provide **only** the label (`relevant`, `non-relevant`, or `CANNOT_DECIDE`) as the answer. Do not add any explanation or extra text.
    """

    # Define the prompt template
    prompt_template = PromptTemplate.from_template(
        template=system_prompt
    )
    
    
    chain = prompt_template | model | output_parser
    result = chain.invoke(input={"medical_specialty_subspecialty": medical_specialty_subspecialty, "medical_query": medical_query,"icd_code_description_list" : icd_code_description_list, "format_instructions" : output_parser.get_format_instructions()})
    return result    


In [42]:
model_name = "gpt-4.1"
model = initialize_llm(model_name = model_name)
#retry_dict = get_icd_code_processor(model_name = model_name, chunk_specialty_query_code_dict = chunk_specialty_query_code_dict, chunk_specialty_query_code_description_dict = chunk_specialty_query_code_description_dict)


Initializing Model : gpt-4.1
Model gpt-4.1 Initialized


In [46]:
reference_medical_specialty_subspecialty = list(specialty_query_code_desciption_dict.keys())[540]
medical_query = list(specialty_query_code_desciption_dict.get(reference_medical_specialty_subspecialty))[6]
icd_code_description_list = specialty_query_code_desciption_dict.get(reference_medical_specialty_subspecialty).get(medical_query)

print(f'specialty_subspecialty : {reference_medical_specialty_subspecialty}')
print(f'query : {medical_query}')
print(f'icd_code_description_list : {icd_code_description_list}')

specialty_subspecialty : registered nurse_ophthalmic
query : how does diabetes affect vision
icd_code_description_list : ['E11.339 : Type 2 diabetes mellitus with moderate nonproliferative diabetic retinopathy without macular edema', 'H43.10 : Vitreous hemorrhage, unspecified eye', 'E11.319 : Type 2 diabetes w unsp diabetic rtnop w/o macular edema', 'E11.36 : Type 2 diabetes mellitus with diabetic cataract', 'E11.349 : Type 2 diab w severe nonprlf diab rtnop w/o macular edema', 'H35.0 : Background retinopathy and retinal vascular changes', 'E08.311 : Diabetes mellitus due to underlying condition with unspecified diabetic retinopathy with macular edema', 'E13.311 : Other specified diabetes mellitus with unspecified diabetic retinopathy with macular edema', 'E11.329 : Type 2 diabetes mellitus with mild nonproliferative diabetic retinopathy without macular edema', 'E11.39 : Type 2 diabetes mellitus with other diabetic ophthalmic complication', 'E11.311 : Type 2 diabetes w unsp diabetic re

In [47]:
result = get_specialty_verification(model = model, medical_specialty_subspecialty = reference_medical_specialty_subspecialty, medical_query = medical_query, icd_code_description_list = icd_code_description_list)



In [48]:
result

['relevant']