In [3]:
import os

# # Set ALL necessary Hugging Face cache paths
os.environ["HF_HOME"] = "/projects/sciences/computing/sheju347/.cache/huggingface"
os.environ["HF_HUB_CACHE"] = "/projects/sciences/computing/sheju347/.cache/huggingface/hub"
# os.environ["TRANSFORMERS_CACHE"] = "/projects/sciences/computing/sheju347/.cache/transformers"
# os.environ["HF_DATASETS_CACHE"] = "/projects/sciences/computing/sheju347/.cache/datasets"


from datasets import load_dataset
ds = load_dataset("MedRAG/pubmed", split = "train")

print ("done")

done


In [4]:

data_list = ds.select(range(10000))
textList = []
for data in data_list:
    text = data["contents"]
    textList.append(text)

In [5]:
from transformers import AutoTokenizer, AutoModel
import torch

model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

def embed_text(text, device="cpu"):
    inputs = tokenizer(text, truncation=True, return_tensors="pt", max_length=512)
    with torch.no_grad():
        outputs = model(**{k: v.to(device) for k, v in inputs.items()})
    cls_embedding = outputs.last_hidden_state[0, 0, :]  # [CLS] token of the first (and only) sequence
    return cls_embedding.cpu().numpy()
    

In [None]:
import tqdm

embeddings = []
for text in tqdm.tqdm(textList, desc = "embedding"):
    embedding = embed_text(text)
print(embeddings)

In [None]:
dim = embeddings.shape[1]
print(dim)

In [None]:
import numpy as np 

# Fix: Ensure embeddings are C-contiguous and float32
embeddings = np.ascontiguousarray(embeddings, dtype='float32')

nlist = 10  # Number of clusters (high: more accurate, but slower)
index = faiss.IndexIVFFlat(faiss.IndexFlatL2(dim), dim, nlist)
index.train(embeddings) # K-means clustering
index.add(embeddings)
index.nprobe = 10 # Search 10 clusters per query

faiss.write_index(index, "/projects/sciences/computing/sheju347/RAG/pubmed_faiss_ivf.index")
print("done")

In [None]:
index = faiss.read_index("/projects/sciences/computing/sheju347/RAG/pubmed_faiss_ivf.index")



In [None]:
query = "Influence of a new virostatic compound on the induction of enzymes in rat liver. The virostatic compound N,N-diethyl-4-[2-(2-oxo-3-tetradecyl-1-imidazolidinyl)-ethyl]-1-piperazinecarboxamide-hydrochloride (5531) was analyzed as to its effect on the induction of tryptophan-pyrrolase and tyrosineaminotransferase in rat liver. 1. The basic activity of the enzymes was not influenced by the substance either in normal or in adrenalectomized animals"
query_embedding = embed_text(query)
D, I = index.search(query_embedding, k=3) # k-nearest neighbor search
print(D, I)