In [1]:
import pandas as pd
import chromadb
from tqdm import tqdm
from transformers import pipeline

# Initialize the client

In [2]:
client = chromadb.PersistentClient(path="idx/icd11",
                                   settings=chromadb.Settings(allow_reset=True))

In [3]:
client.heartbeat()
client.reset()

True

# Load the data

In [4]:
icd = pd.read_csv("data/ICD-11-MMS-en-cleaned.csv.zip").fillna('')
icd.head()

Unnamed: 0,Code,Title,ClassKind,DepthInKind,IsResidual,BrowserLink,isLeaf
0,,Certain infectious or parasitic diseases,chapter,1,False,https://icd.who.int/browse11/l-m/en#/http%3A%2...,False
1,,- Gastroenteritis or colitis of infectious origin,block,1,False,https://icd.who.int/browse11/l-m/en#/http%3A%2...,False
2,,- - Bacterial intestinal infections,block,2,False,https://icd.who.int/browse11/l-m/en#/http%3A%2...,False
3,1A00,- - - Cholera,category,1,False,https://icd.who.int/browse11/l-m/en#/http%3A%2...,True
4,1A01,- - - Intestinal infection due to other Vibrio,category,1,False,https://icd.who.int/browse11/l-m/en#/http%3A%2...,True


In [5]:
titles = icd["Title"].tolist()
titles = [t.lstrip('- ') for t in titles if t]
titles[:5]

['Certain infectious or parasitic diseases',
 'Gastroenteritis or colitis of infectious origin',
 'Bacterial intestinal infections',
 'Cholera',
 'Intestinal infection due to other Vibrio']

In [6]:
max_len_chars = max(len(t) for t in titles)
max_len_words = max(len(t.split()) for t in titles)
max_len_chars, max_len_words

(212, 30)

In [7]:
icd["cleaned_title"] = icd["Title"].str.lstrip('- ')

# Load the vectorizer
Using the pipelines discussed in https://stackoverflow.com/questions/64685243/getting-sentence-embedding-from-huggingface-feature-extraction-pipeline

In [8]:
extractor = pipeline(model="AkshatSurolia/ICD-10-Code-Prediction", task="feature-extraction")

config.json:   0%|          | 0.00/744k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/482M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/370 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [9]:
result = extractor("Bacterial intestinal infections", return_tensors=True)
result.shape

torch.Size([1, 7, 768])

In [10]:
sentence_embedding = result[0].mean(dim=0)
sentence_embedding.shape

torch.Size([768])

In [11]:
sentence_embedding2 = result[0][0]
sentence_embedding2.shape

torch.Size([768])

# Create the index

In [12]:
collection = client.create_collection(
    name="icd11_titles", 
    # metadata={"hnsw:space": "cosine"},
)

In [15]:
# client.reset()
docs = []
ids = []
embeddings = []

def append_to_collection():
    collection.upsert(documents=docs, ids=ids, embeddings=embeddings)
    docs.clear()
    ids.clear()
    embeddings.clear()

    
for n, row in tqdm(icd.iterrows(), total=len(icd)):
    if not row['Code']:
        continue
    docs.append(doc := row['cleaned_title'])
    embeddings.append(extractor(doc, return_tensors=True)[0].mean(dim=0).numpy().tolist())
    ids.append(row['Code'])
    if len(docs) >= 1000:
        append_to_collection()
append_to_collection()

100%|██████████| 36044/36044 [40:41<00:00, 14.76it/s]  


In [16]:
# check the index
embedding = extractor("- - - - Gastroduodenal motor or secretory disorders, unspecified", return_tensors=True)[0].mean(dim=0).numpy().tolist()
results = collection.query(query_embeddings=embedding, n_results=5)
results

{'ids': [['DA41.Z', 'DA41', 'DD90.Z', 'DD90', 'DA41.Y']],
 'distances': [[551.697509765625,
   748.6067504882812,
   756.9035034179688,
   898.1651000976562,
   904.1014404296875]],
 'metadatas': [[None, None, None, None, None]],
 'embeddings': None,
 'documents': [['Gastroduodenal motor or secretory disorders, unspecified',
   'Gastroduodenal motor or secretory disorders',
   'Functional oesophageal or gastroduodenal disorders, unspecified',
   'Functional oesophageal or gastroduodenal disorders',
   'Other specified gastroduodenal motor or secretory disorders']],
 'uris': None,
 'data': None}