In [12]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader
from torchmetrics import F1Score, Precision
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
from transformers import BertTokenizer, BertForSequenceClassification
from peft import get_peft_model, LoraConfig, TaskType
import pytorch_lightning as pl
import tqdm
from torch.nn import functional as F
from pytorch_lightning.loggers import TensorBoardLogger
from torch.nn.utils.rnn import pad_sequence
from huggingface_hub.hf_api import HfFolder
from pytorch_lightning.callbacks import ModelCheckpoint
from random import shuffle
from .config import secret_key
HfFolder.save_token(secret_key)

In [7]:
import logging
logging.disable(logging.WARNING)

In [8]:
import warnings
warnings.filterwarnings('ignore')

In [9]:
miracl_ru = load_dataset('miracl/miracl', 'ru')

In [2]:
miracl_en = load_dataset('miracl/miracl', 'en')

Generating dev split: 0 examples [00:00, ? examples/s]

Loading dataset shards:   0%|          | 0/28 [00:00<?, ?it/s]

Generating testB split: 0 examples [00:00, ? examples/s]

Loading dataset shards:   0%|          | 0/28 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Loading dataset shards:   0%|          | 0/28 [00:00<?, ?it/s]

Generating testA split: 0 examples [00:00, ? examples/s]

Loading dataset shards:   0%|          | 0/28 [00:00<?, ?it/s]

In [3]:
miracl_id = load_dataset('miracl/miracl', 'id')

(…)-id/topics/topics.miracl-v1.0-id-dev.tsv:   0%|          | 0.00/42.0k [00:00<?, ?B/s]

(…).0-id/qrels/qrels.miracl-v1.0-id-dev.tsv:   0%|          | 0.00/177k [00:00<?, ?B/s]

(…)/topics/topics.miracl-v1.0-id-test-b.tsv:   0%|          | 0.00/24.8k [00:00<?, ?B/s]

(…)d/topics/topics.miracl-v1.0-id-train.tsv:   0%|          | 0.00/175k [00:00<?, ?B/s]

(…)-id/qrels/qrels.miracl-v1.0-id-train.tsv:   0%|          | 0.00/758k [00:00<?, ?B/s]

(…)/topics/topics.miracl-v1.0-id-test-a.tsv:   0%|          | 0.00/33.8k [00:00<?, ?B/s]

Generating dev split: 0 examples [00:00, ? examples/s]

docs-0.jsonl.gz:   0%|          | 0.00/68.5M [00:00<?, ?B/s]

docs-1.jsonl.gz:   0%|          | 0.00/39.7M [00:00<?, ?B/s]

docs-2.jsonl.gz:   0%|          | 0.00/61.4M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating testB split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating testA split: 0 examples [00:00, ? examples/s]

In [4]:
miracl_es = load_dataset('miracl/miracl', 'es')

(…)-es/topics/topics.miracl-v1.0-es-dev.tsv:   0%|          | 0.00/39.6k [00:00<?, ?B/s]

(…).0-es/qrels/qrels.miracl-v1.0-es-dev.tsv:   0%|          | 0.00/158k [00:00<?, ?B/s]

(…)/topics/topics.miracl-v1.0-es-test-b.tsv:   0%|          | 0.00/92.4k [00:00<?, ?B/s]

(…)s/topics/topics.miracl-v1.0-es-train.tsv:   0%|          | 0.00/131k [00:00<?, ?B/s]

(…)-es/qrels/qrels.miracl-v1.0-es-train.tsv:   0%|          | 0.00/526k [00:00<?, ?B/s]

Generating dev split: 0 examples [00:00, ? examples/s]

Downloading data:   0%|          | 0/21 [00:00<?, ?files/s]

docs-0.jsonl.gz:   0%|          | 0.00/92.9M [00:00<?, ?B/s]

docs-1.jsonl.gz:   0%|          | 0.00/89.6M [00:00<?, ?B/s]

docs-2.jsonl.gz:   0%|          | 0.00/85.4M [00:00<?, ?B/s]

docs-3.jsonl.gz:   0%|          | 0.00/83.3M [00:00<?, ?B/s]

docs-4.jsonl.gz:   0%|          | 0.00/80.8M [00:00<?, ?B/s]

docs-5.jsonl.gz:   0%|          | 0.00/79.4M [00:00<?, ?B/s]

docs-6.jsonl.gz:   0%|          | 0.00/79.2M [00:00<?, ?B/s]

docs-7.jsonl.gz:   0%|          | 0.00/78.5M [00:00<?, ?B/s]

docs-8.jsonl.gz:   0%|          | 0.00/77.6M [00:00<?, ?B/s]

docs-9.jsonl.gz:   0%|          | 0.00/66.3M [00:00<?, ?B/s]

docs-10.jsonl.gz:   0%|          | 0.00/62.6M [00:00<?, ?B/s]

docs-11.jsonl.gz:   0%|          | 0.00/65.7M [00:00<?, ?B/s]

docs-12.jsonl.gz:   0%|          | 0.00/72.6M [00:00<?, ?B/s]

docs-13.jsonl.gz:   0%|          | 0.00/75.5M [00:00<?, ?B/s]

docs-14.jsonl.gz:   0%|          | 0.00/76.5M [00:00<?, ?B/s]

docs-15.jsonl.gz:   0%|          | 0.00/73.0M [00:00<?, ?B/s]

docs-16.jsonl.gz:   0%|          | 0.00/74.2M [00:00<?, ?B/s]

docs-17.jsonl.gz:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

docs-18.jsonl.gz:   0%|          | 0.00/77.9M [00:00<?, ?B/s]

docs-19.jsonl.gz:   0%|          | 0.00/79.9M [00:00<?, ?B/s]

docs-20.jsonl.gz:   0%|          | 0.00/57.2M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating testB split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [5]:
miracl_fi = load_dataset('miracl/miracl', 'fi')

(…)-fi/topics/topics.miracl-v1.0-fi-dev.tsv:   0%|          | 0.00/58.1k [00:00<?, ?B/s]

(…).0-fi/qrels/qrels.miracl-v1.0-fi-dev.tsv:   0%|          | 0.00/226k [00:00<?, ?B/s]

(…)/topics/topics.miracl-v1.0-fi-test-b.tsv:   0%|          | 0.00/33.5k [00:00<?, ?B/s]

(…)i/topics/topics.miracl-v1.0-fi-train.tsv:   0%|          | 0.00/130k [00:00<?, ?B/s]

(…)-fi/qrels/qrels.miracl-v1.0-fi-train.tsv:   0%|          | 0.00/382k [00:00<?, ?B/s]

(…)/topics/topics.miracl-v1.0-fi-test-a.tsv:   0%|          | 0.00/47.9k [00:00<?, ?B/s]

Generating dev split: 0 examples [00:00, ? examples/s]

docs-0.jsonl.gz:   0%|          | 0.00/79.7M [00:00<?, ?B/s]

docs-1.jsonl.gz:   0%|          | 0.00/71.2M [00:00<?, ?B/s]

docs-2.jsonl.gz:   0%|          | 0.00/68.0M [00:00<?, ?B/s]

docs-3.jsonl.gz:   0%|          | 0.00/51.1M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating testB split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating testA split: 0 examples [00:00, ? examples/s]

In [13]:
#training set:
train_data = []

for data in miracl_ru['train']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        train_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        train_data.append(temp_dict)

for data in miracl_en['train']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        train_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        train_data.append(temp_dict)

for data in miracl_es['train']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        train_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        train_data.append(temp_dict)

for data in miracl_fi['train']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        train_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        train_data.append(temp_dict)

for data in miracl_id['train']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        train_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        train_data.append(temp_dict)

shuffle(train_data)

In [23]:
class QADataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        question = self.data[idx]['question']
        answer = self.data[idx]['answer']
        label = self.data[idx]['label']

        inputs = self.tokenizer(
            question, answer, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors="pt"
        )

        inputs = {key: val.squeeze(0) for key, val in inputs.items()}  # убираем лишнюю размерность
        inputs['labels'] = torch.tensor(label, dtype=torch.long)
        return inputs


In [24]:
class QAClassificationModel(pl.LightningModule):
    def __init__(self, model_name='bert-base-multilingual-cased', learning_rate=1e-5):
        super().__init__()
        self.learning_rate = learning_rate
        self.save_hyperparameters()
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR

        # Инициализируем mBERT
        self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

        # Настройка LoRA
        lora_config = LoraConfig(
            task_type=TaskType.SEQ_CLS,
            r=32,
            lora_alpha=24,
            bias="all"
        )
        self.model = get_peft_model(self.model, lora_config)

        # Инициализируем F1-метрику для бинарной классификации
        self.f1_metric = F1Score(num_classes=2, average='macro', task='binary')
        self.precision_metric = Precision(num_classes=2, task='binary')

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)

    def training_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch['input_ids'], 
            attention_mask=batch['attention_mask'], 
            token_type_ids=batch['token_type_ids'], 
            labels=batch['labels']
        )
        loss = outputs.loss
        
        
        self.log('train_loss', loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch['input_ids'], 
            attention_mask=batch['attention_mask'], 
            token_type_ids=batch['token_type_ids'], 
            labels=batch['labels']
        )
        loss = outputs.loss

        # Предсказания и реальные значения для расчета метрики
        preds = torch.argmax(outputs.logits, dim=1)
        labels = batch['labels']
        
        self.log('val_loss', loss, prog_bar=True, logger=True)
        
        f1 = self.f1_metric(preds, labels)
        self.log('val_f1', f1, prog_bar=True, logger=True)
        precision = self.precision_metric(preds, labels)
        self.log('val_precision', precision, prog_bar=True, logger=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)


In [25]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
train_dataset = QADataset(train_data, tokenizer, max_length=128)
train_data_loader = DataLoader(train_dataset, batch_size=128)


In [26]:
# Инициализируем логгер TensorBoard
logger = TensorBoardLogger("QA_logs", name="qa_model")
checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')

trainer = pl.Trainer(
    max_epochs=25,
    logger=logger,
    callbacks=[checkpoint_callback],
    accelerator="gpu"
)

model = QAClassificationModel()


In [27]:
trainer.fit(model, train_data_loader)

Training: |          | 0/? [00:00<?, ?it/s]

In [28]:
#training set:
val_data = []
for data in miracl_ru['dev']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        val_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        val_data.append(temp_dict)

for data in miracl_en['dev']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        val_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        val_data.append(temp_dict)

for data in miracl_es['dev']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        val_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        val_data.append(temp_dict)

for data in miracl_id['dev']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        val_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        val_data.append(temp_dict)

for data in miracl_fi['dev']:
    for neg in data['negative_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = neg['text']
        temp_dict['label'] = 0
        val_data.append(temp_dict)
    for pos in data['positive_passages']:
        temp_dict = {}
        temp_dict['question'] = data['query']
        temp_dict['answer'] = pos['text']
        temp_dict['label'] = 1
        val_data.append(temp_dict)

shuffle(val_data)

In [29]:
val_dataset = QADataset(val_data, tokenizer, max_length=128)
val_data_loader = DataLoader(val_dataset, batch_size=128)

In [30]:
trainer.validate(model, val_data_loader)

Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.46312230825424194,
  'val_f1': 0.5759850740432739,
  'val_precision': 0.6667466759681702}]

In [31]:
!zip -r QA_logs.zip QA_logs

  adding: QA_logs/ (stored 0%)
  adding: QA_logs/qa_model/ (stored 0%)
  adding: QA_logs/qa_model/version_0/ (stored 0%)
  adding: QA_logs/qa_model/version_0/events.out.tfevents.1728953375.rentserver.2914.2 (deflated 68%)
  adding: QA_logs/qa_model/version_0/hparams.yaml (deflated 3%)
  adding: QA_logs/qa_model/version_0/events.out.tfevents.1728976319.rentserver.2914.3 (deflated 28%)
