In [10]:
!pip install -q datasets
!pip install -q gensim
from datasets import load_dataset

ds = load_dataset("medalpaca/medical_meadow_medical_flashcards")
print('Structure of data: ', ds)

Structure of data:  DatasetDict({
    train: Dataset({
        features: ['input', 'output', 'instruction'],
        num_rows: 33955
    })
})


The data is a dictionary ```train``` as the key and another dictionary ```Dataset```  as the value, the keys of this dictinoary are ```features``` and ```num_rows```, features being one of the ```input```, ```output```, or ```instruction```.

The ```num_rows``` suggests that there are 33955 questions and answers with instructions respectively.

Let us have a look at how each one of these are:

In [11]:
print(f'Input :', ds['train']['input'][0])
print(f'Output :', ds['train']['output'][0])
print(f'Instruction :', ds['train']['instruction'][0])

Input : What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels?
Output : Very low Mg2+ levels correspond to low PTH levels which in turn results in low Ca2+ levels.
Instruction : Answer this question truthfully


The ```input```'s are questions,

The ```output```'s are answers and,

The ```instruction```'s are the instructions to answer the questions.

Lowercasing the text, tokenizing based on spaces to prepare the text for training.

In [12]:
import re
import numpy as np

tokenized_sentences = [re.sub('\W', ' ', sentence).lower().split() for sentence in ds['train']['output']]
#\W removes the non word characters thereby removing the '+','/','-' that could be good to keep
for i in tokenized_sentences[:5]:
  print(i)

print(f'Minimum sentence len: ', min([len(i)for i in tokenized_sentences]))
print(f'Maximum sentence len: ', max([len(i)for i in tokenized_sentences]))
print(f'Average sentence len: ', np.mean([len(i)for i in tokenized_sentences]))

['very', 'low', 'mg2', 'levels', 'correspond', 'to', 'low', 'pth', 'levels', 'which', 'in', 'turn', 'results', 'in', 'low', 'ca2', 'levels']
['low', 'estradiol', 'production', 'leads', 'to', 'genitourinary', 'syndrome', 'of', 'menopause', 'atrophic', 'vaginitis']
['low', 'rem', 'sleep', 'latency', 'and', 'experiencing', 'hallucinations', 'sleep', 'paralysis', 'suggests', 'narcolepsy']
['pth', 'independent', 'hypercalcemia', 'which', 'can', 'be', 'caused', 'by', 'cancer', 'granulomatous', 'disease', 'or', 'vitamin', 'd', 'intoxication']
['the', 'level', 'of', 'anti', 'müllerian', 'hormone', 'is', 'directly', 'related', 'to', 'ovarian', 'reserve', 'a', 'lower', 'level', 'indicates', 'a', 'lower', 'ovarian', 'reserve']
Minimum sentence len:  0
Maximum sentence len:  247
Average sentence len:  54.24835812104256


We now want to train the Word2Vec on the ```outputs```.

In [13]:
from gensim.models.word2vec import Word2Vec

model = Word2Vec(tokenized_sentences, vector_size=100, min_count=2, window=10)
print(f'Learnt vectors: ',len(model.wv))

Learnt vectors:  17679


Some things to try,

try different terms

most similar words

In [14]:
# term = 'sickness'
# term = 'fever'
# term = 'cure'
term = 'drugs'

sims = model.wv.most_similar(term, topn=5)
sims

[('agents', 0.8524566888809204),
 ('inhibitors', 0.7399193644523621),
 ('medications', 0.7008272409439087),
 ('drug', 0.6907287240028381),
 ('chemicals', 0.6843883395195007)]

In [15]:
!pip install -q hnswlib
!pip install -q sentence-transformers
import hnswlib
from sentence_transformers import SentenceTransformer

corpus = ds['train']['output'][:100]    # create a corpus of 100 docs
model = SentenceTransformer('all-mpnet-base-v2')
corpus_embeddings = model.encode(corpus, convert_to_tensor=True)


index = hnswlib.Index(space='cosine', dim=corpus_embeddings.size(1))

In [16]:
import os

index_path = './hnswlib.index'

if not os.path.exists(index_path):
  print("Creating a HNSWLIB index")
  corpus = ds['train']['output']
  model = SentenceTransformer('all-mpnet-base-v2')
  corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
  index = hnswlib.Index(space='cosine', dim=corpus_embeddings.size(1))
  index.init_index(max_elements = corpus_embeddings.size(0), ef_construction=128, M=64)
  index.add_items(corpus_embeddings.cpu(), list(range(len(corpus_embeddings))))
  print("Saving index to:", index_path)
  index.save_index(index_path)
else:
  index.load_index(index_path)


In [17]:
query = ds['train']['input'][0]
print(f'Query : ', query)
query_embedding = model.encode(query, convert_to_tensor=True)
ids, distances = index.knn_query(query_embedding.cpu(), k=16)
out = ds['train']['output'][ids[0][np.argmin(distances[0])]]
out

Query :  What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels?


'Very low Mg2+ levels correspond to low PTH levels which in turn results in low Ca2+ levels.'

Re-ranking

In [25]:
from sentence_transformers import CrossEncoder
xenc_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

model_inputs = [(query, corpus[i]) for i in ids[0]]
cross_scores = xenc_model.predict(model_inputs)
print("Cross-encoder model re-ranking results")
print(f"Query: \"{query} \"")
print("-------------------------")
for idx in np.argsort(-cross_scores)[:3]:
  print(f"Score: {cross_scores[idx]:.4f}\nDocument: \" {corpus[ids[0][idx]]}\"")

Cross-encoder model re-ranking results
Query: "What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels? "
-------------------------
Score: 8.8953
Document: " Very low Mg2+ levels correspond to low PTH levels which in turn results in low Ca2+ levels."
Score: 4.0440
Document: " Low Ca2+ and low PTH is seen in primary hypoparathyroidism."
Score: -4.2314
Document: " PTH-independent hypercalcemia, which can be caused by cancer, granulomatous disease, or vitamin D intoxication."


Use the first few of the documents as the context and prompt an LLM to generate text