https://www.dialog-21.ru/media/5806/gusevi112.pdf

https://arxiv.org/pdf/2105.08206.pdf

In [1]:
pwd

'/home/demid-vm/home/demid-vm/test_dir'

### Tagger inference

In [3]:
import torch
from transformers import AutoTokenizer, pipeline

tagger_model_name = "IlyaGusev/rubertconv_toxic_editor"

device = "cuda" if torch.cuda.is_available() else "cpu"
device_num = 0 if device == "cuda" else -1
tagger_tokenizer = AutoTokenizer.from_pretrained(tagger_model_name)
tagger_pipe = pipeline(
    "token-classification",
    model=tagger_model_name,
    tokenizer=tagger_model_name,
    framework="pt",
    device=device_num,
    aggregation_strategy="max"
)

In [4]:
text = "Ёпта, меня зовут придурок и я живу в жопе"
tagger_predictions = tagger_pipe([text], batch_size=1)
sample_predictions = tagger_predictions[0]
print(sample_predictions)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[{'entity_group': 'delete', 'score': 0.57686865, 'word': 'Ёпта,', 'start': 0, 'end': 5}, {'entity_group': 'equal', 'score': 0.9623802, 'word': 'меня зовут', 'start': 6, 'end': 16}, {'entity_group': 'replace', 'score': 0.69702196, 'word': 'придурок', 'start': 17, 'end': 25}, {'entity_group': 'equal', 'score': 0.86356926, 'word': 'и я живу', 'start': 26, 'end': 34}, {'entity_group': 'replace', 'score': 0.8265923, 'word': 'в жопе', 'start': 35, 'end': 41}]


### Template building

In [5]:
template = []
for group in sample_predictions:
    tag = group["entity_group"]
    phrase = group["word"]
    pad_index = phrase.find(tagger_tokenizer.pad_token)
    if pad_index != -1:
        phrase = phrase[:pad_index]
    if tag == "delete":
        continue
    if tag == "replace":
        phrase = tagger_tokenizer.mask_token
    template.append(phrase.strip())
template = " ".join(template)
print(template)

меня зовут [MASK] и я живу [MASK]


In [6]:
MASK_TEMPLATE = " <extra_id_{}> "

def convert_template_to_t5(template, orig_mask_token):
    current_pos = 0
    mask_pos = template.find(orig_mask_token, current_pos)
    mask_num = 0
    while mask_pos != -1:
        end_mask_pos = mask_pos + len(orig_mask_token)
        template = template[:mask_pos] + MASK_TEMPLATE.format(mask_num) + template[end_mask_pos:]
        template = " ".join(template.split())
        current_pos = end_mask_pos
        mask_pos = template.find(orig_mask_token, current_pos)
        mask_num += 1
    return template

template = convert_template_to_t5(template, tagger_tokenizer.mask_token)
print(template)

меня зовут <extra_id_0> и я живу <extra_id_1>


### Template filling

In [7]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

device = "cuda" if torch.cuda.is_available() else "cpu"
gen_model_name = "IlyaGusev/sber_rut5_filler"
gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name).to(device)

In [8]:
input_ids = gen_tokenizer(
    text,
    text_pair=template,
    add_special_tokens=True,
    max_length=200,
    padding="max_length",
    truncation=True,
    return_tensors="pt",
).input_ids.to(gen_model.device)

output_ids = gen_model.generate(
    input_ids=input_ids,
    num_beams=5,
    num_return_sequences=1,
    max_length=300,
    repetition_penalty=2.5
)
output_ids = output_ids[0]
fillers = gen_tokenizer.decode(output_ids, skip_special_tokens=True)
print(fillers)

mask_count = template.count("extra_id")
target = template
for mask_num in range(mask_count):
    current_mask = MASK_TEMPLATE.format(mask_num).strip()
    next_mask = MASK_TEMPLATE.format(mask_num + 1).strip()
    start_index = fillers.find(current_mask) + len(current_mask)
    end_index = fillers.find(next_mask)
    filler = fillers[start_index:end_index]
    target = target.replace(current_mask, filler)
target = " ".join(target.split())
target = target.replace(" ,", ",")
print(target)

<extra_id_0> нехороший человек <extra_id_1> в беде <extra_id_2>
меня зовут нехороший человек и я живу в беде


In [9]:
target

'меня зовут нехороший человек и я живу в беде'

In [10]:
import pandas as pd
from tqdm import tqdm

In [11]:
import gc
gc.collect()

9

In [12]:
df = pd.read_csv('test.tsv.txt', sep='\t')
toxic_inputs = df['toxic_comment'].tolist()

### Inference for toxic inputs

In [13]:
detox_list = []

for text in tqdm(toxic_inputs):

    tagger_predictions = tagger_pipe([text], batch_size=1)
    sample_predictions = tagger_predictions[0]

    template = []
    for group in sample_predictions:
        tag = group["entity_group"]
        phrase = group["word"]
        pad_index = phrase.find(tagger_tokenizer.pad_token)
        if pad_index != -1:
            phrase = phrase[:pad_index]
        if tag == "delete":
            continue
        if tag == "replace":
            phrase = tagger_tokenizer.mask_token
        template.append(phrase.strip())
    template = " ".join(template)

    MASK_TEMPLATE = " <extra_id_{}> "

    template = convert_template_to_t5(template, tagger_tokenizer.mask_token)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    gen_model_name = "IlyaGusev/sber_rut5_filler"
    gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
    gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name).to(device)

    input_ids = gen_tokenizer(
        text,
        text_pair=template,
        add_special_tokens=True,
        max_length=200,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    ).input_ids.to(gen_model.device)

    output_ids = gen_model.generate(
        input_ids=input_ids,
        num_beams=5,
        num_return_sequences=1,
        max_length=300,
        repetition_penalty=2.5
    )
    output_ids = output_ids[0]
    fillers = gen_tokenizer.decode(output_ids, skip_special_tokens=True)

    mask_count = template.count("extra_id")
    target = template
    for mask_num in range(mask_count):
        current_mask = MASK_TEMPLATE.format(mask_num).strip()
        next_mask = MASK_TEMPLATE.format(mask_num + 1).strip()
        start_index = fillers.find(current_mask) + len(current_mask)
        end_index = fillers.find(next_mask)
        filler = fillers[start_index:end_index]
        target = target.replace(current_mask, filler)
    target = " ".join(target.split())
    target = target.replace(" ,", ",")
    detox_list.append(target)

100%|█████████████████████████████████████████████████████████████████████████████████| 875/875 [55:05<00:00,  3.78s/it]


In [16]:
with open('output/rubertconf_toxic_1.txt', 'w') as file:
    file.writelines([sentence+'\n' for sentence in detox_list])