In [1]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader
from torchmetrics import F1Score, Precision, Accuracy
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
from transformers import BertTokenizer, BertForSequenceClassification, BertModel
from peft import get_peft_model, LoraConfig, TaskType
import pytorch_lightning as pl
import tqdm
from torch.nn import functional as F
from pytorch_lightning.loggers import TensorBoardLogger
from torch.nn.utils.rnn import pad_sequence
from huggingface_hub.hf_api import HfFolder
from pytorch_lightning.callbacks import ModelCheckpoint
from random import shuffle
HfFolder.save_token('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')

In [2]:
import logging
logging.disable(logging.WARNING)

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
miracl_ru = load_dataset('miracl/miracl', 'ru')

In [5]:
miracl_en = load_dataset('miracl/miracl', 'en')

In [6]:
miracl_id = load_dataset('miracl/miracl', 'id')

In [7]:
miracl_es = load_dataset('miracl/miracl', 'es')

In [8]:
miracl_fi = load_dataset('miracl/miracl', 'fi')

In [98]:
#training set:
train_data = []

for data in miracl_ru['train']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        train_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        train_data.append(temp_dict)

for data in miracl_en['train']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        train_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        train_data.append(temp_dict)

for data in miracl_es['train']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        train_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        train_data.append(temp_dict)

for data in miracl_fi['train']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        train_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        train_data.append(temp_dict)

for data in miracl_id['train']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        train_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        train_data.append(temp_dict)

shuffle(train_data)

In [99]:
class QADataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        question = self.data[idx]['question']
        answer = self.data[idx]['answer']
        label = self.data[idx]['label']

        # Токенизируем вопрос и ответ отдельно
        question_inputs = self.tokenizer(
            question, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors="pt"
        )
        
        answer_inputs = self.tokenizer(
            answer, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors="pt"
        )

        # Убираем лишнее измерение
        question_inputs = {key: val.squeeze(0) for key, val in question_inputs.items()}
        answer_inputs = {key: val.squeeze(0) for key, val in answer_inputs.items()}

        # Возвращаем словарь с правильными ключами
        return {
            'question_input_ids': question_inputs['input_ids'],
            'question_attention_mask': question_inputs['attention_mask'],
            'answer_input_ids': answer_inputs['input_ids'],
            'answer_attention_mask': answer_inputs['attention_mask'],
            'labels': torch.tensor(label, dtype=torch.long)
        }


In [103]:
class DualEncoderModel(pl.LightningModule):
    def __init__(self, model_name='bert-base-multilingual-cased', learning_rate=1e-5, lora_r=128, lora_alpha=128):
        super().__init__()
        self.learning_rate = learning_rate
        self.threshold = 0.6

        # Инициализация энкодеров вопросов и ответов (BERT)
        self.question_encoder = BertModel.from_pretrained(model_name)
        self.answer_encoder = BertModel.from_pretrained(model_name)

        # Настройка LoRA для обоих энкодеров
        lora_config = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,  # Задача на извлечение признаков
            r=lora_r,
            lora_alpha=lora_alpha,
            bias="none"
        )
        self.question_encoder = get_peft_model(self.question_encoder, lora_config)
        self.answer_encoder = get_peft_model(self.answer_encoder, lora_config)

        # Метрики
        self.f1_metric = F1Score(num_classes=2, task='binary')
        self.precision_metric = Precision(num_classes=2, task='binary')
        self.accuracy_metric = Accuracy(num_classes=2, task='binary')

    def forward(self, question_inputs, answer_inputs):
        # Кодирование вопросов и ответов с использованием энкодеров
        question_embeddings = self.question_encoder(
            input_ids=question_inputs['input_ids'],
            attention_mask=question_inputs['attention_mask'],
            return_dict=True
        ).pooler_output  # Используем pooler_output для представления

        answer_embeddings = self.answer_encoder(
            input_ids=answer_inputs['input_ids'],
            attention_mask=answer_inputs['attention_mask'],
            return_dict=True
        ).pooler_output

        # Нормализуем эмбеддинги для вычисления косинусного сходства
        question_embeddings = F.normalize(question_embeddings, p=2, dim=1)
        answer_embeddings = F.normalize(answer_embeddings, p=2, dim=1)

        # Вычисляем косинусное сходство между вопросом и ответом
        cosine_similarity = torch.matmul(question_embeddings, answer_embeddings.T)
        return cosine_similarity
    
    def training_step(self, batch, batch_idx):
        # Входные данные для вопросов и ответов
        question_inputs = {
            'input_ids': batch['question_input_ids'],
            'attention_mask': batch['question_attention_mask']
        }
        answer_inputs = {
            'input_ids': batch['answer_input_ids'],
            'attention_mask': batch['answer_attention_mask']
        }
        labels = batch['labels']
    
        # Предсказания модели (косинусное сходство)
        similarity_scores = self(question_inputs, answer_inputs)
    
        # Переводим косинусное сходство в логиты для бинарной классификации (0 или 1)
        logits = similarity_scores.diag()  # Берем диагональ, т.к. это правильные пары
        pos_weight = torch.tensor([2]).to(logits.device)  # Штрафуем сильнее за неправильные ответы
        loss = F.binary_cross_entropy_with_logits(logits, labels.float(), pos_weight=pos_weight)

    
        # Логируем метрики
        preds = torch.where(torch.sigmoid(logits) >= self.threshold, 1, 0)
        
        accuracy = self.accuracy_metric(preds, labels)
        self.log('train_accuracy', accuracy, prog_bar=True, logger=True)

        f1 = self.f1_metric(preds, labels)
        self.log('train_f1', f1, prog_bar=True, logger=True)
        
        precision = self.accuracy_metric(preds, labels)
        self.log('train_precision', precision, prog_bar=True, logger=True)
    
        # Логируем лосс
        self.log('train_loss', loss, prog_bar=True, logger=True)
    
        return loss
            
    def validation_step(self, batch, batch_idx):
        # Входные данные для вопросов и ответов
        question_inputs = {
            'input_ids': batch['question_input_ids'],
            'attention_mask': batch['question_attention_mask']
        }
        answer_inputs = {
            'input_ids': batch['answer_input_ids'],
            'attention_mask': batch['answer_attention_mask']
        }
        labels = batch['labels']
    
        # Предсказания модели
        similarity_scores = self(question_inputs, answer_inputs)
        logits = similarity_scores.diag()
        loss = F.binary_cross_entropy_with_logits(logits, labels.float())
    
        # Логирование лосса
        self.log('val_loss', loss, prog_bar=True, logger=True)
    
        # Логирование метрик
        preds = torch.where(torch.sigmoid(logits) >= self.threshold, 1, 0)
        f1 = self.f1_metric(preds, labels)
        precision = self.precision_metric(preds, labels)
    
        self.log('val_f1', f1, prog_bar=True, logger=True)
        self.log('val_precision', precision, prog_bar=True, logger=True)
    
        return loss
    def configure_optimizers(self):
        # Оптимизатор AdamW для обучения модели
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer


In [104]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
train_dataset = QADataset(train_data, tokenizer, max_length=128)
train_data_loader = DataLoader(train_dataset, batch_size=128)

In [105]:
# Инициализируем логгер TensorBoard
logger = TensorBoardLogger("DE_logs", name="qa_model")
checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')

trainer = pl.Trainer(
    max_epochs=20,
    logger=logger,
    callbacks=[checkpoint_callback],
    accelerator="gpu"
)

model = DualEncoderModel()


In [114]:
import gc
torch.cuda.empty_cache()
gc.collect()


0

In [109]:
trainer.fit(model, train_data_loader)

Training: |          | 0/? [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 192.00 MiB. GPU 

In [110]:
#training set:
val_data = []

for data in miracl_ru['dev']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        val_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        val_data.append(temp_dict)

for data in miracl_en['dev']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        val_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        val_data.append(temp_dict)

for data in miracl_es['dev']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        val_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        val_data.append(temp_dict)

for data in miracl_fi['dev']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        val_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        val_data.append(temp_dict)

for data in miracl_id['dev']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        val_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        val_data.append(temp_dict)

shuffle(val_data)

In [111]:
val_dataset = QADataset(val_data, tokenizer, max_length=128)
val_data_loader = DataLoader(val_dataset, batch_size=128)

In [115]:
trainer.validate(model, val_data_loader)

Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.6229941248893738,
  'val_f1': 0.44277188181877136,
  'val_precision': 0.5141499638557434}]

In [None]:
!zip -r QA_logs.zip QA_logs