<a href="https://colab.research.google.com/github/pavelpryadokhin/Transformer-BERT-GPT/blob/main/BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Давайте рассмотрим набор данных SQuAD и обучим модель BERT отвечать на вопросы в контексте.

#BERT SQuAD

In [None]:
import torch
import json
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
import re
import string
import collections
from transformers import BertTokenizerFast, BertForQuestionAnswering
from transformers.tokenization_utils_base import BatchEncoding
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# Отключим мешаюшие предупреждения
import warnings
warnings.filterwarnings("ignore")

pd.set_option('max_colwidth', 500)
%matplotlib inline

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')



#Предобработка данных

Загрузим тренировочный датасет

In [None]:
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json

--2025-02-05 13:44:43--  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.111.153, 185.199.108.153, 185.199.109.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.111.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 42123633 (40M) [application/json]
Saving to: ‘train-v2.0.json’


2025-02-05 13:44:44 (202 MB/s) - ‘train-v2.0.json’ saved [42123633/42123633]



In [None]:
def squad_load_from_json(json_file_path,shape=2000):
    with open(json_file_path, "r") as f:
        json_data = json.load(f)['data']
        questions = []
        answers = []
        corpuses = []
        for category in json_data[:shape]:
            for paragraph in category['paragraphs']:
                context = paragraph['context']
                for qa in paragraph['qas']:
                    question = qa['question']
                    ans_list = qa['plausible_answers'] if qa['is_impossible'] else qa['answers']
                    ans_tuple = None
                    for  ans in ans_list:
                        ans_tuple = (ans['answer_start'], ans['answer_start']+len(ans['text']), ans['text'])
                    if ans_tuple:
                        if ans_tuple[1]<200: # доп. проверка, так как будем обрезать тексты
                            corpuses.append(context)
                            questions.append(question)
                            answers.append(ans_tuple)

        df = pd.DataFrame(data={'question':questions, 'answer':answers, 'corpus':corpuses})
        return df.sample(n=min(shape, len(df)), random_state=7).reset_index(drop=True)

In [None]:
train_dataset = squad_load_from_json("train-v2.0.json")
train_dataset.head()

Unnamed: 0,question,answer,corpus
0,By how much did the top tier increase minimum wage in December 2007?,"(66, 78, 12.5 percent)","The Sichuan government raised the minimum wage in the province by 12.5 percent at the end of December 2007. The monthly minimum wage went up from 400 to 450 yuan, with a minimum of 4.9 yuan per hour for part-time work, effective 26 December 2007. The government also reduced the four-tier minimum wage structure to three. The top tier mandates a minimum of 650 yuan per month, or 7.1 yuan per hour. National law allows each province to set minimum wages independently, but with a floor of 450 yua..."
1,"What store was founded in Paris, in 1883, which sold a wide variety of products?","(22, 35, Au Bon Marché)","A novelty shop called Au Bon Marché had been founded in Paris in 1838 to sell lace, ribbons, sheets, mattresses, buttons, umbrellas and other assorted goods. It originally had four departments, twelve employees, and a floor space of three hundred meters. The entrepreneur Aristide Boucicaut became a partner in 1852, and changed the marketing plan, instituting fixed prices and guarantees that allowed exchanges and refunds, advertising, and a much wider variety of merchandise. The annual income..."
2,"What are gender identity, ethnic identity, and occupational identity aspects of?","(178, 198, one's total identity)","The inclusiveness of Weinreich's definition (above) directs attention to the totality of one's identity at a given phase in time, and assists in elucidating component aspects of one's total identity, such as one's gender identity, ethnic identity, occupational identity and so on. The definition readily applies to the young child, to the adolescent, to the young adult, and to the older adult in various phases of the life cycle. Depending on whether one is a young child or an adult at the heig..."
3,From which word meaning anointed one does Christos originate?,"(27, 38, Christianos)","The Greek word Χριστιανός (Christianos), meaning ""follower of Christ"", comes from Χριστός (Christos), meaning ""anointed one"", with an adjectival ending borrowed from Latin to denote adhering to, or even belonging to, as in slave ownership. In the Greek Septuagint, christos was used to translate the Hebrew מָשִׁיחַ (Mašíaḥ, messiah), meaning ""[one who is] anointed."" In other European languages, equivalent words to Christian are likewise derived from the Greek, such as Chrétien in French and C..."
4,Where did the 1896 Cincinnati Red Stockings relocate to?,"(187, 193, Boston)","Major League Baseball is especially well known for red teams. The Cincinnati Red Stockings are the oldest professional baseball team, dating back to 1869. The franchise soon relocated to Boston and is now the Atlanta Braves, but its name survives as the origin for both the Cincinnati Reds and Boston Red Sox. During the 1950s when red was strongly associated with communism, the modern Cincinnati team was known as the ""Redlegs"" and the term was used on baseball cards. After the red scare faded..."


In [None]:
# validation_dataset = squad_load_from_json("dev-v2.0.json",4000)
# validation_dataset.head()

#Токенизация

In [None]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Сначала определим токены начала и конца ответа

In [None]:
def calculate_tokenized_ans_indices(dataset,max_length=128):
    ans_tok_start = []
    ans_tok_end = []
    ans_tok_text = []

    for idx, (ans_text_start, ans_text_end, ans_text) in enumerate(dataset['answer'].values):
        encoding = tokenizer.encode_plus(
            text=dataset['corpus'].values[idx],
            text_pair=dataset['question'].values[idx],
            max_length=max_length,
            padding='max_length',
            truncation=True)

        # Получаем индексы токенов для начала и конца ответа
        ans_start = encoding.char_to_token(0, ans_text_start)
        ans_end = encoding.char_to_token(0, ans_text_end - 1)

        # Обработка случаев, когда ответ обрезан
        if ans_start is None:
            ans_start = ans_end = tokenizer.model_max_length
        elif ans_end is None:
            ans_end = min(encoding['input_ids'].index(tokenizer.sep_token_id),max_length-1)

        # Конвертация токенов обратно в строку
        ans_text_tok =  tokenizer.convert_tokens_to_string(
            tokenizer.convert_ids_to_tokens(encoding['input_ids'][ans_start:ans_end + 1]))

        ans_tok_start.append(ans_start)
        ans_tok_end.append(ans_end)
        ans_tok_text.append(ans_text_tok)

    dataset['ans_start_tok'] = ans_tok_start
    dataset['ans_end_tok'] = ans_tok_end
    dataset['ans_tok_text'] = ans_tok_text

    return dataset

In [None]:
train_dataset = calculate_tokenized_ans_indices(train_dataset)
train_dataset.head()

Unnamed: 0,question,answer,corpus,ans_start_tok,ans_end_tok,ans_tok_text
0,By how much did the top tier increase minimum wage in December 2007?,"(66, 78, 12.5 percent)","The Sichuan government raised the minimum wage in the province by 12.5 percent at the end of December 2007. The monthly minimum wage went up from 400 to 450 yuan, with a minimum of 4.9 yuan per hour for part-time work, effective 26 December 2007. The government also reduced the four-tier minimum wage structure to three. The top tier mandates a minimum of 650 yuan per month, or 7.1 yuan per hour. National law allows each province to set minimum wages independently, but with a floor of 450 yua...",12,15,12. 5 percent
1,"What store was founded in Paris, in 1883, which sold a wide variety of products?","(22, 35, Au Bon Marché)","A novelty shop called Au Bon Marché had been founded in Paris in 1838 to sell lace, ribbons, sheets, mattresses, buttons, umbrellas and other assorted goods. It originally had four departments, twelve employees, and a floor space of three hundred meters. The entrepreneur Aristide Boucicaut became a partner in 1852, and changed the marketing plan, instituting fixed prices and guarantees that allowed exchanges and refunds, advertising, and a much wider variety of merchandise. The annual income...",5,7,au bon marche
2,"What are gender identity, ethnic identity, and occupational identity aspects of?","(178, 198, one's total identity)","The inclusiveness of Weinreich's definition (above) directs attention to the totality of one's identity at a given phase in time, and assists in elucidating component aspects of one's total identity, such as one's gender identity, ethnic identity, occupational identity and so on. The definition readily applies to the young child, to the adolescent, to the young adult, and to the older adult in various phases of the life cycle. Depending on whether one is a young child or an adult at the heig...",43,47,one ' s total identity
3,From which word meaning anointed one does Christos originate?,"(27, 38, Christianos)","The Greek word Χριστιανός (Christianos), meaning ""follower of Christ"", comes from Χριστός (Christos), meaning ""anointed one"", with an adjectival ending borrowed from Latin to denote adhering to, or even belonging to, as in slave ownership. In the Greek Septuagint, christos was used to translate the Hebrew מָשִׁיחַ (Mašíaḥ, messiah), meaning ""[one who is] anointed."" In other European languages, equivalent words to Christian are likewise derived from the Greek, such as Chrétien in French and C...",13,14,christianos
4,Where did the 1896 Cincinnati Red Stockings relocate to?,"(187, 193, Boston)","Major League Baseball is especially well known for red teams. The Cincinnati Red Stockings are the oldest professional baseball team, dating back to 1869. The franchise soon relocated to Boston and is now the Atlanta Braves, but its name survives as the origin for both the Cincinnati Reds and Boston Red Sox. During the 1950s when red was strongly associated with communism, the modern Cincinnati team was known as the ""Redlegs"" and the term was used on baseball cards. After the red scare faded...",33,33,boston


In [None]:
# validation_dataset = calculate_tokenized_ans_indices(validation_dataset)
# validation_dataset.head()

In [None]:
# train_dataset.drop(index=36523, inplace=True)

In [None]:
assert all(0 <= pos < 128 for pos in train_dataset['ans_start_tok'])
assert all(0 <= pos < 128 for pos in train_dataset['ans_end_tok'])
# assert all(0 <= pos < 128 for pos in validation_dataset['ans_start_tok'])
# assert all(0 <= pos < 128 for pos in validation_dataset['ans_end_tok'])

Токенизируем датасет

In [None]:
class SQuAD_Dataset(Dataset):
    def __init__(self, data: pd.DataFrame):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        encoding = tokenizer(text=self.data['corpus'].values[idx], text_pair=self.data['question'].values[idx], max_length=128, padding='max_length', truncation=True, return_tensors='pt').to(device)
        return {
            'ans_start_tok': torch.tensor(self.data['ans_start_tok'].values[idx], dtype=torch.long, device=device),
            'ans_end_tok': torch.tensor(self.data['ans_end_tok'].values[idx], dtype=torch.long, device=device),
            'input_ids': encoding['input_ids'][0],
            'attention_mask': encoding['attention_mask'][0],
            'token_type_ids': encoding['token_type_ids'][0]}

Подготовим данные для обучения

In [None]:
BATCH_SIZE=4

train_squad_dataset = SQuAD_Dataset(data=train_dataset)
train_dataloader = DataLoader(train_squad_dataset, batch_size=BATCH_SIZE)

#Обучение

In [None]:
bert_model =  BertForQuestionAnswering.from_pretrained('bert-base-uncased').to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
def train(model, train_dataloader, learning_rate = 1e-5, epochs = 2):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    train_losses = []

    for epoch in range(epochs):
        model.train()
        t_losses = []
        for batch in tqdm(train_dataloader):
            optimizer.zero_grad()
            output = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'],
                        token_type_ids=batch['token_type_ids'], start_positions=batch['ans_start_tok'],
                        end_positions=batch['ans_end_tok'])

            loss = output[0]
            loss.backward()
            optimizer.step()
            t_losses.append(loss.item())

        train_loss = np.mean(t_losses)
        train_losses.append(train_loss)
        print(f"Epoch {epoch:3}: Loss = {train_loss:.5f}")

In [None]:
train(bert_model, train_dataloader)

100%|██████████| 500/500 [59:30<00:00,  7.14s/it]


Epoch   0: Loss = 2.63779


100%|██████████| 500/500 [58:39<00:00,  7.04s/it]


Epoch   1: Loss = 1.40802


In [None]:
TEXT = "Alexander Sergeyevich Pushkin (1799–1837) was a Russian poet, novelist, and playwright, regarded as the founder of modern Russian literature. He is best known for his works such as 'Eugene Onegin','Ruslan and Lyudmila' and numerous lyrical poems. Pushkin had a significant influence on subsequent generations of writers and poets, and his work continues to be studied and valued for its artistic expressiveness and depth of thought."
def chat(question,text=TEXT):
    input = tokenizer(question, text,   max_length=512, padding='max_length', truncation=True, return_tensors='pt').to(device)
    output = bert_model(**input)
    start_index = output.start_logits.argmax()
    end_index = output.end_logits.argmax() + 1
    answer = tokenizer.decode(input.input_ids[0, start_index:end_index])
    # answer = tokenizer.decode(input.input_ids[0, end_index:start_index])
    print(answer)



In [None]:
chat( "Who is Alexander Pushkin and what is he known for?")

( 1799 – 1837 ) was a russian poet,


In [None]:
chat('Which poet is mentioned in the text?')

alexander sergeyevich pushkin


#Резюме

Мы обучили модель BERT отвечать на вопросы в контексте. К сожалению, из-за нехватки вычислительных ресурсов нам пришлось ограничить набор данных для обучения. Кроме того, модель была обучена всего за две эпохи. Тем не менее, модель уже демонстрирует хорошие результаты в ответах на лёгкие вопросы.