In [1]:
# One run of test to deduplicate the bio_med_research dataset
import pandas as pd
import os
from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import xml.etree.ElementTree as ET
import json
from tqdm import tqdm
import pickle

In [2]:
# if use colab, run this part
from google.colab import drive

drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/bionlp')

Mounted at /content/drive


In [3]:
# go to model dir
os.chdir('MedImageInsights')

In [4]:
# set directory to deduplicate
directory = "../deduplicated_data/self_medquad"

In [5]:
# install necessary package
!pip install mup
!pip install fvcore

Collecting mup
  Downloading mup-1.0.0.tar.gz (28 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: mup
  Building wheel for mup (setup.py) ... [?25l[?25hdone
  Created wheel for mup: filename=mup-1.0.0-py3-none-any.whl size=23629 sha256=d94b79c3879b21015225a72e93aaab5b209d0d9b5caa3fbe155eb12b1b66c549
  Stored in directory: /root/.cache/pip/wheels/f4/c8/88/3c23a3d10c50053b6552d2d30aee5b53ba89a47f742420036c
Successfully built mup
Installing collected packages: mup
Successfully installed mup-1.0.0
Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.6 (from fvcore)
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting iopath>=0.1.7 (from fvcore)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━

In [6]:
# load model
from medimageinsightmodel import MedImageInsight

classifier = MedImageInsight(
    model_dir="2024.09.27",
    vision_model_name="medimageinsigt-v1.0.0.pt",
    language_model_name="language_model.pth"
)

classifier.load_model()



Model loaded successfully on device: cuda


In [7]:
all_csv_files = []
for root, dirs, files in os.walk(directory):
    for file in files:
        if file.endswith(".csv"):
            file_path = os.path.join(root, file)
            all_csv_files.append(file_path)

In [8]:
data = {}
for f in all_csv_files:
    data[f] = pd.read_csv(f)

# deduplicate across all dataset

In [9]:
# functions for deduplication
def get_embeddings(texts, batch_size = 64):
    embeddings = []
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        embeddings.extend(classifier.encode(texts = batch_texts)['text_embeddings'])
    return np.array(embeddings)

def compute_similarity(embeddings, threshold = 0.9):
    # n = len(embeddings)
    # to_remove = set()
    # for i in tqdm(range(n), desc = "Computing similarity"):
    #     for j in range(i+1, n):
    #         sim = cosine_similarity(embeddings[i].reshape(1, -1), embeddings[j].reshape(1, -1))[0][0]
    #         if sim > threshold:
    #             to_remove.add(j)
    # return to_remove
    similarity_matrix = cosine_similarity(embeddings)
    np.fill_diagonal(similarity_matrix, 0)  # Ignore self-similarity

    # Find indices of pairs with similarity above the threshold
    to_remove = set()
    for i in range(similarity_matrix.shape[0]):
        if i in to_remove:
            continue
        similar_indices = np.where(similarity_matrix[i] > threshold)[0]
        to_remove.update(similar_indices)

    return to_remove

def compute_similarity_chunked(embeddings, threshold=0.9, chunk_size=8000):
    """
    Compute cosine similarity in chunks to reduce memory usage.
    """
    n = len(embeddings)
    to_remove = set()
    for i in range(0, n, chunk_size):
        # Get the current chunk
        chunk_embeddings = embeddings[i:i + chunk_size]

        # Compute cosine similarity for the current chunk against all embeddings
        similarity_matrix = cosine_similarity(chunk_embeddings, embeddings)

        # Iterate through the chunk rows to find high-similarity indices
        for row_idx, similarities in enumerate(similarity_matrix):
            actual_idx = i + row_idx  # Map back to the original index
            if actual_idx in to_remove:
                continue

            similar_indices = np.where(similarities > threshold)[0]
            similar_indices = [idx for idx in similar_indices if idx > actual_idx]  # Avoid duplicates
            to_remove.update(similar_indices)

    return to_remove

def compute_similarity_between_datasets(embeddings1, embeddings2, threshold = 0.9):
    to_remove = set()
    for i in range(len(embeddings1)):
        for j in range(len(embeddings2)):
            sim = cosine_similarity(embeddings1[i].reshape(1, -1), embeddings2[j].reshape(1, -1))[0][0]
            if sim > threshold:
                to_remove.add(j)
    return to_remove

def compute_similarity_between_datasets_chunked(embeddings1, embeddings2, threshold=0.9, chunk_size1=8000, chunk_size2=8000):
    """
    Compute cosine similarity between two datasets in chunks to reduce memory usage.
    Removes entries from embeddings1 based on high similarity with embeddings2.
    """
    to_remove = set()
    n1, n2 = len(embeddings1), len(embeddings2)

    for i in ange(0, n1, chunk_size1):
        # Get a chunk from embeddings1
        chunk_embeddings1 = embeddings1[i:i + chunk_size1]

        for j in range(0, n2, chunk_size2):
            # Get a chunk from embeddings2
            chunk_embeddings2 = embeddings2[j:j + chunk_size2]

            # Compute cosine similarity for the two chunks
            similarity_matrix = cosine_similarity(chunk_embeddings1, chunk_embeddings2)

            # Check rows in chunk_embeddings1 with high similarity to chunk_embeddings2
            for row_idx, similarities in enumerate(similarity_matrix):
                actual_idx = i + row_idx  # Map back to the original index in embeddings1
                if actual_idx in to_remove:
                    continue
                if np.any(similarities > threshold):
                    to_remove.add(actual_idx)

    return to_remove

def deduplication_within_dataset_qa(dataset, threshold = 0.9):
    questions = dataset["question"].tolist()
    #answers = dataset["answer"].tolist()

    question_embeddings = get_embeddings(questions)
    to_remove_questions = compute_similarity_chunked(question_embeddings, threshold)

    new_dataset = dataset.drop(index = list(to_remove_questions)).reset_index(drop=True)

    answers = new_dataset["answer"].tolist()
    answer_embeddings = get_embeddings(answers)
    to_remove_answers = compute_similarity_chunked(answer_embeddings, threshold)

    new_dataset = new_dataset.drop(index = list(to_remove_answers)).reset_index(drop=True)
    return new_dataset, list(to_remove_questions), list(to_remove_answers)


def deduplicate_across_datasets_qa(new_dataset, old_question_embeddings_saved, old_answer_embeddings_saved, threshold = 0.9):
    # Combine all old dataset questions and answers
    # all_old_questions = []
    # all_old_answers = []

    # for dataset in old_datasets:
    #     all_old_questions.extend(dataset["question"].tolist())
    #     all_old_answers.extend(dataset["answer"].tolist())

    # Generate embeddings for old dataset questions and answers
    # old_question_embeddings = get_embeddings(all_old_questions)
    # old_answer_embeddings = get_embeddings(all_old_answers)
    old_question_embeddings = []
    old_answer_embeddings = []
    for old_embed in old_question_embeddings_saved:
        old_question_embeddings.extend(old_embed)
    for old_embed in old_answer_embeddings_saved:
        old_answer_embeddings.extend(old_embed)

    # Generate embeddings for new dataset questions and answers
    new_question_embeddings = get_embeddings(new_dataset["question"].tolist())
    new_answer_embeddings = get_embeddings(new_dataset["answer"].tolist())

    # Deduplicate new questions
    to_remove_questions = compute_similarity_between_datasets_chunked(new_question_embeddings, old_question_embeddings)

    # Deduplicate new answers
    to_remove_answers = compute_similarity_between_datasets_chunked(new_answer_embeddings, old_answer_embeddings)

    # Combine removal indices
    to_remove = to_remove_questions.union(to_remove_answers)

    # Drop duplicates from new dataset
    deduplicated_new_dataset = new_dataset.drop(index=list(to_remove)).reset_index(drop=True)

    return deduplicated_new_dataset, list(to_remove_questions), list(to_remove_answers)



In [10]:
old_questions = []
old_answers = []

with open("../deduplicated_embeddings/QAs/medicationqa_question_embeddings.pkl", "rb") as f:
    medication_qa_q_embed = pickle.load(f)
    old_questions.append(medication_qa_q_embed)

with open("../deduplicated_embeddings/QAs/medicationqa_answer_embeddings.pkl", "rb") as f:
    medication_qa_a_embed = pickle.load(f)
    old_answers.append(medication_qa_a_embed)

#pubmed1,2,3
with open("../deduplicated_embeddings/QAs/pubmed1_question_embeddings.pkl", "rb") as f:
    pubmed1_q_embed = pickle.load(f)
    old_questions.append(pubmed1_q_embed)

with open("../deduplicated_embeddings/QAs/pubmed1_answer_embeddings.pkl", "rb") as f:
    pubmed1_a_embed = pickle.load(f)
    old_answers.append(pubmed1_a_embed)

with open("../deduplicated_embeddings/QAs/pubmed2_question_embeddings.pkl", "rb") as f:
    pubmed2_q_embed = pickle.load(f)
    old_questions.append(pubmed2_q_embed)

with open("../deduplicated_embeddings/QAs/pubmed2_answer_embeddings.pkl", "rb") as f:
    pubmed2_a_embed = pickle.load(f)
    old_answers.append(pubmed2_a_embed)

with open("../deduplicated_embeddings/QAs/pubmed3_question_embeddings.pkl", "rb") as f:
    pubmed3_q_embed = pickle.load(f)
    old_questions.append(pubmed3_q_embed)

with open("../deduplicated_embeddings/QAs/pubmed3_answer_embeddings.pkl", "rb") as f:
    pubmed3_a_embed = pickle.load(f)
    old_answers.append(pubmed3_a_embed)

# medmcqa
with open("../deduplicated_embeddings/QAs/medmcqa_train_question_embeddings.pkl", "rb") as f:
    medmcqa_train_q_embed = pickle.load(f)
    old_questions.append(medmcqa_train_q_embed)

with open("../deduplicated_embeddings/QAs/medmcqa_train_answer_embeddings.pkl", "rb") as f:
    medmcqa_train_a_embed = pickle.load(f)
    old_answers.append(medmcqa_train_a_embed)

with open("../deduplicated_embeddings/QAs/medmcqa_dev_question_embeddings.pkl", "rb") as f:
    medmcqa_dev_q_embed = pickle.load(f)
    old_questions.append(medmcqa_dev_q_embed)

with open("../deduplicated_embeddings/QAs/medmcqa_dev_answer_embeddings.pkl", "rb") as f:
    medmcqa_dev_a_embed = pickle.load(f)
    old_answers.append(medmcqa_dev_a_embed)

with open("../deduplicated_embeddings/QAs/medmcqa_test_question_embeddings.pkl", "rb") as f:
    medmcqa_test_q_embed = pickle.load(f)
    old_questions.append(medmcqa_test_q_embed)

with open("../deduplicated_embeddings/QAs/medmcqa_test_answer_embeddings.pkl", "rb") as f:
    medmcqa_test_a_embed = pickle.load(f)
    old_answers.append(medmcqa_test_a_embed)

with open("../deduplicated_embeddings/QAs/medqa_train_question_embeddings.pkl", "rb") as f:
    medqa_train_a_embed = pickle.load(f)
    old_questions.append(medqa_train_a_embed)

with open("../deduplicated_embeddings/QAs/medqa_train_answer_embeddings.pkl", "rb") as f:
    medqa_train_a_embed = pickle.load(f)
    old_answers.append(medqa_train_a_embed)

with open("../deduplicated_embeddings/QAs/medqa_dev_question_embeddings.pkl", "rb") as f:
    medqa_dev_a_embed = pickle.load(f)
    old_questions.append(medqa_dev_a_embed)

with open("../deduplicated_embeddings/QAs/medqa_dev_answer_embeddings.pkl", "rb") as f:
    medqa_dev_a_embed = pickle.load(f)
    old_answers.append(medqa_dev_a_embed)

with open("../deduplicated_embeddings/QAs/medqa_test_question_embeddings.pkl", "rb") as f:
    medqa_test_a_embed = pickle.load(f)
    old_questions.append(medqa_test_a_embed)

with open("../deduplicated_embeddings/QAs/medqa_test_answer_embeddings.pkl", "rb") as f:
    medqa_test_a_embed = pickle.load(f)
    old_answers.append(medqa_test_a_embed)

with open("../deduplicated_embeddings/QAs/trec_train1_question_embeddings.pkl", "rb") as f:
    trec_train1_q_embed = pickle.load(f)
    old_questions.append(trec_train1_q_embed)

with open("../deduplicated_embeddings/QAs/trec_train1_answer_embeddings.pkl", "rb") as f:
    trec_train1_a_embed = pickle.load(f)
    old_answers.append(trec_train1_a_embed)

with open("../deduplicated_embeddings/QAs/trec_train2_question_embeddings.pkl", "rb") as f:
    trec_train2_q_embed = pickle.load(f)
    old_questions.append(trec_train2_q_embed)

with open("../deduplicated_embeddings/QAs/trec_train2_answer_embeddings.pkl", "rb") as f:
    trec_train2_a_embed = pickle.load(f)
    old_answers.append(trec_train2_a_embed)

with open("../deduplicated_embeddings/QAs/trec_test_question_embeddings.pkl", "rb") as f:
    trec_test_q_embed = pickle.load(f)
    old_questions.append(trec_test_q_embed)

with open("../deduplicated_embeddings/QAs/trec_test_answer_embeddings.pkl", "rb") as f:
    trec_test_a_embed = pickle.load(f)
    old_answers.append(trec_test_a_embed)

In [13]:
# remove na for all data
def clean_dataframe(df):
    # Ensure "question" and "answer" columns exist and are non-empty
    df["question"] = df["question"].fillna("").astype(str)
    df["answer"] = df["answer"].fillna("").astype(str)

    # Remove rows where "question" or "answer" is an empty string
    df = df[(df["question"].str.strip() != "") & (df["answer"].str.strip() != "")]
    return df.reset_index(drop=True)

cleaned_data = {}
for k in tqdm(data,desc = "Cleaning data"):
    cleaned_data[k] = clean_dataframe(data[k])

Cleaning data: 100%|██████████| 11268/11268 [00:21<00:00, 519.08it/s]


In [14]:
deduplicated_dict = {}

for k in tqdm(data, desc = "Deduplicating"):
    full_deduplicated_dataset, q_to_remove, a_to_remove = deduplication_within_dataset_qa(data[k])
    deduplicated_dict[k] = full_deduplicated_dataset

Deduplicating: 100%|██████████| 11268/11268 [09:04<00:00, 20.69it/s]


In [15]:
# save data
for k in tqdm(deduplicated_dict, desc = "Saving data"):
    subdir = k.split('/')[-2]
    if not os.path.exists(os.path.join("../deduplicated_data/MedQuAD", subdir)):
        os.makedirs(os.path.join("../deduplicated_data/MedQuAD", subdir))
    deduplicated_dict[k].to_csv(os.path.join("../deduplicated_data/MedQuAD", subdir, k.split('/')[-1]), index = False)

Saving data: 100%|██████████| 11268/11268 [01:06<00:00, 170.43it/s]
