In [None]:
# One run of test to deduplicate the bio_med_research dataset
import pandas as pd
import os
from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import xml.etree.ElementTree as ET
import json
from tqdm import tqdm
import pickle

In [None]:
# if use colab, run this part
from google.colab import drive

drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/bionlp')

Mounted at /content/drive


In [None]:
# go to model dir
os.chdir('MedImageInsights')

In [None]:
# set directory to deduplicate
directory = "../dataset/QAs"

In [None]:
# install necessary package
!pip install mup
!pip install fvcore

Collecting mup
  Downloading mup-1.0.0.tar.gz (28 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: mup
  Building wheel for mup (setup.py) ... [?25l[?25hdone
  Created wheel for mup: filename=mup-1.0.0-py3-none-any.whl size=23629 sha256=e173466e45ca07a301df9b7086aabecbc82b1dc184916beffcdfbee22a807040
  Stored in directory: /root/.cache/pip/wheels/f4/c8/88/3c23a3d10c50053b6552d2d30aee5b53ba89a47f742420036c
Successfully built mup
Installing collected packages: mup
Successfully installed mup-1.0.0
Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.6 (from fvcore)
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting iopath>=0.1.7 (from fvcore)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━

In [None]:
# load model
from medimageinsightmodel import MedImageInsight

classifier = MedImageInsight(
    model_dir="2024.09.27",
    vision_model_name="medimageinsigt-v1.0.0.pt",
    language_model_name="language_model.pth"
)

classifier.load_model()



Model loaded successfully on device: cuda


In [None]:
# loading dataset
def parse_xml(file):
    tree = ET.parse(file)
    root = tree.getroot()

    sentence_data = []
    for sentence in root.findall('sentence'):
        sentence_id = sentence.get('id')
        sentence_text = sentence.get('text')

        sentence_data.append({
            "sentence_id": sentence_id,
            "sentence_text": sentence_text
        })

    return pd.DataFrame(sentence_data)


def load_dataset(path, filetype = "csv"):
    if filetype == "csv":
        all_files = []
        for root, dirs, files in tqdm(os.walk(path), desc = "Loading CSV files"):
            for file in tqdm(files, desc = "Processing file"):
                if file.endswith(".csv"):
                    all_files.append(os.path.join(root, file))
        ds = {}
        for f in all_files:
            df = pd.read_csv(f)
            ds[f] = df
        return ds
    elif filetype == "xml":
        all_files = []
        for root, dirs, files in tqdm(os.walk(path), desc = "Loading XML files"):
            for file in tqdm(files, desc = "Processing file"):
                if file.endswith(".xml"):
                    all_files.append(os.path.join(root, file))
        ds = {}
        for f in all_files:
            ds[f] = parse_xml(f)
        return ds
    elif filetype == "jsonl":
        all_files = []
        for root, dirs, files in tqdm(os.walk(path), desc = "Loading JSONL files"):
            for file in tqdm(files, desc = "Processing file"):
                if file.endswith(".jsonl"):
                    all_files.append(os.path.join(root, file))
        ds = {}
        for f in all_files:
            print("current file: ", f)
            with open(f, "r") as file:
                data = [json.loads(line) for line in file]
            ds[f] = pd.DataFrame(data)
        return ds
    elif filetype == "json":
        all_files = []
        for root, dirs, files in tqdm(os.walk(path), desc = "Loading JSON files"):
            for file in tqdm(files, desc = "Processing file"):
                if file.endswith(".json"):
                    all_files.append(os.path.join(root, file))
        ds = {}
        for f in all_files:
            with open(f, "r") as file:
                data = json.load(file)
            ds[f] = pd.DataFrame(data)
        return ds



In [None]:
# functions for deduplication
def get_embeddings(texts, batch_size = 64):
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc = "Generating embeddings"):
        batch_texts = texts[i:i+batch_size]
        embeddings.extend(classifier.encode(texts = batch_texts)['text_embeddings'])
    return np.array(embeddings)

def compute_similarity(embeddings, threshold = 0.9):
    # n = len(embeddings)
    # to_remove = set()
    # for i in tqdm(range(n), desc = "Computing similarity"):
    #     for j in range(i+1, n):
    #         sim = cosine_similarity(embeddings[i].reshape(1, -1), embeddings[j].reshape(1, -1))[0][0]
    #         if sim > threshold:
    #             to_remove.add(j)
    # return to_remove
    similarity_matrix = cosine_similarity(embeddings)
    np.fill_diagonal(similarity_matrix, 0)  # Ignore self-similarity

    # Find indices of pairs with similarity above the threshold
    to_remove = set()
    for i in range(similarity_matrix.shape[0]):
        if i in to_remove:
            continue
        similar_indices = np.where(similarity_matrix[i] > threshold)[0]
        to_remove.update(similar_indices)

    return to_remove

def compute_similarity_chunked(embeddings, threshold=0.9, chunk_size=8000):
    """
    Compute cosine similarity in chunks to reduce memory usage.
    """
    n = len(embeddings)
    to_remove = set()
    for i in tqdm(range(0, n, chunk_size), desc= "Calcuating Similarity"):
        # Get the current chunk
        chunk_embeddings = embeddings[i:i + chunk_size]

        # Compute cosine similarity for the current chunk against all embeddings
        similarity_matrix = cosine_similarity(chunk_embeddings, embeddings)

        # Iterate through the chunk rows to find high-similarity indices
        for row_idx, similarities in enumerate(similarity_matrix):
            actual_idx = i + row_idx  # Map back to the original index
            if actual_idx in to_remove:
                continue

            similar_indices = np.where(similarities > threshold)[0]
            similar_indices = [idx for idx in similar_indices if idx > actual_idx]  # Avoid duplicates
            to_remove.update(similar_indices)

    return to_remove

def compute_similarity_between_datasets(embeddings1, embeddings2, threshold = 0.9):
    to_remove = set()
    for i in tqdm(range(len(embeddings1)), desc = "Computing similarity"):
        for j in range(len(embeddings2)):
            sim = cosine_similarity(embeddings1[i].reshape(1, -1), embeddings2[j].reshape(1, -1))[0][0]
            if sim > threshold:
                to_remove.add(j)
    return to_remove

def compute_similarity_between_datasets_chunked(embeddings1, embeddings2, threshold=0.9, chunk_size1=8000, chunk_size2=8000):
    """
    Compute cosine similarity between two datasets in chunks to reduce memory usage.
    Removes entries from embeddings1 based on high similarity with embeddings2.
    """
    to_remove = set()
    n1, n2 = len(embeddings1), len(embeddings2)

    for i in tqdm(range(0, n1, chunk_size1), desc="Processing dataset1 in chunks"):
        # Get a chunk from embeddings1
        chunk_embeddings1 = embeddings1[i:i + chunk_size1]

        for j in range(0, n2, chunk_size2):
            # Get a chunk from embeddings2
            chunk_embeddings2 = embeddings2[j:j + chunk_size2]

            # Compute cosine similarity for the two chunks
            similarity_matrix = cosine_similarity(chunk_embeddings1, chunk_embeddings2)

            # Check rows in chunk_embeddings1 with high similarity to chunk_embeddings2
            for row_idx, similarities in enumerate(similarity_matrix):
                actual_idx = i + row_idx  # Map back to the original index in embeddings1
                if actual_idx in to_remove:
                    continue
                if np.any(similarities > threshold):
                    to_remove.add(actual_idx)

    return to_remove

def deduplication_within_dataset_qa(dataset, threshold = 0.9):
    questions = dataset["question"].tolist()
    #answers = dataset["answer"].tolist()

    question_embeddings = get_embeddings(questions)
    to_remove_questions = compute_similarity_chunked(question_embeddings, threshold)

    new_dataset = dataset.drop(index = list(to_remove_questions)).reset_index(drop=True)

    answers = new_dataset["answer"].tolist()
    answer_embeddings = get_embeddings(answers)
    to_remove_answers = compute_similarity_chunked(answer_embeddings, threshold)

    new_dataset = new_dataset.drop(index = list(to_remove_answers)).reset_index(drop=True)
    return new_dataset, list(to_remove_questions), list(to_remove_answers)


def deduplicate_across_datasets_qa(new_dataset, old_question_embeddings_saved, old_answer_embeddings_saved, threshold = 0.9):
    # Combine all old dataset questions and answers
    # all_old_questions = []
    # all_old_answers = []

    # for dataset in old_datasets:
    #     all_old_questions.extend(dataset["question"].tolist())
    #     all_old_answers.extend(dataset["answer"].tolist())

    # Generate embeddings for old dataset questions and answers
    # old_question_embeddings = get_embeddings(all_old_questions)
    # old_answer_embeddings = get_embeddings(all_old_answers)
    old_question_embeddings = []
    old_answer_embeddings = []
    for old_embed in old_question_embeddings_saved:
        old_question_embeddings.extend(old_embed)
    for old_embed in old_answer_embeddings_saved:
        old_answer_embeddings.extend(old_embed)

    # Generate embeddings for new dataset questions and answers
    new_question_embeddings = get_embeddings(new_dataset["question"].tolist())
    new_answer_embeddings = get_embeddings(new_dataset["answer"].tolist())

    # Deduplicate new questions
    to_remove_questions = compute_similarity_between_datasets_chunked(new_question_embeddings, old_question_embeddings)

    # Deduplicate new answers
    to_remove_answers = compute_similarity_between_datasets_chunked(new_answer_embeddings, old_answer_embeddings)

    # Combine removal indices
    to_remove = to_remove_questions.union(to_remove_answers)

    # Drop duplicates from new dataset
    deduplicated_new_dataset = new_dataset.drop(index=list(to_remove)).reset_index(drop=True)

    return deduplicated_new_dataset, list(to_remove_questions), list(to_remove_answers)



In [None]:
#deduplicated data loading
deduplicated_medicationqa = pd.read_csv("../deduplicated_data/QAs/MedicationQA/medicationqa_train_fulltext_deduplicated.csv")
deduplicated_pubmed1 = pd.read_csv("../deduplicated_data/QAs/PubMedQA/ori_pqaa_deduplicated.csv")
deduplicated_pubmed2 = pd.read_csv("../deduplicated_data/QAs/PubMedQA/ori_pqau_deduplicated.csv")
deduplicated_pubmed3 = pd.read_csv("../deduplicated_data/QAs/PubMedQA/ori_pqal_deduplicated.csv")

## Deduplicate MedMCQA

In [None]:
# load medmcqa
medmcqa = load_dataset(path = directory + "/MedMCQA", filetype = "jsonl")

Loading JSONL files: 0it [00:00, ?it/s]
Processing file: 100%|██████████| 2/2 [00:00<00:00, 22982.49it/s]
Loading JSONL files: 1it [00:01,  1.38s/it]
Processing file: 100%|██████████| 3/3 [00:00<00:00, 33200.30it/s]
Loading JSONL files: 2it [00:01,  1.26it/s]


current file:  ../dataset/QAs/MedMCQA/data/test.jsonl
current file:  ../dataset/QAs/MedMCQA/data/train.jsonl
current file:  ../dataset/QAs/MedMCQA/data/dev.jsonl


In [None]:
print("Available files" + str(medmcqa.keys()))
medmcqa_train = medmcqa["../dataset/QAs/MedMCQA/data/train.jsonl"]
medmcqa_dev = medmcqa["../dataset/QAs/MedMCQA/data/dev.jsonl"]
medmcqa_test = medmcqa["../dataset/QAs/MedMCQA/data/test.jsonl"]

Available filesdict_keys(['../dataset/QAs/MedMCQA/data/test.jsonl', '../dataset/QAs/MedMCQA/data/train.jsonl', '../dataset/QAs/MedMCQA/data/dev.jsonl'])


In [None]:
def process_medmcqa(df, mode = 'train'):
    df['answer'] = None
    for i, row in enumerate(df.itertuples()):
        if mode != "test":
            answer_row = f"The choices are: A) {row.opa}, B) {row.opb}, C) {row.opc}, D) {row.opd}. The correct answer is {row.cop}, because {row.exp}"
        else:
            answer_row = f"The choices are: A) {row.opa}, B) {row.opb}, C) {row.opc}, D) {row.opd}."
        df.at[i, 'answer'] = answer_row

    return df


In [None]:
medmcqa_train = process_medmcqa(medmcqa_train, mode = 'train')
medmcqa_dev = process_medmcqa(medmcqa_dev, mode = 'dev')
medmcqa_test = process_medmcqa(medmcqa_test, mode = 'test')

In [None]:
# self deduplication first
medmcqa_train_self_dedup, removed_questions_self_train, removed_answers_self_train = deduplication_within_dataset_qa(medmcqa_train)
print(len(removed_questions_self_train), len(removed_answers_self_train))
medmcqa_dev_self_dedup, removed_questions_self_dev, removed_answers_self_dev = deduplication_within_dataset_qa(medmcqa_dev)
print(len(removed_questions_self_dev), len(removed_answers_self_dev))
medmcqa_test_self_dedup, removed_questions_self_test, removed_answers_self_test = deduplication_within_dataset_qa(medmcqa_test)
print(len(removed_questions_self_test), len(removed_answers_self_test))

Generating embeddings: 100%|██████████| 2857/2857 [17:00<00:00,  2.80it/s]
Calcuating Similarity: 100%|██████████| 23/23 [02:18<00:00,  6.00s/it]
Generating embeddings: 100%|██████████| 2488/2488 [17:34<00:00,  2.36it/s]
Calcuating Similarity: 100%|██████████| 20/20 [01:44<00:00,  5.22s/it]


23598 15601


Generating embeddings: 100%|██████████| 66/66 [00:23<00:00,  2.78it/s]
Calcuating Similarity: 100%|██████████| 1/1 [00:00<00:00,  7.89it/s]
Generating embeddings: 100%|██████████| 65/65 [00:25<00:00,  2.51it/s]
Calcuating Similarity: 100%|██████████| 1/1 [00:00<00:00,  7.64it/s]


30 163


Generating embeddings: 100%|██████████| 97/97 [00:34<00:00,  2.85it/s]
Calcuating Similarity: 100%|██████████| 1/1 [00:00<00:00,  4.21it/s]
Generating embeddings: 100%|██████████| 96/96 [00:34<00:00,  2.79it/s]
Calcuating Similarity: 100%|██████████| 1/1 [00:00<00:00,  4.36it/s]

13 674





In [None]:
len(medmcqa_train_self_dedup), len(medmcqa_dev_self_dedup), len(medmcqa_test_self_dedup)

(143623, 3990, 5463)

## Now, we deduplicate between existing datas

In [None]:
medmcqa_test_self_dedup.to_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_test_fulltext_deduplicated_self.csv", index = False)
medmcqa_dev_self_dedup.to_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_dev_fulltext_deduplicated_self.csv", index = False)
medmcqa_train_self_dedup.to_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_train_fulltext_deduplicated_self.csv", index = False)

In [None]:
# load back data
medmcqa_test_self_dedup = pd.read_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_test_fulltext_deduplicated_self.csv")
medmcqa_dev_self_dedup = pd.read_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_dev_fulltext_deduplicated_self.csv")
medmcqa_train_self_dedup = pd.read_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_train_fulltext_deduplicated_self.csv")

In [None]:
old_questions = []
old_answers = []

with open("../deduplicated_embeddings/QAs/medicationqa_question_embeddings.pkl", "rb") as f:
    medication_qa_q_embed = pickle.load(f)
    old_questions.append(medication_qa_q_embed)

with open("../deduplicated_embeddings/QAs/medicationqa_answer_embeddings.pkl", "rb") as f:
    medication_qa_a_embed = pickle.load(f)
    old_answers.append(medication_qa_a_embed)

#pubmed1,2,3
with open("../deduplicated_embeddings/QAs/pubmed1_question_embeddings.pkl", "rb") as f:
    pubmed1_q_embed = pickle.load(f)
    old_questions.append(pubmed1_q_embed)

with open("../deduplicated_embeddings/QAs/pubmed1_answer_embeddings.pkl", "rb") as f:
    pubmed1_a_embed = pickle.load(f)
    old_answers.append(pubmed1_a_embed)

with open("../deduplicated_embeddings/QAs/pubmed2_question_embeddings.pkl", "rb") as f:
    pubmed2_q_embed = pickle.load(f)
    old_questions.append(pubmed2_q_embed)

with open("../deduplicated_embeddings/QAs/pubmed2_answer_embeddings.pkl", "rb") as f:
    pubmed2_a_embed = pickle.load(f)
    old_answers.append(pubmed2_a_embed)

with open("../deduplicated_embeddings/QAs/pubmed3_question_embeddings.pkl", "rb") as f:
    pubmed3_q_embed = pickle.load(f)
    old_questions.append(pubmed3_q_embed)

with open("../deduplicated_embeddings/QAs/pubmed3_answer_embeddings.pkl", "rb") as f:
    pubmed3_a_embed = pickle.load(f)
    old_answers.append(pubmed3_a_embed)

In [None]:
# load already there data
full_medmcqa_test_self_dedup, removed_questions_full_test, removed_answers_full_test = deduplicate_across_datasets_qa([deduplicated_medicationqa, deduplicated_pubmed1, deduplicated_pubmed2, deduplicated_pubmed3], medmcqa_test_self_dedup, old_questions, old_answers)


Generating embeddings: 100%|██████████| 86/86 [00:31<00:00,  2.73it/s]
Generating embeddings: 100%|██████████| 86/86 [00:30<00:00,  2.86it/s]
Processing dataset1 in chunks: 100%|██████████| 1/1 [00:10<00:00, 10.92s/it]
Processing dataset1 in chunks: 100%|██████████| 1/1 [00:10<00:00, 10.39s/it]


In [None]:
print(len(removed_questions_full_test), len(removed_answers_full_test))

3 0


In [None]:
full_medmcqa_test_self_dedup.to_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_test_fulltext_deduplicated.csv", index = False)

In [None]:
with open("../deduplicated_embeddings/QAs/medmcqa_test_question_embeddings.pkl", "rb") as f:
    medication_qa_q_embed = pickle.load(f)
    old_questions.append(medication_qa_q_embed)

with open("../deduplicated_embeddings/QAs/medmcqa_test_answer_embeddings.pkl", "rb") as f:
    medication_qa_a_embed = pickle.load(f)
    old_answers.append(medication_qa_a_embed)

In [None]:
# load already there data
full_medmcqa_dev_self_dedup, removed_questions_full_dev, removed_answers_full_dev = deduplicate_across_datasets_qa(medmcqa_dev_self_dedup, old_questions, old_answers)


Generating embeddings: 100%|██████████| 63/63 [00:22<00:00,  2.81it/s]
Generating embeddings: 100%|██████████| 63/63 [00:24<00:00,  2.55it/s]
Processing dataset1 in chunks: 100%|██████████| 1/1 [00:09<00:00,  9.08s/it]
Processing dataset1 in chunks: 100%|██████████| 1/1 [00:08<00:00,  8.40s/it]


In [None]:
print(len(removed_questions_full_dev), len(removed_answers_full_dev))
full_medmcqa_dev_self_dedup.to_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_dev_fulltext_deduplicated.csv", index = False)

5 20


In [None]:
with open("../deduplicated_embeddings/QAs/medmcqa_dev_question_embeddings.pkl", "rb") as f:
    medication_qa_q_embed = pickle.load(f)
    old_questions.append(medication_qa_q_embed)

with open("../deduplicated_embeddings/QAs/medmcqa_dev_answer_embeddings.pkl", "rb") as f:
    medication_qa_a_embed = pickle.load(f)
    old_answers.append(medication_qa_a_embed)

In [None]:
# load already there data
full_medmcqa_train_self_dedup, removed_questions_full_train, removed_answers_full_train = deduplicate_across_datasets_qa(medmcqa_train_self_dedup, old_questions, old_answers)


Generating embeddings: 100%|██████████| 2245/2245 [13:04<00:00,  2.86it/s]
Generating embeddings: 100%|██████████| 2245/2245 [15:48<00:00,  2.37it/s]
Processing dataset1 in chunks: 100%|██████████| 18/18 [04:23<00:00, 14.66s/it]
Processing dataset1 in chunks: 100%|██████████| 18/18 [04:25<00:00, 14.73s/it]


In [None]:
print(len(removed_questions_full_train), len(removed_answers_full_train))
full_medmcqa_train_self_dedup.to_csv("../deduplicated_data/QAs/MedMCQA/medmcqa_train_fulltext_deduplicated.csv", index = False)

222 614
