In [1]:
#librerias necesarias
import os
import pandas as pd
import numpy as np
import faiss
from faiss import write_index
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from config import CFG

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def read_processed_data(with_na=False, n_samples=None):
    # List the files in the processed_data directory
    files = os.listdir('dataset/processed_data')

    # Read the files into a dataframe
    for idx, file in enumerate(files):
        if idx == 0:
            df = pd.read_csv('dataset/processed_data/' + file, na_values=['', ' ', 'No information found.'])
        else:
            df = pd.concat([df, pd.read_csv('dataset/processed_data/' + file, na_values=['', ' ', 'No information found.'])], ignore_index=True)
    
    if not with_na:
        df = df.dropna()

    if n_samples is not None:
        df = df.sample(n_samples)

    return df

df = read_processed_data(with_na = CFG.with_na, n_samples=CFG.n_samples)

In [3]:
df

Unnamed: 0,question,question_id,question_type,answer
0,What is (are) A guide to clinical trials for c...,0000001-1,information,"If you have cancer, a clinical trial may be an..."
5,Do you have information about A guide to herba...,0000002-1,information,Herbal remedies are plants used like a medicin...
6,What is (are) A1C test ?,0000003-1,information,A1C is a lab test that shows the average level...
8,What is (are) Aarskog syndrome ?,0000004-1,information,Aarskog syndrome is a very rare disease that a...
9,What causes Aarskog syndrome ?,0000004-2,causes,Aarskog syndrome is a genetic disorder that is...
...,...,...,...,...
47242,What is (are) Parasites - Zoonotic Hookworm ?,0000440-1,information,"There are many different species of hookworms,..."
47243,Who is at risk for Parasites - Zoonotic Hookwo...,0000440-2,susceptibility,Dog and cat hookworms are found throughout the...
47244,How to diagnose Parasites - Zoonotic Hookworm ?,0000440-5,exams and tests,Cutaneous larva migrans (CLM) is a clinical di...
47245,What are the treatments for Parasites - Zoonot...,0000440-6,treatment,The zoonotic hookworm larvae that cause cutane...


In [4]:
class TextDataset(Dataset):
    def __init__(self, df):# Input is a pandas dataframe
        self.questions = df.question.tolist()
        self.question_ids = df.question_id.tolist()
        self.question_types = df.question_type.tolist()
        self.answers = df.answer.tolist()
        
    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        return {'Q': self.questions[idx], 
                'Q_id': self.question_ids[idx], 
                'Q_T': self.question_types[idx], 
                'A': self.answers[idx]}
    

def collate_fn(batch, tokenizer=AutoTokenizer.from_pretrained(CFG.embedding_model)):
    # Extrae las preguntas de los elementos del batch
    questions = [item['Q'] for item in batch]
    
    # Tokeniza las preguntas en un lote
    tokenized_questions = tokenizer(
        questions,
        return_tensors='pt',
        truncation=True,
        padding=True,
        max_length=512
    )
    
    # No hay necesidad de usar pad_sequence aquí, ya que tokenizer maneja el padding
    return {
        "input_ids": tokenized_questions['input_ids'],
        "attention_mask": tokenized_questions['attention_mask']
    }



def get_bert_embeddings(ds, batch_size=CFG.batch_size):
    dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    model = AutoModel.from_pretrained(CFG.embedding_model)
    model = model.to(CFG.device)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids = batch['input_ids'].to(CFG.device)
            attention_mask = batch['attention_mask'].to(CFG.device)
            outputs = model(input_ids, attention_mask)
            last_hidden_state = outputs.last_hidden_state
            cls_embedding = last_hidden_state[:, 0, :]
            embeddings.append(cls_embedding.cpu().numpy())
    return np.concatenate(embeddings)


documents = TextDataset(df)

In [5]:
# Función para crear el índice FAISS
def create_faiss_index(embeddings):
  dimension = embeddings.shape[1]
  index = faiss.IndexFlatL2(dimension)
  index.add(embeddings)
  return index

embeddings = get_bert_embeddings(documents, CFG.batch_size)
index = create_faiss_index(embeddings)# Crea el índice FAISS con los embeddings

100%|██████████| 359/359 [00:35<00:00, 10.21it/s]


In [6]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [7]:
# Función para obtener los embeddings de una consulta de texto
def get_query_embedding(query_text, device = CFG.device):
    tokenizer = AutoTokenizer.from_pretrained(CFG.embedding_model)
    model = AutoModel.from_pretrained(CFG.embedding_model).to(device)
    inputs = tokenizer(query_text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    query_embedding = outputs.last_hidden_state.mean(1).squeeze().cpu().numpy()
    return query_embedding

# Ejemplo de consulta
query_text = "abdominal mass ?"
query_embedding = get_query_embedding(query_text)
query_vector = np.expand_dims(query_embedding, axis=0)
# print(f'Query vector: {query_vector}')
# Realiza la búsqueda en el índice FAISS
D, I = index.search(query_vector, k=10)  # Busca los 5 documentos más similares

print("Consulta realizada con el texto:", query_text)
print("Documentos más similares:")
for i, idx in enumerate(I[0], start=1):
    print(f"{i}: Documento {idx} con una distancia de {D[0][i-1]}")
    print("Contenido del documento:", documents[idx]['Q'])
    print("Respuesta obtenida:", documents[idx]['A'].replace('\n', ' ') )
    # Devolver el answer de cada pregunta

Consulta realizada con el texto: abdominal mass ?
Documentos más similares:
1: Documento 40 con una distancia de 9.776694297790527
Contenido del documento: What causes Abdominal mass ?
Respuesta obtenida: Several conditions can cause an abdominal mass: Abdominal aortic aneurysm  can cause a pulsating mass around the navel.,  Bladder distention  (urinary bladder over-filled with fluid) can cause a firm mass in the center of the lower abdomen above the pelvic bones. In extreme cases, it can reach as far up as the navel.,  Cholecystitis  can cause a very tender mass that is felt below the liver in the right-upper quadrant (occasionally).,  Colon cancer  can cause a mass almost anywhere in the abdomen.,  Crohn disease  or bowel obstruction can cause many tender, sausage-shaped masses anywhere in the abdomen.,  Diverticulitis  can cause a mass that is usually located in the left-lower quadrant., Gallbladder tumor can cause a tender, irregularly shaped mass in the right-upper quadrant.,  Hyd