In [22]:
import torch
from torch.utils.data import DataLoader, Dataset
from sentence_transformers import SentenceTransformer, losses, InputExample
import pandas as pd
from torch.utils.data import DataLoader
import pandas as pd
import os
import random

In [25]:
def load_medical_qa_dataset(n_samples=None):
    
    files = os.listdir('dataset/processed_data')

    # Read the files into a dataframe
    for idx, file in enumerate(files):
        if idx == 0:
            df = pd.read_csv('dataset/processed_data/' + file, na_values=['', ' ', 'No information found.'])
        else:
            df = pd.concat([df, pd.read_csv('dataset/processed_data/' + file, na_values=['', ' ', 'No information found.'])], ignore_index=True)

    # Drop rows with missing values
    df = df.dropna()

    # Select a subset of the data if n_samples is specified
    if n_samples:
        df = df.sample(n_samples)

    return df


df = load_medical_qa_dataset(n_samples=None)

df

Unnamed: 0,question,question_id,question_type,answer
0,What is (are) A guide to clinical trials for c...,0000001-1,information,"If you have cancer, a clinical trial may be an..."
5,Do you have information about A guide to herba...,0000002-1,information,Herbal remedies are plants used like a medicin...
6,What is (are) A1C test ?,0000003-1,information,A1C is a lab test that shows the average level...
8,What is (are) Aarskog syndrome ?,0000004-1,information,Aarskog syndrome is a very rare disease that a...
9,What causes Aarskog syndrome ?,0000004-2,causes,Aarskog syndrome is a genetic disorder that is...
...,...,...,...,...
47242,What is (are) Parasites - Zoonotic Hookworm ?,0000440-1,information,"There are many different species of hookworms,..."
47243,Who is at risk for Parasites - Zoonotic Hookwo...,0000440-2,susceptibility,Dog and cat hookworms are found throughout the...
47244,How to diagnose Parasites - Zoonotic Hookworm ?,0000440-5,exams and tests,Cutaneous larva migrans (CLM) is a clinical di...
47245,What are the treatments for Parasites - Zoonot...,0000440-6,treatment,The zoonotic hookworm larvae that cause cutane...


In [12]:
df.question_type.unique()

array(['susceptibility', 'treatment', 'symptoms', 'precautions',
       'dietary', 'indication', 'prevention', 'side effects',
       'genetic changes', 'information', 'causes', 'other information',
       'storage and disposal', 'inheritance', 'outlook',
       'exams and tests', 'usage', 'frequency', 'brand names',
       'complications', 'when to contact a medical professional',
       'support groups', 'emergency or overdose', 'how effective is it',
       'brand names of combination products',
       'interactions with medications',
       'interactions with herbs and supplements', 'stages',
       'how does it work'], dtype=object)

In [14]:
class QADataset(Dataset):
    def __init__(self, questions, answers, question_type):
        self.questions = questions
        self.answers = answers
        self.question_type = question_type

        self.label_idx = {'susceptibility': 0, 
                          'treatment': 1,
                           'symptoms': 2, 
                           'precautions': 3,
                           'dietary': 4,
                            'indication': 5,
                            'prevention': 6,
                            'side effects': 7,
                            'genetic changes': 8,
                            'information': 9,
                            'causes': 10,
                            'other information': 11,
                            'storage and disposal': 12,
                            'inheritance': 13,
                            'outlook': 14,
                            'exams and tests': 15,
                            'usage': 16,
                            'frequency': 17,
                            'brand names': 18,
                            'complications': 19,
                            'when to contact a medical professional': 20,
                            'considerations': 21,
                            'forget a dose': 22,
                            'research': 23,
                            'important warning': 24,
                            'support groups': 25,
                            'emergency or overdose': 26,
                            'how effective is it': 27,
                            'brand names of combination products': 28,
                            'interactions with medications': 29,
                            'interactions with herbs and supplements': 30,
                            'stages': 31,
                            'how does it work': 32}
        
     
    
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        return InputExample(texts=[self.questions[idx], self.answers[idx]], label=self.label_idx[self.question_type[idx]])
    

# Create a dataset from the dataframe
dataset = QADataset(df['question'].values, df['answer'].values, df['question_type'].values)

# Ver ejemplo aleatorio
n = random.randint(0, len(dataset))
print(dataset[n])


<InputExample> label: 5, texts: Who should get Rifabutin and why is it prescribed ?; Rifabutin helps to prevent or slow the spread of Mycobacterium avium complex disease (MAC; a bacterial infection that may cause serious symptoms) in patients with human immunodeficiency virus (HIV) infection. It is also used in combination with other medications to eliminate  H. pylori , a bacteria that causes ulcers. Rifabutin is in a class of medications called antimycobacterials. It works by killing the bacteria that cause infection. Antibiotics such as rifabutin will not work for colds, flu, or other viral infections. Using antibiotics when they are not needed increases your risk of getting an infection later that resists antibiotic treatment.


In [21]:
from sentence_transformers import SentenceTransformer, losses, models
import torch
from torch.utils.data import DataLoader

# Carga el modelo
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# Asume 'dataset' es una lista de InputExample
loader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=model.smart_batching_collate)

# Define la función de pérdida
train_loss = losses.CosineSimilarityLoss(model=model)

# Configuración del optimizador
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

# Ciclo de entrenamiento
model.train()
for epoch in range(1):  # Ajusta el número de épocas según sea necesario
    for batch in loader:
        optimizer.zero_grad()

        print(batch)
        
        
        # Las entradas ya están en el formato correcto gracias a smart_batching_collate
        features, labels = batch
        # features contiene dos listas: una para cada texto en los InputExamples
        sentences = [features[0], features[1]]

        # Obtiene las incrustaciones de las oraciones
        embeddings = model(sentences[0], sentences[1])

        # Calcula la pérdida
        loss = train_loss(embeddings[0], embeddings[1], labels)
        loss.backward()
        optimizer.step()
        print(f"Epoch: {epoch}, Loss: {loss.item()}")

# Para guardar el modelo, utiliza model.save()
model.save("path_to_save_your_finetuned_model")
print("Modelo fine-tuneado guardado")


([{'input_ids': tensor([[  101,  2129,  2000,  4652, 17419,  2594, 27641,  1029,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
        [  101,  2054,  2003,  1006,  2024,  1007,  5939,  7352,  3170, 10092,
          2828,  1015,  1029,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
        [  101,  2054,  2024,  1996,  8030,  1997, 11265,  8458, 29166, 12419,
          2791, 24471,  3981,  2854, 12859,  3617, 15451, 14192,  3370,  1029,
           102,     0,     0],
        [  101,  2054,  2003,  1996,  2895,  1997, 19395, 11253, 14973,  2378,
          1998,  2129,  2515,  2009,  2147,  1029,   102,     0,     0,     0,
             0,     0,     0],
        [  101,  2054,  2024,  1996,  8030,  1997, 27281,  3536,  8715,  1029,
           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0],
        [  101,  2054,  2003,  1006,  2

TypeError: Sequential.forward() takes 2 positional arguments but 3 were given

In [25]:
# Carga el modelo
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

loader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=model.smart_batching_collate)

# Define la loss function
train_loss = losses.CosineSimilarityLoss(model)

# Configuración del optimizador
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

# Ciclo de entrenamiento
model.train()
for epoch in range(1):  # Ajusta el número de épocas según sea necesario
    for batch in loader:
        optimizer.zero_grad()
        texts = [example.texts for example in batch]
        texts = list(map(list, zip(*texts)))  # Transpone la lista de textos
        encodings = model.encode(texts, convert_to_tensor=True, show_progress_bar=False)
        loss = train_loss(encodings[0], encodings[1])
        loss.backward()
        optimizer.step()
        print(f"Loss: {loss.item()}")

# Guarda el modelo fine-tuneado
# model.save("path_to_save_your_finetuned_model")
print("Modelo fine-tuneado guardado")


AttributeError: 'list' object has no attribute 'texts'