In [1]:
import os
import json
import gzip
import random

import torch
import pickle
import logging

import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List
from pathlib import Path


from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer, util


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
class Distributed_Sentence_Embeddings:
    """ Generates distributed embedding for descriptions
        Args:
            model_path : path to the Sentence Transformer Model
            batch_size : batch size to be used for building the index of code : embedding pairs
            text_descriptions : all anchor positive negative text
        Returns:
            text_sentence_embedding_dictionary : dictionary with all sentences ( anchor + positives + negatives ) and respective embedding
    """
    
    def __init__(self, model_path : str, batch_size : int):
        """ initialize model_path and reference code_lookup dataset
        """
        self.model_path = model_path
        self.batch_size = batch_size
        self.text_sentence_embedding_dictionary : dict =  {}
        self.model = None
    def load_model(self):
        """ loads SentenceTransformer model
        """
        self.model = SentenceTransformer(self.model_path, trust_remote_code=True).cuda()
        return self.model
        
    def setup_multi_processing_pool(self):
        """ determines the number of GPUs available on the compute
        """
        
        print("Setting Up Multi Processing Pool")
        device_count = torch.cuda.device_count()
        pool_list = [f"cuda:{device}" for device in range(device_count)]
        return pool_list
    
    
    def generate_sentence_embedding_pair(self, text_descriptions : List[str]):
        """ Creates a dictionary of code and respective embeddings
        """
        
        print("Generating Embeddings")

        torch.cuda.empty_cache()

        model = self.load_model()

        pool_list = self.setup_multi_processing_pool()

        pool = self.model.start_multi_process_pool(pool_list)

        print("Computing embeddings using the multi-process pool")
        embeddings = self.model.encode_multi_process(sentences = text_descriptions, pool = pool, batch_size = self.batch_size)

        self.model.stop_multi_process_pool(pool)
        
        for sentence, embedding in zip(text_descriptions, embeddings):
            self.text_sentence_embedding_dictionary[sentence] = embedding

        print("Sentence : Embedding Mapping Created")
        return self.text_sentence_embedding_dictionary


In [8]:
class DataLoader:
    """ DataLoader class to load the UES keywords list and the ICD Datasets loaded
    """
    def __init__(self, ues_dataset_path : str, icd_dataset_path : str):
        self.ues_dataset_path = ues_dataset_path
        self.icd_dataset_path = icd_dataset_path
        
    def load_ues_keywords(self):

        ues_dataset = pd.read_csv(self.ues_dataset_path).iloc[:,1:]
        ues_dataset.columns = ['ues_keywords']
        return ues_dataset
    
    def load_icd_dataset(self)-> dict:
    
        def select_description(df : pd.DataFrame):

            max_length_indices = df.groupby('codes')['length'].idxmax()

            return df.loc[max_length_indices].reset_index(drop = True)
    
        icd_dataset = pd.read_csv(self.icd_dataset_path).iloc[:,14:16]
        icd_dataset.columns = ['codes','description']
        icd_dataset = icd_dataset.drop_duplicates()

        icd_dataset['description'] = icd_dataset['description'].apply(lambda x : x.strip())
        icd_dataset['length'] = icd_dataset['description'].apply(lambda x : len(x))

        icd_dataset_filtered = select_description(icd_dataset)

        icd_reference_lookup = {}
        for row in icd_dataset.itertuples():
            icd_reference_lookup[row.codes.strip()] = row.description.strip()

        print(f'Total ICD Coes : {len(icd_reference_lookup.keys())}')

        return icd_reference_lookup 

In [19]:
def get_icd_code_embeddings(icd_reference_lookup : dict, embedding_object : Distributed_Sentence_Embeddings):
    code_embeddings = {}
    for code, description in tqdm(icd_reference_lookup.items()):
        code_embeddings[code] = embedding_object.model.encode(description)
        
    return code_embeddings

In [24]:
import numpy as np
import faiss
from typing import Dict, List, Tuple
import pickle
import json

class ICDRetriever:
    """
    Efficient ICD code retrieval using FAISS for cosine similarity search.
    """

    def __init__(self, code_embeddings: Dict[str, np.ndarray]):
        """
        Initialize the retriever with ICD code embeddings.

        Args:
            code_embeddings: Dict with ICD codes as keys and embeddings as values
        """
        self.code_embeddings = code_embeddings
        self.icd_codes = list(code_embeddings.keys())
        self.embedding_dim = None
        self.index = None
        self._build_index()

    def _build_index(self):
        """Build FAISS index for efficient similarity search."""
        print("Building FAISS index...")

        # Convert embeddings to numpy array
        embeddings_list = [self.code_embeddings[code] for code in self.icd_codes]
        embeddings_array = np.array(embeddings_list, dtype=np.float32)

        # Normalize embeddings for cosine similarity
        faiss.normalize_L2(embeddings_array)

        self.embedding_dim = embeddings_array.shape[1]

        # Create FAISS index (using Inner Product for normalized vectors = cosine similarity)
        self.index = faiss.IndexFlatIP(self.embedding_dim)
        self.index.add(embeddings_array)

        print(f"Index built with {len(self.icd_codes)} ICD codes, embedding dimension: {self.embedding_dim}")

    def search_top_k(self, query_embedding: np.ndarray, k: int = 10) -> List[Tuple[str, float]]:
        """
        Find top-k most similar ICD codes for a query embedding.

        Args:
            query_embedding: Query embedding vector
            k: Number of top results to return

        Returns:
            List of tuples (icd_code, similarity_score)
        """
        # Ensure query embedding is the right shape and type
        query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1)

        # Normalize query embedding
        faiss.normalize_L2(query_embedding)

        # Search
        similarities, indices = self.index.search(query_embedding, k)

        # Convert results to list of tuples
        results = []
        for sim, idx in zip(similarities[0], indices[0]):
            if idx != -1:  # Valid index
                results.append((self.icd_codes[idx], float(sim)))

        return results

    def batch_search(self, query_embeddings: Dict[str, np.ndarray], k: int = 10) -> Dict[str, List[Tuple[str, float]]]:
        """
        Batch search for multiple queries.

        Args:
            query_embeddings: Dict with query descriptions as keys and embeddings as values
            k: Number of top results per query

        Returns:
            Dict with query descriptions as keys and list of (icd_code, similarity) tuples as values
        """
        print(f"Processing {len(query_embeddings)} queries...")

        results = {}
        for i, (query_desc, query_emb) in enumerate(query_embeddings.items()):
            results[query_desc] = self.search_top_k(query_emb, k)

            # Progress indicator
            if (i + 1) % 1000 == 0:
                print(f"Processed {i + 1}/{len(query_embeddings)} queries")

        return results

    def save_index(self, filepath: str):
        """Save the FAISS index to disk."""
        faiss.write_index(self.index, filepath)

        # Save metadata
        metadata = {
            'icd_codes': self.icd_codes,
            'embedding_dim': self.embedding_dim
        }
        with open(filepath + '.metadata', 'w') as f:
            json.dump(metadata, f)

    def load_index(self, filepath: str):
        """Load FAISS index from disk."""
        self.index = faiss.read_index(filepath)

        # Load metadata
        with open(filepath + '.metadata', 'r') as f:
            metadata = json.load(f)

        self.icd_codes = metadata['icd_codes']
        self.embedding_dim = metadata['embedding_dim']



In [34]:
def get_retrieval_results_model(icd_reference_lookup : dict, model_path : str, model_name : str, ues_dataset : pd.DataFrame):
    
    batch_size = 16
    embedding_object = Distributed_Sentence_Embeddings(model_path = model_path, batch_size = batch_size)
    print(f'Model Selected : {model_name}')
    text_sentence_embedding_dictionary = embedding_object.generate_sentence_embedding_pair(text_descriptions = list(ues_dataset['ues_keywords']))
    
    code_embeddings = get_icd_code_embeddings(icd_reference_lookup = icd_reference_lookup, embedding_object = embedding_object)
    
    retriever = ICDRetriever(code_embeddings)
    results = retriever.batch_search(text_sentence_embedding_dictionary, k=10)
    print(f"\nSaving results for {model_name}")
    with open(f'icd_retrieval_results_{model_name}.json', 'w') as f:
        json.dump(results, f, indent=4)
    

In [35]:
def main():
    
    ues_dataset_path = '../../datasets/dataset_evaluation/ues_search_queries.csv'
    icd_dataset_path = '../../../shekhar_tanwar/ICD-ICD-Triplet/dataset/icd10.csv'

    data_loader = DataLoader(ues_dataset_path = ues_dataset_path, icd_dataset_path = icd_dataset_path)

    ues_dataset = data_loader.load_ues_keywords()
    icd_reference_lookup = data_loader.load_icd_dataset()
    with open('icd_reference_lookup.pkl', 'wb') as file:
        pickle.dump(icd_reference_lookup, file)
    
    model_path_a = '../../../shekhar_tanwar/ICD-ICD-Triplet/model/e5-large-v2-20250331143312-finetuned-icd-v30/'
    
    get_retrieval_results_model(icd_reference_lookup = icd_reference_lookup, model_path =  model_path_a, model_name = 'e5_v30', ues_dataset = ues_dataset)
    

    #model_path_b = '../model/e5-large-v2-20250604214304-finetuned-icd-v40/'
    
    #get_retrieval_results_model(model_path =  model_path_b, model_name = 'e5_v40')
    
    

    

In [36]:
if __name__ == "__main__":
    main()

Total ICD Coes : 97064
Model Selected : e5_v30
Generating Embeddings
Setting Up Multi Processing Pool
Computing embeddings using the multi-process pool
Sentence : Embedding Mapping Created


100%|██████████| 97064/97064 [2:02:19<00:00, 13.22it/s]   


Building FAISS index...
Index built with 97064 ICD codes, embedding dimension: 1024
Processing 16804 queries...
Processed 1000/16804 queries
Processed 2000/16804 queries
Processed 3000/16804 queries
Processed 4000/16804 queries
Processed 5000/16804 queries
Processed 6000/16804 queries
Processed 7000/16804 queries
Processed 8000/16804 queries
Processed 9000/16804 queries
Processed 10000/16804 queries
Processed 11000/16804 queries
Processed 12000/16804 queries
Processed 13000/16804 queries
Processed 14000/16804 queries
Processed 15000/16804 queries
Processed 16000/16804 queries

Saving results for e5_v30


# Generating Summary For UES 

In [2]:
import json
import pickle
import pandas as pd


In [9]:
def error_analysis_summary(icd_reference_lookup_path : str, icd_retrieval_results_e5_v30_path : str, icd_retrieval_results_e5_v40_path : str) -> pd.DataFrame:
    print(f'Loading ICD Reference Lookup')
    with open(icd_reference_lookup_path, 'rb') as file:
        icd_reference_lookup = pickle.load(file)
        
    print(f'Loading Retrieved Results : E5_V30 ')
    with open(icd_retrieval_results_e5_v30_path, 'r') as file:
        icd_retrieval_results_e5_v30 = json.load(file)
    
    print(f'Loading Retrieved Results : E5_V40 ')
    with open(icd_retrieval_results_e5_v40_path, 'r') as file:
        icd_retrieval_results_e5_v40 = json.load(file)    
    
    all_queries = list(icd_retrieval_results_e5_v30.keys())
    data_list = []
    for query in all_queries:
        result_v30 = icd_retrieval_results_e5_v30.get(query)
        result_v40 = icd_retrieval_results_e5_v40.get(query)    

        retrieved_codes_v30 = []
        retrieved_codes_v40 = []
        for item_v30, item_v40 in zip(result_v30, result_v40):
            code_v30 = item_v30[0]
            code_v40 = item_v40[0]

            
            retrieved_codes_v30.append(code_v30)
            retrieved_codes_v40.append(code_v40)

        common_codes = set(retrieved_codes_v30).intersection(retrieved_codes_v40)
        #common_codes = [code + ' : ' +  icd_reference_lookup.get(code) for code in common_codes]
        common_codes = ['  '.join(common_codes)]

        model_a_diff_model_b = set(retrieved_codes_v30).difference(set(retrieved_codes_v40))
        model_b_diff_model_a = set(retrieved_codes_v40).difference(set(retrieved_codes_v30))


        #model_a_diff_model_b = [code + ' : ' +  icd_reference_lookup.get(code) for code in model_a_diff_model_b]
        #model_b_diff_model_a = [code + ' : ' +  icd_reference_lookup.get(code) for code in model_b_diff_model_a]
        
        model_a_diff_model_b = [','.join(model_a_diff_model_b)]
        model_b_diff_model_a = [','.join(model_b_diff_model_a)]

        result = pd.DataFrame(list(zip([query], common_codes,model_a_diff_model_b, model_b_diff_model_a )), columns = ['UES_Keyword','Common_Codes','E5_V30_Diff_E5_V40','E5_V40_Diff_E5_V30'])
        data_list.append(result)
    
    final_dataset =  pd.concat(data_list)
    final_dataset = final_dataset.reset_index().iloc[:,1:]
    return final_dataset

In [79]:

icd_reference_lookup_path = 'icd_reference_lookup.pkl'
icd_retrieval_results_e5_v30_path = './icd_retrieval_results_e5_v30.json'
icd_retrieval_results_e5_v40_path = './icd_retrieval_results_e5_v40.json'

final_dataset = error_analysis_summary(icd_reference_lookup_path = icd_reference_lookup_path, 
                                       icd_retrieval_results_e5_v30_path = icd_retrieval_results_e5_v30_path, 
                                       icd_retrieval_results_e5_v40_path = icd_retrieval_results_e5_v40_path,
                                       )

Loading ICD Reference Lookup
Loading Retrieved Results : E5_V30 
Loading Retrieved Results : E5_V40 


In [64]:
temp = final_dataset.iloc[0,1].split(' ')
print(temp)
list(set([code[:5] for code in temp if code != '']))

['H52.4', '', 'H52.2', '', 'H52.1', '', 'H52.6', '', 'H52.7']


['H52.4', 'H52.2', 'H52.1', 'H52.6', 'H52.7']

In [84]:
groups_level_4 = []
groups_level_3 = []
groups_level_2 = []

group_e5_v30_diff_e5_v40_level4 = []
group_e5_v30_diff_e5_v40_level3 = []
group_e5_v30_diff_e5_v40_level2 = []

group_e5_v40_diff_e5_v30_level4 = []
group_e5_v40_diff_e5_v30_level3 = []
group_e5_v40_diff_e5_v30_level2 = []

for index, row in final_dataset.iterrows():

    
    e5_v30_diff_e5_v40 = row['E5_V30_Diff_E5_V40']
    e5_v30_diff_e5_v40 = e5_v30_diff_e5_v40.split(',')

    e5_v30_diff_e5_v40_level4 = [code[:5] for code in e5_v30_diff_e5_v40 if code != '']
    e5_v30_diff_e5_v40_level4 = [str(code).replace('.','') if code[-1] == '.' else code for code in e5_v30_diff_e5_v40_level4]
    e5_v30_diff_e5_v40_level3 = [code[:4] for code in e5_v30_diff_e5_v40 if code != '']
    e5_v30_diff_e5_v40_level3 = [str(code).replace('.','') if code[-1] == '.' else code for code in e5_v30_diff_e5_v40_level3]
    e5_v30_diff_e5_v40_level2 = [code[:3] for code in e5_v30_diff_e5_v40 if code != '']
    e5_v30_diff_e5_v40_level2 = [str(code).replace('.','') if code[-1] == '.' else code for code in e5_v30_diff_e5_v40_level2]


    group_e5_v30_diff_e5_v40_level4.append(' '.join(list(set(e5_v30_diff_e5_v40_level4))))
    group_e5_v30_diff_e5_v40_level3.append(' '.join(list(set(e5_v30_diff_e5_v40_level3))))
    group_e5_v30_diff_e5_v40_level2.append(' '.join(list(set(e5_v30_diff_e5_v40_level2))))


    e5_v40_diff_e5_v30 = row['E5_V40_Diff_E5_V30']
    e5_v40_diff_e5_v30 = e5_v40_diff_e5_v30.split(',')

    e5_v40_diff_e5_v30_level4 = [code[:5] for code in e5_v40_diff_e5_v30 if code != '']
    e5_v40_diff_e5_v30_level4 = [str(code).replace('.','') if code[-1] == '.' else code for code in e5_v40_diff_e5_v30_level4]
    e5_v40_diff_e5_v30_level3 = [code[:4] for code in e5_v40_diff_e5_v30 if code != '']
    e5_v40_diff_e5_v30_level3 = [str(code).replace('.','') if code[-1] == '.' else code for code in e5_v40_diff_e5_v30_level3]
    e5_v40_diff_e5_v30_level2 = [code[:3] for code in e5_v40_diff_e5_v30 if code != '']
    e5_v40_diff_e5_v30_level2 = [str(code).replace('.','') if code[-1] == '.' else code for code in e5_v40_diff_e5_v30_level2]


    group_e5_v40_diff_e5_v30_level4.append(' '.join(list(set(e5_v40_diff_e5_v30_level4))))
    group_e5_v40_diff_e5_v30_level3.append(' '.join(list(set(e5_v40_diff_e5_v30_level3))))
    group_e5_v40_diff_e5_v30_level2.append(' '.join(list(set(e5_v40_diff_e5_v30_level2))))
    
    
    common_codes = row['Common_Codes']
    common_codes = common_codes.split(' ')

    common_codes_level4 = [code[:5] for code in common_codes if code != '']
    common_codes_level3 = [code[:4] for code in common_codes if code != '']
    common_codes_level2 = [code[:3] for code in common_codes if code != '']
    
    groups_level_4.append(' '.join(list(set(common_codes_level4))))
    groups_level_3.append(' '.join(list(set(common_codes_level3))))
    groups_level_2.append(' '.join(list(set(common_codes_level2))))
    

final_dataset['Common_Codes_Group_Level4'] = groups_level_4
final_dataset['Common_Codes_Group_Level3'] = groups_level_3
final_dataset['Common_Codes_Group_Level2'] = groups_level_2

final_dataset['E5_V30_Diff_E5_V40_Group_Level4'] = group_e5_v30_diff_e5_v40_level4
final_dataset['E5_V30_Diff_E5_V40_Group_Level3'] = group_e5_v30_diff_e5_v40_level3
final_dataset['E5_V30_Diff_E5_V40_Group_Level2'] = group_e5_v30_diff_e5_v40_level2


final_dataset['E5_V40_Diff_E5_V30_Group_Level4'] = group_e5_v40_diff_e5_v30_level4
final_dataset['E5_V40_Diff_E5_V30_Group_Level3'] = group_e5_v40_diff_e5_v30_level3
final_dataset['E5_V40_Diff_E5_V30_Group_Level2'] = group_e5_v40_diff_e5_v30_level2

In [85]:
final_dataset

Unnamed: 0,UES_Keyword,Common_Codes,E5_V30_Diff_E5_V40,E5_V40_Diff_E5_V30,Common_Codes_Group_Level4,Common_Codes_Group_Level3,Common_Codes_Group_Level2,E5_V30_Diff_E5_V40_Group_Level4,E5_V30_Diff_E5_V40_Group_Level3,E5_V30_Diff_E5_V40_Group_Level2,E5_V40_Diff_E5_V30_Group_Level4,E5_V40_Diff_E5_V30_Group_Level3,E5_V40_Diff_E5_V30_Group_Level2
0,refractive error,H52.4 H52.2 H52.1 H52.6 H52.7,"H52,Q12.4,H53.02,Q12.8,H52.5","H53.11,H52.00,H50.2,H52.31,H52.0",H52.4 H52.2 H52.1 H52.6 H52.7,H52.,H52,H52 Q12.4 H53.0 Q12.8 H52.5,Q12 H52 H53,Q12 H52 H53,H52.3 H53.1 H52.0 H50.2,H50 H52 H53,H50 H52 H53
1,sharp pains,M79.60,"M79.64,M79.639,M79.643,G89,M79.642,G89.1,M79.6...","M79.7,R51,R20.1,M91.3,M79.609,M79.2,R20.2,R52,...",M79.6,M79.,M79,G89.1 M79.6 G89,M79 G89,M79 G89,M79.7 R51 M79.6 R20.1 M91.3 M79.2 R20.2 R52 R20.9,M79 R51 M91 R20 R52,M79 R51 M91 R20 R52
2,impaired physical function,M62.81 Z74.0 R41.843 Z74.09,"R53.2,Z73.6,R94.2,Z72.3,M99.0,R26","R53.8,R53.1,R53,R54,R29.6,R29.890",R41.8 M62.8 Z74.0,M62. Z74. R41.,Z74 R41 M62,R53.2 Z73.6 R94.2 Z72.3 M99.0 R26,R94 Z72 R26 R53 M99 Z73,R94 Z72 R26 R53 M99 Z73,R29.8 R53.8 R53.1 R53 R54 R29.6,R54 R53 R29,R54 R53 R29
3,physiological phenomenon,R29.2,"H53.1,R40.4,R06,R94.3,R20.8,F01-F09,R20,H53.16...","R00-R09,R53.83,R29.818,G90.4,R29.90,R29.81,R94...",R29.2,R29.,R29,H53.1 R40.4 R06 R94.3 R20.8 F01-F R20 R62,R06 R94 F01- R40 R20 H53 R62,F01 R06 R94 R40 R20 H53 R62,R29.8 R53.8 R29.9 R79.8 G90.4 R94.0 R00-R R20.3,R94 G90 R00- R20 R53 R29 R79,R94 R00 G90 R20 R53 R29 R79
4,bone marrow,Z52.3,"Z52.2,Z94.6,T86.02,D47.02,Z52.29,Z94.81,Z52.21...","D46.4,D72.822,D70.9,D46.9,Z52.001,Z52.091,Z52....",Z52.3,Z52.,Z52,Z52.2 Z94.6 D47.0 T86.0 Z94.8,Z52 D47 T86 Z94,Z52 D47 T86 Z94,D46.4 Z52.0 D70.9 D46.9 D72.8 Z67.1,Z52 D70 D72 D46 Z67,Z52 D70 D72 D46 Z67
...,...,...,...,...,...,...,...,...,...,...,...,...,...
16799,functional decline,R41.843 Z74.0 R54 R41.81 Z74.09,"Z73.6,R94.2,E34.8,M62.84,R46.4","R41.82,Z74.01,R53.1,R62.7,R41.84",R41.8 R54 Z74.0,R54 Z74. R41.,Z74 R41 R54,Z73.6 R94.2 E34.8 M62.8 R46.4,E34 R94 R46 M62 Z73,E34 R94 R46 M62 Z73,R41.8 Z74.0 R53.1 R62.7,Z74 R41 R53 R62,Z74 R41 R53 R62
16800,esophagus disorders,K20-K31 K22 K20 K23 K22.8 K22.9 K20.8,"K20.90,K21,H50.51","K21.00,K20.9,K20.0",K22 K20 K23 K22.8 K20-K K22.9 K20.8,K20- K22. K22 K20. K20 K23,K23 K22 K20,H50.5 K21 K20.9,H50 K21 K20,H50 K21 K20,K20.0 K21.0 K20.9,K21 K20,K21 K20
16801,bipolar type 1,F31.31 F31.81 F31.9 F31.8 F31.62 F31.6,"F30,F31.0,F31.32,F31.7","F31.61,F31.60,F31.77,F31.64",F31.8 F31.3 F31.6 F31.9,F31.,F31,F30 F31.3 F31.0 F31.7,F30 F31,F30 F31,F31.6 F31.7,F31,F31
16802,chiropractic care,,"M45-M49,M54,M80-M94,M48,M90.8,M54.2,M94.2,M48....","M99.07,M99.04,M99.06,M99.03,M99.05,M99.09,M99....",,,,M54 M48.8 M48 M90.8 M45-M M54.2 M80-M M94.2 M9...,M54 M45- M48 M42 M80- M90 M94,M45 M54 M48 M42 M94 M90 M80,M99.0,M99,M99


In [86]:
final_dataset.to_csv('error_analysis_ues_e5_v30_e5_v40_v3.csv')