In [1]:
%pwd
%cd ..
%pwd

/Users/agustinlopez/Repositories/TrialGPT


'/Users/agustinlopez/Repositories/TrialGPT'

In [2]:
import ipykernel
print(ipykernel.__version__)

6.28.0


In [3]:
import json
import os
from rank_bm25 import BM25Okapi
import tqdm
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
import faiss
from beir.datasets.data_loader import GenericDataLoader
from nltk import word_tokenize
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)


def preprocess_entry(entry):
    title = entry.get("title", "")
    text = entry.get("text", "")

    if not isinstance(title, str):
        title = str(title)

    if not isinstance(text, str):
        text = str(text)

    return title, text


def process_bm25_chunk(chunk, tokenized_corpus, corpus_nctids):
    for entry in chunk:
        corpus_nctids.append(entry["_id"])
        title, _ = preprocess_entry(entry)
        tokens = word_tokenize(title.lower()) * 3

        for disease in entry["metadata"]["diseases_list"]:
            tokens += word_tokenize(disease.lower()) * 2

        tokens += word_tokenize(entry["text"].lower())
        tokenized_corpus.append(tokens)


def process_medcpt_chunk(chunk, embeds, corpus_nctids, tokenizer, model):
    for entry in chunk:
        corpus_nctids.append(entry["_id"])
        title, text = preprocess_entry(entry)

        with torch.no_grad():
            encoded = tokenizer(
                [[title, text]],
                truncation=True,
                padding=True,
                return_tensors='pt',
                max_length=512,
            )
            embed = model(**encoded).last_hidden_state[:, 0, :]
            embeds.append(embed[0].cpu().numpy())


def read_first_n_entries(file_path, n=20):
    entries = []
    with open(file_path, 'r') as file:
        for i, line in enumerate(file):
            if i >= n:
                break
            entries.append(json.loads(line))
    return entries


def get_bm25_corpus_index(corpus):
    corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json")

    if os.path.exists(corpus_path):
        corpus_data = json.load(open(corpus_path))
        tokenized_corpus = corpus_data["tokenized_corpus"]
        corpus_nctids = corpus_data["corpus_nctids"]
    else:
        tokenized_corpus = []
        corpus_nctids = []

        # Leer solo los primeros 20 registros de corpus.jsonl
        first_20_entries = read_first_n_entries(f"dataset/{corpus}/corpus.jsonl", n=20)
        process_bm25_chunk(first_20_entries, tokenized_corpus, corpus_nctids)

        corpus_data = {
            "tokenized_corpus": tokenized_corpus,
            "corpus_nctids": corpus_nctids,
        }

        with open(corpus_path, "w") as f:
            json.dump(corpus_data, f, indent=4)

    bm25 = BM25Okapi(tokenized_corpus)
    return bm25, corpus_nctids


def get_medcpt_corpus_index(corpus):
    corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy"
    nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json"

    if os.path.exists(corpus_path):
        embeds = np.load(corpus_path)
        corpus_nctids = json.load(open(nctids_path))
    else:
        print(f"Building MedCPT corpus index for {corpus}")
        embeds = []
        corpus_nctids = []

        print("Loading MedCPT model and tokenizer...")
        model = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder")
        tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder")
        print("MedCPT model and tokenizer loaded successfully.")

        # Leer solo los primeros 20 registros de corpus.jsonl
        first_20_entries = read_first_n_entries(f"dataset/{corpus}/corpus.jsonl", n=20)
        process_medcpt_chunk(first_20_entries, embeds, corpus_nctids, tokenizer, model)

        embeds = np.array(embeds)
        np.save(corpus_path, embeds)
        with open(nctids_path, "w") as f:
            json.dump(corpus_nctids, f, indent=4)

    index = faiss.IndexFlatIP(768)
    index.add(embeds)
    return index, corpus_nctids


def preprocess_corpus(corpus_texts, max_length=200):
    preprocessed_texts = []
    for text in corpus_texts:
        # Divide el texto en trozos más pequeños si es demasiado largo
        if len(text) > max_length:
            for i in range(0, len(text), max_length):
                preprocessed_texts.append(text[i:i + max_length])
        else:
            preprocessed_texts.append(text)
    return preprocessed_texts

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from tqdm.autonotebook import tqdm

In [4]:
corpus = "metabase"
q_type = "gpt-4-turbo"
k = 20
bm25_wt = 1
medcpt_wt = 1
N = 2000

# Load the qrels
_, _, qrels = GenericDataLoader(data_folder=f"dataset/{corpus}/").load(split="test")

# Load all types of queries
id2queries = json.load(open(f"dataset/{corpus}/id2queries.json"))

# Preprocess the corpus texts
preprocessed_corpus = preprocess_corpus([query[q_type] if q_type in query else "" for query in id2queries.values()])

# Create a new dictionary with preprocessed texts
id2queries_preprocessed = {key: preprocessed_corpus[i] for i, key in enumerate(id2queries.keys())}

# Load the indices
bm25, bm25_nctids = get_bm25_corpus_index(corpus)
medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus)

# Load the query encoder for MedCPT
model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder")
tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")

# Conduct the searches, saving top 1k
output_path = f"results/qid2nctids_results_{q_type}_{corpus}_k{k}_bm25wt{bm25_wt}_medcptwt{medcpt_wt}_N{N}.json"

qid2nctids = {}
recalls = []

with open(f"dataset/{corpus}/queries.jsonl", "r") as f:
    for line in tqdm.tqdm(f.readlines()):
        entry = json.loads(line)
        query = entry["text"]
        qid = entry["_id"]

        if qid not in qrels:
            continue

        truth_sum = sum(qrels[qid].values())

        # get the keyword list
        if q_type in ["raw", "human_summary"]:
            conditions = [id2queries_preprocessed[qid]]
        elif "turbo" in q_type:
            conditions = id2queries[qid][q_type]["conditions"]
        elif "Clinician" in q_type:
            conditions = id2queries[qid].get(q_type, [])

        if len(conditions) == 0:
            nctid2score = {}
        else:
            # a list of nctid lists for the bm25 retriever
            bm25_condition_top_nctids = []

            for condition in conditions:
                tokens = word_tokenize(condition.lower())
                top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=N)
                bm25_condition_top_nctids.append(top_nctids)

            # doing MedCPT retrieval
            with torch.no_grad():
                encoded = tokenizer(
                    conditions,
                    truncation=True,
                    padding=True,
                    return_tensors='pt',
                    max_length=256,
                )

                # encode the queries (use the [CLS] last hidden states as the representations)
                embeds = model(**encoded).last_hidden_state[:, 0, :].cpu().numpy()

                # search the Faiss index
                scores, inds = medcpt.search(embeds, k=N)

            medcpt_condition_top_nctids = []
            for ind_list in inds:
                top_nctids = [medcpt_nctids[ind] for ind in ind_list]
                medcpt_condition_top_nctids.append(top_nctids)

            nctid2score = {}

            for condition_idx, (bm25_top_nctids, medcpt_top_nctids) in enumerate(
                    zip(bm25_condition_top_nctids, medcpt_condition_top_nctids)):

                if bm25_wt > 0:
                    for rank, nctid in enumerate(bm25_top_nctids):
                        if nctid not in nctid2score:
                            nctid2score[nctid] = 0

                        nctid2score[nctid] += (1 / (rank + k)) * (1 / (condition_idx + 1))

                if medcpt_wt > 0:
                    for rank, nctid in enumerate(medcpt_top_nctids):
                        if nctid not in nctid2score:
                            nctid2score[nctid] = 0

                        nctid2score[nctid] += (1 / (rank + k)) * (1 / (condition_idx + 1))

        nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1])
        top_nctids = [nctid for nctid, _ in nctid2score[:N]]
        qid2nctids[qid] = top_nctids

        actual_sum = sum([qrels[qid].get(nctid, 0) for nctid in top_nctids])
        recalls.append(actual_sum / truth_sum)

with open(output_path, "w") as f:
    json.dump(qid2nctids, f, indent=4)

100%|██████████| 137/137 [00:00<00:00, 17695.85it/s]


Building MedCPT corpus index for metabase
Loading MedCPT model and tokenizer...
MedCPT model and tokenizer loaded successfully.


100%|██████████| 7/7 [00:00<00:00, 10.31it/s]
