In [1]:
import torch
from transformers import AutoTokenizer, AutoModel
import random
import numpy as np
from sklearn.manifold import TSNE
import umap
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import plotly.express as px
import pandas as pd
import faiss

np.random.seed(0)
random.seed(0)
torch.manual_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x116ccd150>

In [2]:
device = "cpu"  # "cuda"

In [3]:
pd.set_option("display.max_colwidth", None)

In [4]:
"""
References:
- https://arxiv.org/pdf/1908.10084
- https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens
- https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext
- https://huggingface.co/NeuML/pubmedbert-base-embeddings
"""

'\nReferences:\n- https://arxiv.org/pdf/1908.10084\n- https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens\n- https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext\n- https://huggingface.co/NeuML/pubmedbert-base-embeddings\n'

In [None]:
"""
Set options here
"""
# encoding_model_str = "google-bert/bert-base-uncased"
# encoding_model_str = "sentence-transformers/bert-base-nli-mean-tokens"
# encoding_model_str = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
# encoding_model_str = "Lianglab/PharmBERT-uncased"
encoding_model_str = "NeuML/pubmedbert-base-embeddings"

# text files from the Canadian ADR dB
# see: https://www.canada.ca/en/health-canada/services/drugs-health-products/medeffect-canada/adverse-reaction-database/canada-vigilance-online-database-data-extract.html
base_dir = "./"
text_path_dict = {
    "AlternateDrugs": base_dir + "AlternateDrugs.csv",
    "diseases": base_dir + "diseases.csv",
    "drugnames": base_dir + "drugnames.csv",
}

text_path = text_path_dict["diseases"]  # ["diseases"]
limit_num_rows_text = 1500

In [10]:
with open(text_path, "r") as f:
    texts = f.readlines()
texts = [i.strip().lower() for i in texts[1:]]
random.shuffle(texts)
texts = texts[:limit_num_rows_text]

In [11]:
tokenizer = AutoTokenizer.from_pretrained(encoding_model_str)
model = AutoModel.from_pretrained(encoding_model_str)

In [12]:
model = model.to(device)

In [13]:
encoded_inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
attention_mask = encoded_inputs["attention_mask"][
    ..., None
]  # handle padding with attention mask

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [14]:
batch_size = 16
outputs = []
for idx in range(0, len(texts), batch_size):
    with torch.no_grad():
        batched_encoded_inputs = {}
        for key_ in encoded_inputs.keys():
            batched_encoded_inputs[key_] = encoded_inputs[key_][idx : idx + batch_size]
            batched_encoded_inputs[key_] = batched_encoded_inputs[key_].to(device)
        output = model(**batched_encoded_inputs)[0]  # get the last hidden state vector
        outputs.append(output.detach().cpu())
outputs = torch.concat(outputs, dim=0)

In [15]:
sentence_embedding = torch.sum(
    outputs.to("cpu") * attention_mask.to("cpu"), dim=1
) / torch.clamp(torch.sum(attention_mask.to("cpu"), dim=1), min=1e-9)

In [16]:
sentence_embedding

tensor([[ 0.0018,  0.1961,  0.4001,  ..., -0.6743,  0.1901, -0.2501],
        [ 0.0904, -0.0036,  0.4530,  ...,  0.3199,  0.1284, -0.3126],
        [-0.4115,  0.3865, -0.5249,  ..., -0.3305,  0.1948,  0.0268],
        ...,
        [ 0.3898, -0.1245,  0.1752,  ...,  0.4537,  0.7214,  0.5138],
        [ 0.5157,  0.3204, -0.1786,  ...,  0.0685,  0.0088, -0.2578],
        [-0.5059,  0.2664, -0.2065,  ..., -0.5653, -0.0713, -0.0543]])

In [17]:
sentence_embedding.shape

torch.Size([1500, 768])

In [18]:
norms = torch.clamp(torch.sqrt(torch.sum(sentence_embedding**2, dim=1)), min=1e-9)[
    ..., None
]
sentence_embedding_normalized = sentence_embedding / norms

In [19]:
# visualization
# first use PCA
pca_components = 50

In [20]:
pca = PCA(n_components=pca_components, svd_solver="full")
X_pca = pca.fit_transform(sentence_embedding_normalized.numpy())

In [21]:
# then use UMAP (or tSNE)
tsne = umap.UMAP(random_state=42)  # TSNE(n_components=2)
X_tsne = tsne.fit_transform(X_pca)

  warn(
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [22]:
X_tsne_pd = pd.DataFrame(X_tsne)
X_tsne_pd.columns = ["tsne_x", "tsne_y"]
X_tsne_pd["text"] = texts

In [23]:
X_tsne_pd.head()

Unnamed: 0,tsne_x,tsne_y,text
0,-1.356653,6.192302,virilism
1,0.163219,9.104291,still's disease
2,-2.982886,9.869576,choroid plexus carcinoma
3,0.041707,8.407229,anaphylactic shock
4,-1.55976,8.532484,labour pain


In [24]:
# K-means clustering to visualize the closest neighbors in the full embedding dimensions (not just reduced dims)
n_clusters = 50  # better to oversegment than undersegment! choose a number slightly more than expected

kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans_classes = kmeans.fit_predict(sentence_embedding_normalized.numpy())
X_tsne_pd["KMeans_classes"] = kmeans_classes
kmeans_classes.shape

  super()._check_params_vs_input(X, default_n_init=10)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


(1500,)

In [25]:
fig = px.scatter(
    X_tsne_pd, x="tsne_x", y="tsne_y", hover_data=["text"], color="KMeans_classes"
)
fig.show()

In [26]:
# visualize per kmeans class
class_num = 9
X_tsne_pd[X_tsne_pd["KMeans_classes"] == class_num].head(20)

Unnamed: 0,tsne_x,tsne_y,text,KMeans_classes
84,2.761922,5.664856,endometrial cancer metastatic,9
149,2.869312,5.215249,cervix carcinoma,9
203,3.232354,5.445651,transitional cell carcinoma metastatic,9
239,2.780092,5.665688,breast cancer metastatic,9
302,2.525391,5.598411,leiomyosarcoma recurrent,9
345,2.7141,5.728059,metastases to liver,9
368,3.262393,5.491288,metastatic carcinoma of the bladder,9
401,2.916399,5.468676,vulvar adenocarcinoma,9
452,2.742719,5.513125,lung cancer metastatic,9
492,1.255798,8.837572,regional chemotherapy,9


In [27]:
X_tsne_pd["KMeans_classes"].value_counts()

KMeans_classes
28    66
1     60
11    52
34    52
48    52
27    50
33    50
49    47
45    42
7     40
25    39
9     37
20    34
0     34
41    33
2     33
37    32
43    31
12    31
17    30
32    29
13    29
4     28
6     28
21    28
10    27
22    26
46    25
35    25
36    25
18    24
5     24
19    23
39    22
24    22
15    22
26    22
44    21
3     21
40    21
29    20
38    19
30    19
23    18
14    17
42    17
47    16
16    13
8     13
31    11
Name: count, dtype: int64

In [28]:
# Use Cosine Similarity for RAG -> This is the Retrieval part of Retrieval-augmented Generation
# Cosine similarity is just a dot-product after normalizing vectors to be unit norm
# For each embedding, we can compute the top K related sentences based on cosine similarity

topK = 10
topK_similar_texts = []
topK_similar_texts_cosines = []

print(sentence_embedding_normalized.shape)  # already normalized before
for user_query_embedding_idx in range(sentence_embedding_normalized.shape[0]):
    user_query = sentence_embedding_normalized[
        user_query_embedding_idx : user_query_embedding_idx + 1
    ]
    dot_product = torch.mm(sentence_embedding_normalized, user_query.T)
    top_matches = torch.argsort(dot_product, dim=0, descending=True).flatten()[
        1:
    ]  # discard the first elem as that will be itself

    # keep topK
    topK_matches = top_matches[:topK]
    topK_similar_texts.append(
        ", ".join(X_tsne_pd["text"].to_numpy()[topK_matches].tolist())
    )
    topK_similar_texts_cosines.append(
        ", ".join(
            [str(i)[:5] for i in dot_product[topK_matches].flatten().numpy().tolist()]
        )
    )

X_tsne_pd["topK_similar_texts"] = topK_similar_texts
X_tsne_pd["topK_similar_cosines"] = topK_similar_texts_cosines

torch.Size([1500, 768])


In [29]:
X_tsne_pd.head()

Unnamed: 0,tsne_x,tsne_y,text,KMeans_classes,topK_similar_texts,topK_similar_cosines
0,-1.356653,6.192302,virilism,28,"cytomegalovirus viraemia, parametritis, phlebitis infective, bacterial vaginosis, meningoencephalitis viral, penile infection, venogram, meningococcal immunisation, gangrene, genital infection fungal","0.638, 0.467, 0.461, 0.461, 0.422, 0.415, 0.401, 0.398, 0.398, 0.397"
1,0.163219,9.104291,still's disease,48,"albright's disease, coeliac disease, glanzmann's disease, venoocclusive disease, rheumatic heart disease, gaucher's disease, bronchopulmonary disease, hodgkin's disease stage iv, dandruff, periodontal disease","0.499, 0.497, 0.478, 0.452, 0.443, 0.440, 0.410, 0.398, 0.395, 0.395"
2,-2.982886,9.869576,choroid plexus carcinoma,18,"central nervous system lymphoma, cns ventriculitis, polypoidal choroidal vasculopathy, cerebrospinal fluid leakage, meningioma malignant, metastases to central nervous system, central nervous system fungal infection, choroidal neovascularisation, intracranial mass, optic glioma","0.452, 0.421, 0.418, 0.397, 0.382, 0.376, 0.373, 0.373, 0.368, 0.365"
3,0.041707,8.407229,anaphylactic shock,6,"anaphylaxis treatment, shock haemorrhagic, toxic shock syndrome streptococcal, hypersensitivity vasculitis, accidental overdose, drug reaction with eosinophilia and systemic symptoms, animal bite, syncope, cytokine storm, intentional overdose","0.623, 0.572, 0.522, 0.468, 0.436, 0.401, 0.397, 0.390, 0.387, 0.381"
4,-1.55976,8.532484,labour pain,14,"uterine pain, oesophageal pain, perineal pain, neck pain, axillary pain, chest pain, abdominal pain, injection site pain, postpartum haemorrhage, tendon pain","0.772, 0.684, 0.642, 0.635, 0.629, 0.600, 0.599, 0.567, 0.545, 0.525"


In [30]:
# Use a Vector Database for Efficient Similarity Search like FAISS
# build an index for Cosine Similarity (the embeddings must already be normalized before indexing them)

In [31]:
embeddings = sentence_embedding_normalized.numpy().astype("float32")
embeddings.shape

(1500, 768)

In [32]:
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
faiss.write_index(index, "./faiss_vector_db_inner_prod.index")

In [33]:
index = faiss.read_index("./faiss_vector_db_inner_prod.index")

In [34]:
topK = 10
topK_similar_texts = []
topK_similar_texts_cosines = []

# get the topK matches using FAISS in a vectorized manner
similarities, indices = index.search(embeddings, topK + 1)

for user_query_embedding_idx in range(sentence_embedding_normalized.shape[0]):
    topK_matches = indices[user_query_embedding_idx][
        1:
    ]  # ignore top match which is itself
    topK_similar_texts.append(
        ", ".join(X_tsne_pd["text"].to_numpy()[topK_matches].tolist())
    )
    topK_similar_texts_cosines.append(
        ", ".join(
            [str(i)[:5] for i in similarities[user_query_embedding_idx][1:].tolist()]
        )
    )

X_tsne_pd["faiss_topK_similar_texts"] = topK_similar_texts
X_tsne_pd["faiss_topK_similar_cosines"] = topK_similar_texts_cosines

In [35]:
X_tsne_pd.head()

Unnamed: 0,tsne_x,tsne_y,text,KMeans_classes,topK_similar_texts,topK_similar_cosines,faiss_topK_similar_texts,faiss_topK_similar_cosines
0,-1.356653,6.192302,virilism,28,"cytomegalovirus viraemia, parametritis, phlebitis infective, bacterial vaginosis, meningoencephalitis viral, penile infection, venogram, meningococcal immunisation, gangrene, genital infection fungal","0.638, 0.467, 0.461, 0.461, 0.422, 0.415, 0.401, 0.398, 0.398, 0.397","cytomegalovirus viraemia, parametritis, phlebitis infective, bacterial vaginosis, meningoencephalitis viral, penile infection, venogram, meningococcal immunisation, gangrene, genital infection fungal","0.638, 0.467, 0.461, 0.461, 0.422, 0.415, 0.401, 0.398, 0.398, 0.397"
1,0.163219,9.104291,still's disease,48,"albright's disease, coeliac disease, glanzmann's disease, venoocclusive disease, rheumatic heart disease, gaucher's disease, bronchopulmonary disease, hodgkin's disease stage iv, dandruff, periodontal disease","0.499, 0.497, 0.478, 0.452, 0.443, 0.440, 0.410, 0.398, 0.395, 0.395","albright's disease, coeliac disease, glanzmann's disease, venoocclusive disease, rheumatic heart disease, gaucher's disease, bronchopulmonary disease, hodgkin's disease stage iv, dandruff, periodontal disease","0.499, 0.497, 0.478, 0.452, 0.443, 0.440, 0.410, 0.398, 0.395, 0.395"
2,-2.982886,9.869576,choroid plexus carcinoma,18,"central nervous system lymphoma, cns ventriculitis, polypoidal choroidal vasculopathy, cerebrospinal fluid leakage, meningioma malignant, metastases to central nervous system, central nervous system fungal infection, choroidal neovascularisation, intracranial mass, optic glioma","0.452, 0.421, 0.418, 0.397, 0.382, 0.376, 0.373, 0.373, 0.368, 0.365","central nervous system lymphoma, cns ventriculitis, polypoidal choroidal vasculopathy, cerebrospinal fluid leakage, meningioma malignant, metastases to central nervous system, central nervous system fungal infection, choroidal neovascularisation, intracranial mass, optic glioma","0.452, 0.421, 0.418, 0.397, 0.382, 0.376, 0.373, 0.373, 0.368, 0.365"
3,0.041707,8.407229,anaphylactic shock,6,"anaphylaxis treatment, shock haemorrhagic, toxic shock syndrome streptococcal, hypersensitivity vasculitis, accidental overdose, drug reaction with eosinophilia and systemic symptoms, animal bite, syncope, cytokine storm, intentional overdose","0.623, 0.572, 0.522, 0.468, 0.436, 0.401, 0.397, 0.390, 0.387, 0.381","anaphylaxis treatment, shock haemorrhagic, toxic shock syndrome streptococcal, hypersensitivity vasculitis, accidental overdose, drug reaction with eosinophilia and systemic symptoms, animal bite, syncope, cytokine storm, intentional overdose","0.623, 0.572, 0.522, 0.468, 0.436, 0.401, 0.397, 0.390, 0.387, 0.381"
4,-1.55976,8.532484,labour pain,14,"uterine pain, oesophageal pain, perineal pain, neck pain, axillary pain, chest pain, abdominal pain, injection site pain, postpartum haemorrhage, tendon pain","0.772, 0.684, 0.642, 0.635, 0.629, 0.600, 0.599, 0.567, 0.545, 0.525","uterine pain, oesophageal pain, perineal pain, neck pain, axillary pain, chest pain, abdominal pain, injection site pain, postpartum haemorrhage, tendon pain","0.772, 0.684, 0.642, 0.635, 0.629, 0.600, 0.599, 0.567, 0.545, 0.525"
