In [16]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("./mlm_bert")

tokenizer.tokenize('и шэджагъуэ дыгъапIэр къохьэлъэкIыу хуабэ хъуми')

['и',
 'шэджагъуэ',
 'ды',
 '##гъа',
 '##пIэр',
 'къохьэ',
 '##лъэ',
 '##кIыу',
 'хуабэ',
 'хъуми']

In [17]:
import torch
from transformers import BartConfig, BartForConditionalGeneration, Trainer, TrainingArguments, BertModel

# Определение конфигурации модели BART
config = BartConfig(
    vocab_size=10000,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_ffn_dim=3072,
    encoder_attention_heads=16,
    decoder_layers=8,
    decoder_ffn_dim=3072,
    decoder_attention_heads=16,
    activation_function='gelu',
    d_model=384,
    dropout=0.1,
    use_cache=True,
    pad_token_id=0,
    bos_token_id=1,
    eos_token_id=2,
)

# Создание модели BART с пустыми весами
model = BartForConditionalGeneration(config)

# Создание модели BART с весами BERT
# bert_model = BertModel.from_pretrained('./mlm_bert')
# model.model.encoder.embed_tokens.weight = bert_model.embeddings.word_embeddings.weight
# model.model.shared.weight = bert_model.embeddings.word_embeddings.weight

In [18]:
import random


def generate_typo(word):
    typo_type = random.choice(['swap', 'delete', 'insert'])
    if len(word) < 2:
        typo_type = 'insert'

    if typo_type == 'swap':
        idx = random.randint(0, len(word) - 2)
        word = word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]

    elif typo_type == 'delete':
        idx = random.randint(0, len(word) - 1)
        word = word[:idx] + word[idx + 1:]

    elif typo_type == 'insert':
        idx = random.randint(0, len(word))
        alphabet = (
            'АаБбВвГгДдЕеЁёЖжЗзИиЙйКкЛлМмНнОоПпРрСсТтУуФфХхЦцЧчШшЩщЪъЫыЬьЭэЮюЯяIi1'
            '-.,:; -!?–…«»1234567890)(№*×><'
            'IIIьььъъъi111'
        )
        random_letter = random.choice(alphabet)
        word = word[:idx] + random_letter + word[idx:]

    return word


def generate_similar_char_error(word):
    similar_letters = {
        'п': 'II',
        'пI': 'тШ',
        'гы': 'гЫ',
        'жы': 'жь',
        'шы': 'шь',
        'П': 'ТТ',
        'Ш': 'III',
        'ш': 'III',
        'ПI': 'ПIГ',
        'жэ': 'жо',
        'пэ': 'пы',
        'жь': 'жъ',
        'ий': 'нй',
        'пс': 'лс',
        'эм': 'эи',
        'щ': 'шщ',
        'къ': 'кь',
        'Къ': 'Жъ',
        'пл': 'нл',
        'им': 'нм',
        'ти': 'тн',
        'гъщ': 'гъц',
        'хуи': 'хун',
        'щх': 'шх',
    }
    for key, value in similar_letters.items():
        if key in word:
            word = word.replace(key, value)
            break
    return word


def generate_grammatical_suffix_error(word):
    if word.endswith('къым'):
        return word[:-4] + 'кым'
    elif word.endswith('мкIэ'):
        return word[:-4] + 'мкэ'
    elif word.endswith('ым'):
        return word[:-2] + 'ып'
    return word


def generate_grammatical_prefix_error(word):
    if word.startswith('зэры'):
        return 'зари' + word[4:]
    elif word.startswith('къых'):
        return 'кыху' + word[4:]
    return word


def gen_incorr_word(word):
    incorrect_word = word

    for func in [
        generate_grammatical_prefix_error, 
        generate_grammatical_suffix_error, 
        generate_similar_char_error,
        generate_typo
    ]:
        incorrect_word = func(incorrect_word)
        if incorrect_word != word:
            return incorrect_word

    return incorrect_word

In [19]:
import pandas as pd

words_df = pd.read_csv('../data/processed/word_freqs/freq_1000000_oshhamaho.csv')
# words_df = words_df[words_df['freq'] > 10]

words_df['incorrect'] = words_df['word'].apply(gen_incorr_word)
words_df

Unnamed: 0,word,freq,incorrect
0,",",428272,"б,"
1,.,297212,.о
2,и,133661,№и
3,–,92993,-–
4,«,38086,б«
...,...,...,...
483333,уихьэжын,1,уихьэжьн
483334,дутIыпщыжри,1,дутIыIIщыжри
483335,скIэригъэкIыркъым,1,скIэригъэкIыркым
483336,Таксир,1,Тасир


In [20]:
# Токенизация данных
train_encodings = tokenizer(words_df['incorrect'].tolist(), truncation=True, padding=True)
train_labels = tokenizer(words_df['word'].tolist(), truncation=True, padding=True)

In [21]:
train_labels[100].ids

[2,
 995,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [22]:
# Создание датасета
class GrammarCorrectionDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels['input_ids'][idx])
        return item

    def __len__(self):
        return len(self.labels['input_ids'])


dataset = GrammarCorrectionDataset(train_encodings, train_labels)

In [23]:
from sklearn.model_selection import train_test_split

# Разбиение dataset на train_dataset и eval_dataset
train_dataset, eval_dataset = train_test_split(dataset, test_size=0.005, random_state=42)

In [24]:
train_dataset[100]

{'input_ids': tensor([   2,  545,  124,  657, 4260,  115,    3,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'attention

In [25]:
# Определение аргументов обучения
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=16,
    logging_steps=100,
    eval_steps=100,
    save_steps=5000,
    evaluation_strategy='steps',
    learning_rate=1e-6,
    # fp16=True,
    save_total_limit=3,
    # gradient_accumulation_steps=4,
)

# Создание объекта Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

In [26]:
# Обучение модели
trainer.train()

Step,Training Loss,Validation Loss
100,9.2135,9.162376


KeyboardInterrupt: 

In [None]:
# Сохранение модели
model.save_pretrained('grammar_correction_model_bart')
tokenizer.save_pretrained('grammar_correction_model_bart')

In [224]:
model.to(torch.device("cpu"))


def correct_grammar(sentence):
    # Токенизация входного предложения
    print(tokenizer.tokenize(sentence))

    input_ids = tokenizer.encode(sentence, return_tensors='pt')
    print(input_ids)
    
    # Генерация исправленного предложения
    output_ids = model.generate(input_ids, max_length=128, num_beams=4, min_length=20, early_stopping=True)
    # output_ids = model.generate(input_ids, max_length=128)
    print(output_ids)

    # Декодирование сгенерированного предложения
    corrected_sentence = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return corrected_sentence

In [225]:
correct_grammar('IIсалъэ')

['I', '##I', '##са', '##лъэ']
tensor([[  2,  25, 102, 708, 216,   3]])
tensor([[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 2]])


''