In [1]:
# One run of test to deduplicate the bio_med_research dataset
import pandas as pd
import os
from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
# Initialize classifier
import xml.etree.ElementTree as ET
import json
from tqdm import tqdm

In [2]:
# if use colab, run this part
from google.colab import drive

drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/bionlp')

Mounted at /content/drive


In [3]:
# go to model dir
os.chdir('MedImageInsights')

In [4]:
# install necessary package
!pip install mup
!pip install fvcore

Collecting mup
  Downloading mup-1.0.0.tar.gz (28 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: mup
  Building wheel for mup (setup.py) ... [?25l[?25hdone
  Created wheel for mup: filename=mup-1.0.0-py3-none-any.whl size=23629 sha256=e9ffdccbd647c5fe3ee5c20ff8681d8a06fa6b9608ebaae5027348995a55de17
  Stored in directory: /root/.cache/pip/wheels/f4/c8/88/3c23a3d10c50053b6552d2d30aee5b53ba89a47f742420036c
Successfully built mup
Installing collected packages: mup
Successfully installed mup-1.0.0
Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.6 (from fvcore)
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting iopath>=0.1.7 (from fvcore)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━

In [5]:
# load model
from medimageinsightmodel import MedImageInsight

classifier = MedImageInsight(
    model_dir="2024.09.27",
    vision_model_name="medimageinsigt-v1.0.0.pt",
    language_model_name="language_model.pth"
)

classifier.load_model()



Model loaded successfully on device: cuda


## Caluclate Existing Embeddings

In [6]:
import os
import numpy as np
import pickle  # To save/load embeddings efficiently

def calculate_and_save_embeddings(dataset, dataset_name, save_dir="embeddings_cache", batch_size=128):
    """
    Compute and save embeddings for a QA dataset.

    Args:
        dataset (pd.DataFrame): Dataset containing "question" and "answer" columns.
        dataset_name (str): Name of the dataset for unique file identification.
        save_dir (str): Directory where embeddings will be saved.
        batch_size (int): Batch size for generating embeddings.

    Returns:
        dict: A dictionary containing question and answer embeddings.
    """
    # Ensure save directory exists
    os.makedirs(save_dir, exist_ok=True)

    # File paths for embeddings
    question_embedding_file = os.path.join(save_dir, f"{dataset_name}_question_embeddings.pkl")
    answer_embedding_file = os.path.join(save_dir, f"{dataset_name}_answer_embeddings.pkl")

    # Check if embeddings already exist
    if os.path.exists(question_embedding_file) and os.path.exists(answer_embedding_file):
        print(f"Loading cached embeddings for {dataset_name}...")
        with open(question_embedding_file, "rb") as qf:
            question_embeddings = pickle.load(qf)
        with open(answer_embedding_file, "rb") as af:
            answer_embeddings = pickle.load(af)
    else:
        # Compute embeddings for questions
        print(f"Generating question embeddings for {dataset_name}...")
        questions = dataset["question"].tolist()
        question_embeddings = []
        for i in tqdm(range(0, len(questions), batch_size), desc="Question Embeddings"):
            batch_questions = questions[i:i + batch_size]
            question_embeddings.extend(classifier.encode(texts=batch_questions)["text_embeddings"])
        question_embeddings = np.array(question_embeddings)

        # Save question embeddings
        with open(question_embedding_file, "wb") as qf:
            pickle.dump(question_embeddings, qf)
        print(f"Saved question embeddings for {dataset_name}.")

        # Compute embeddings for answers
        print(f"Generating answer embeddings for {dataset_name}...")
        answers = dataset["answer"].tolist()
        answer_embeddings = []
        for i in tqdm(range(0, len(answers), batch_size), desc="Answer Embeddings"):
            batch_answers = answers[i:i + batch_size]
            answer_embeddings.extend(classifier.encode(texts=batch_answers)["text_embeddings"])
        answer_embeddings = np.array(answer_embeddings)

        # Save answer embeddings
        with open(answer_embedding_file, "wb") as af:
            pickle.dump(answer_embeddings, af)
        print(f"Saved answer embeddings for {dataset_name}.")

    return {"questions": question_embeddings, "answers": answer_embeddings}


In [7]:
trec_test = pd.read_csv("../deduplicated_data/QAs/LiveQA/trec_qa_test_fulltext_deduplicated.csv")

In [8]:
calculate_and_save_embeddings(trec_test, "trec_test", save_dir="../deduplicated_embeddings/QAs", batch_size=128)

Generating question embeddings for trec_test...


Question Embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.74s/it]


Saved question embeddings for trec_test.
Generating answer embeddings for trec_test...


Answer Embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.15it/s]

Saved answer embeddings for trec_test.





{'questions': array([[-0.00854761, -0.01788489,  0.02210569, ...,  0.01640709,
          0.00962725, -0.02104766],
        [-0.02988225,  0.06119742, -0.0220231 , ..., -0.02271864,
          0.04223474,  0.0197355 ],
        [-0.04223972,  0.03976909, -0.00833748, ..., -0.01961781,
          0.01816254,  0.00156547],
        ...,
        [ 0.00255125, -0.01433059, -0.02745977, ..., -0.0250255 ,
         -0.02510097, -0.01447551],
        [ 0.00769721,  0.01768944, -0.00075004, ..., -0.03420612,
         -0.01338638,  0.00608168],
        [-0.01469727,  0.03507981, -0.00156332, ...,  0.02157284,
         -0.00573132,  0.03931956]], dtype=float32),
 'answers': array([[-0.03538248, -0.00768419,  0.0133515 , ...,  0.03127733,
          0.02833509,  0.01364348],
        [-0.00613328,  0.02429301, -0.0034427 , ..., -0.03475191,
         -0.02888328, -0.002819  ],
        [-0.00713882,  0.01198535, -0.0566122 , ...,  0.02565822,
          0.00561776, -0.01705137],
        ...,
        [-0.012

In [None]:
trec_train1 = pd.read_csv("../deduplicated_data/QAs/LiveQA/trec_qa_train1_fulltext_deduplicated.csv")

In [None]:
calculate_and_save_embeddings(trec_train1, "trec_train1", save_dir="../deduplicated_embeddings/QAs", batch_size=128)

Generating question embeddings for trec_train1...


Question Embeddings: 100%|██████████| 2/2 [00:02<00:00,  1.17s/it]


Saved question embeddings for trec_train1.
Generating answer embeddings for trec_train1...


Answer Embeddings: 100%|██████████| 2/2 [00:02<00:00,  1.01s/it]

Saved answer embeddings for trec_train1.





{'questions': array([[-0.03138409,  0.0085791 , -0.00895555, ..., -0.04114044,
         -0.02161177, -0.00745709],
        [ 0.0017005 ,  0.09026409,  0.02877223, ..., -0.00430799,
          0.01211306, -0.01728003],
        [-0.01438718,  0.01423368, -0.03137173, ...,  0.01432422,
         -0.01381881,  0.03266332],
        ...,
        [-0.00100056,  0.06271383, -0.01148444, ...,  0.03190619,
         -0.01825087,  0.0001637 ],
        [-0.00770128,  0.00543365, -0.00874834, ...,  0.01995905,
          0.02506645,  0.00431239],
        [-0.03477092,  0.03620756, -0.00251215, ...,  0.0102073 ,
         -0.02318753, -0.02381962]], dtype=float32),
 'answers': array([[-0.00402346,  0.02634857,  0.01679761, ..., -0.00670597,
         -0.00429125, -0.00377662],
        [-0.0141833 ,  0.06576258, -0.02740972, ..., -0.02598694,
          0.01726627, -0.01747934],
        [ 0.00042881, -0.01156661, -0.02231102, ...,  0.00377204,
          0.03326612,  0.01611725],
        ...,
        [ 0.001

In [None]:
trec_train2 = pd.read_csv("../deduplicated_data/QAs/LiveQA/trec_qa_train2_fulltext_deduplicated.csv")

In [None]:
calculate_and_save_embeddings(trec_train2, "trec_train2", save_dir="../deduplicated_embeddings/QAs", batch_size=128)

Generating question embeddings for trec_train2...


Question Embeddings: 100%|██████████| 2/2 [00:01<00:00,  1.61it/s]


Saved question embeddings for trec_train2.
Generating answer embeddings for trec_train2...


Answer Embeddings: 100%|██████████| 2/2 [00:01<00:00,  1.27it/s]

Saved answer embeddings for trec_train2.





{'questions': array([[-5.61103374e-02, -2.94192377e-02,  2.26631686e-02, ...,
         -1.13030421e-02,  1.22457137e-02,  3.31664598e-03],
        [-4.09332924e-02, -2.83247903e-02, -1.67781692e-02, ...,
         -2.67210249e-02, -1.40400175e-02,  4.48641628e-02],
        [-8.66916869e-03,  3.95595049e-03, -3.06588560e-02, ...,
          4.47792094e-03,  8.76315683e-03,  2.55480111e-02],
        ...,
        [-2.59803701e-02,  5.77381626e-02, -1.79289468e-02, ...,
          1.42500096e-03,  3.10253538e-02, -1.61092486e-02],
        [-2.66397391e-02,  3.72499451e-02, -1.92779768e-02, ...,
         -1.90123823e-02, -2.50128396e-02,  3.77133526e-02],
        [-7.25916587e-03,  2.71126954e-03, -4.23971713e-02, ...,
          3.39236744e-02, -1.24145523e-02,  2.62199719e-05]], dtype=float32),
 'answers': array([[-3.39468271e-02, -2.58655995e-02, -7.07719615e-03, ...,
         -6.32266700e-03, -1.49055412e-02,  6.81720441e-03],
        [-7.40444884e-02,  1.14480965e-02, -1.74963137e-03, ...,

In [None]:
deduplicated_medqa_train = pd.read_csv("../deduplicated_data/QAs/MedQA-USMLE/medicationqa_train_deduplicated.csv")

In [None]:
calculate_and_save_embeddings(deduplicated_medqa_train, "medqa_train", save_dir="../deduplicated_embeddings/QAs", batch_size=128)

Generating question embeddings for medqa_train...


Question Embeddings: 100%|██████████| 74/74 [01:05<00:00,  1.12it/s]


Saved question embeddings for medqa_train.
Generating answer embeddings for medqa_train...


Answer Embeddings: 100%|██████████| 74/74 [00:55<00:00,  1.34it/s]

Saved answer embeddings for medqa_train.





{'questions': array([[-0.01920012, -0.00789331, -0.03105436, ..., -0.00961958,
         -0.00903517,  0.0014642 ],
        [-0.00546508, -0.00583425,  0.00680326, ...,  0.0203853 ,
         -0.0036516 ,  0.00322377],
        [ 0.02283401,  0.02300522, -0.00895497, ..., -0.00109803,
          0.01188346, -0.00956821],
        ...,
        [ 0.01013726, -0.00840279, -0.01759079, ..., -0.01372691,
         -0.00095596,  0.01392212],
        [-0.00766191,  0.02635055,  0.01182984, ..., -0.00734856,
         -0.02743114, -0.01689253],
        [-0.01830177,  0.01427505, -0.00038729, ..., -0.00184337,
         -0.01421148, -0.00056265]], dtype=float32),
 'answers': array([[-1.2976246e-02,  2.8769536e-02, -1.7270874e-02, ...,
          3.2288939e-02,  2.1586902e-03, -1.4232512e-02],
        [-1.1601702e-02,  2.7308626e-02,  1.2843061e-02, ...,
         -3.2742750e-02,  8.5733337e-03,  4.2165298e-02],
        [-9.6395276e-03,  2.4367317e-03,  4.1730651e-05, ...,
          2.1742405e-02, -2.0052

In [None]:
deduplicated_medqa_dev = pd.read_csv("../deduplicated_data/QAs/MedQA-USMLE/medqa_dev_deduplicated.csv")
calculate_and_save_embeddings(deduplicated_medqa_dev, "medqa_dev", save_dir="../deduplicated_embeddings/QAs", batch_size=128)

Generating question embeddings for medqa_dev...


Question Embeddings: 100%|██████████| 9/9 [00:07<00:00,  1.14it/s]


Saved question embeddings for medqa_dev.
Generating answer embeddings for medqa_dev...


Answer Embeddings: 100%|██████████| 9/9 [00:06<00:00,  1.32it/s]

Saved answer embeddings for medqa_dev.





{'questions': array([[-0.01426829,  0.02189184, -0.01237418, ..., -0.00538224,
          0.00711382, -0.0095661 ],
        [ 0.02538575, -0.01957065, -0.00494494, ..., -0.00383362,
          0.00719011, -0.02451274],
        [-0.01017492,  0.02959767,  0.00577046, ...,  0.00182964,
          0.02881611, -0.04058651],
        ...,
        [-0.01921252, -0.04238871, -0.03082264, ...,  0.00705667,
          0.02378668, -0.05210292],
        [-0.04144587,  0.0007426 , -0.01640335, ..., -0.02845823,
         -0.02630954, -0.04495118],
        [ 0.01656986,  0.03027669, -0.02626771, ..., -0.05504988,
         -0.02626285, -0.04265226]], dtype=float32),
 'answers': array([[-0.02596904,  0.01135569, -0.03700222, ...,  0.06005554,
          0.0228253 , -0.01678431],
        [-0.03584465,  0.0066256 ,  0.03030788, ...,  0.04410596,
          0.00286488, -0.00241645],
        [-0.02225801,  0.0277004 ,  0.02534864, ...,  0.0106915 ,
          0.02450419, -0.01461822],
        ...,
        [-0.001

In [None]:
deduplicated_medqa_test = pd.read_csv("../deduplicated_data/QAs/MedQA-USMLE/medqa_test_deduplicated.csv")
calculate_and_save_embeddings(deduplicated_medqa_test, "medqa_test", save_dir="../deduplicated_embeddings/QAs", batch_size=128)

Generating question embeddings for medqa_test...


Question Embeddings: 100%|██████████| 9/9 [00:07<00:00,  1.13it/s]


Saved question embeddings for medqa_test.
Generating answer embeddings for medqa_test...


Answer Embeddings: 100%|██████████| 9/9 [00:06<00:00,  1.34it/s]

Saved answer embeddings for medqa_test.





{'questions': array([[ 0.03382285,  0.00109251, -0.01978934, ...,  0.0046124 ,
         -0.00622189, -0.00409079],
        [ 0.00508343,  0.00938223,  0.00950611, ..., -0.0378918 ,
          0.02640384, -0.04462594],
        [ 0.00671429,  0.03353358, -0.01778153, ..., -0.0463046 ,
         -0.01510314, -0.02009208],
        ...,
        [-0.02249827,  0.02032656, -0.02561955, ..., -0.0123793 ,
         -0.0119698 , -0.00687133],
        [ 0.02484576, -0.03710643, -0.02562299, ...,  0.01056139,
         -0.02896243, -0.01811273],
        [ 0.01265016, -0.00121804, -0.03718533, ..., -0.04060804,
         -0.00270018, -0.0479387 ]], dtype=float32),
 'answers': array([[-0.00759999,  0.06355494,  0.00238086, ...,  0.01088175,
          0.03227371, -0.01178831],
        [ 0.00239567, -0.05118917, -0.00570694, ...,  0.00379949,
          0.01057235, -0.00518155],
        [ 0.00386819,  0.02520763,  0.01636413, ..., -0.01436383,
         -0.02558199, -0.03280509],
        ...,
        [ 0.008

In [None]:
deduplicated_medmcqa_test = pd.read_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_test_fulltext_deduplicated.csv")

In [None]:
calculate_and_save_embeddings(deduplicated_medmcqa_test, "medmcqa_test", save_dir="../deduplicated_embeddings/QAs", batch_size=128)

Generating question embeddings for medmcqa_test...


Question Embeddings: 100%|██████████| 43/43 [00:30<00:00,  1.43it/s]


Saved question embeddings for medmcqa_test.
Generating answer embeddings for medmcqa_test...


Answer Embeddings: 100%|██████████| 43/43 [00:29<00:00,  1.44it/s]

Saved answer embeddings for medmcqa_test.





{'questions': array([[ 0.00076395, -0.02654904, -0.03134664, ...,  0.02570958,
         -0.00195837, -0.00630343],
        [ 0.00524337,  0.033416  ,  0.00307591, ...,  0.04220593,
         -0.01449368,  0.0308999 ],
        [-0.00434002, -0.02685545,  0.01362809, ...,  0.05724397,
         -0.00692272, -0.0030829 ],
        ...,
        [-0.02461233,  0.01374935, -0.00677833, ..., -0.00967417,
         -0.01530647, -0.00684253],
        [-0.00531393,  0.01521857, -0.0279128 , ...,  0.02822973,
          0.00548543, -0.00014811],
        [ 0.01891935, -0.02109549, -0.02800494, ...,  0.01564752,
         -0.00385373, -0.00310911]], dtype=float32),
 'answers': array([[-0.02031539, -0.00073444, -0.02261897, ..., -0.00154823,
         -0.02488664,  0.0226727 ],
        [-0.0212907 ,  0.00392494, -0.03102093, ...,  0.00871337,
         -0.03402397,  0.03308565],
        [ 0.0137716 , -0.03382061, -0.02171203, ..., -0.01962037,
         -0.01894022,  0.00477085],
        ...,
        [-0.021

In [None]:
deduplicated_medmcqa_dev = pd.read_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_dev_fulltext_deduplicated.csv")

In [None]:
calculate_and_save_embeddings(deduplicated_medmcqa_dev, "medmcqa_dev", save_dir="../deduplicated_embeddings/QAs", batch_size=128)

Generating question embeddings for medmcqa_dev...


Question Embeddings: 100%|██████████| 31/31 [00:21<00:00,  1.43it/s]


Saved question embeddings for medmcqa_dev.
Generating answer embeddings for medmcqa_dev...


Answer Embeddings: 100%|██████████| 31/31 [00:25<00:00,  1.24it/s]

Saved answer embeddings for medmcqa_dev.





{'questions': array([[-2.56071370e-02,  1.91394035e-02,  1.27475783e-02, ...,
          6.34452626e-02, -1.70708690e-02,  2.03857757e-02],
        [ 5.93451085e-03,  5.27069345e-02, -1.12807238e-02, ...,
         -1.53421902e-03, -4.03792597e-03,  5.80840185e-03],
        [-2.25096289e-02, -2.07897816e-02,  9.05894209e-03, ...,
          1.06330588e-02, -2.68628467e-02,  1.02449032e-02],
        ...,
        [ 1.20566385e-02,  4.27294001e-02, -3.73669937e-02, ...,
          5.58599308e-02, -4.34714146e-02, -1.61970966e-02],
        [ 3.84220891e-02,  1.79261365e-03, -3.20613049e-02, ...,
         -9.69613343e-03, -1.58840474e-02,  1.38213523e-02],
        [-3.43796909e-02, -3.55549928e-05, -1.28169907e-02, ...,
          2.41738465e-02,  6.28327113e-03,  3.56595479e-02]], dtype=float32),
 'answers': array([[-0.00665213,  0.00178398,  0.01581716, ...,  0.02604667,
         -0.04338642,  0.01888935],
        [-0.02049993,  0.01377321,  0.01967613, ...,  0.04060692,
          0.00947692, 

In [None]:
deduplicated_medmcqa_train = pd.read_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_train_fulltext_deduplicated.csv")

In [None]:
calculate_and_save_embeddings(deduplicated_medmcqa_train, "medmcqa_train", save_dir="../deduplicated_embeddings/QAs", batch_size=128)

Generating question embeddings for medmcqa_train...


Question Embeddings: 100%|██████████| 1116/1116 [13:23<00:00,  1.39it/s]


Saved question embeddings for medmcqa_train.
Generating answer embeddings for medmcqa_train...


Answer Embeddings: 100%|██████████| 1116/1116 [15:48<00:00,  1.18it/s]


Saved answer embeddings for medmcqa_train.


{'questions': array([[ 0.0089588 ,  0.02806156,  0.03427921, ..., -0.03710175,
         -0.04952196,  0.0055507 ],
        [-0.01757121, -0.00402611, -0.02513397, ...,  0.05821564,
          0.05154986, -0.01282142],
        [-0.02906533,  0.02274898, -0.03687871, ..., -0.00077901,
          0.02364248,  0.00395479],
        ...,
        [-0.00888756,  0.00367781,  0.0301006 , ...,  0.02007772,
         -0.02150984,  0.02527643],
        [-0.02117271,  0.04881136, -0.0177855 , ..., -0.00178686,
          0.02066536, -0.00442233],
        [ 0.00765537, -0.00361644, -0.01151004, ...,  0.00502331,
          0.00157674,  0.01298617]], dtype=float32),
 'answers': array([[-0.04601239,  0.04181859, -0.01415728, ..., -0.01211543,
          0.00663799,  0.05835623],
        [ 0.01110867, -0.01535145, -0.02623153, ...,  0.00427459,
         -0.01035658,  0.0333827 ],
        [ 0.00476568,  0.00990281,  0.02104012, ..., -0.01777902,
          0.02766381,  0.04678949],
        ...,
        [-0.015