In [None]:
!pip install gradio sentencepiece

In [None]:
import gradio as gr
import json
import re
import torch
from transformers import GPT2Tokenizer, T5ForConditionalGeneration
from IPython.display import IFrame

In [None]:
# device = "cuda:0"
device = "cpu"
HOST_IP = "192.168.31.167"
GRADIO_PORT = 7860

## FRED-T5-large-FT

In [None]:
# path = "/home/jovyan/wdc1/models/FRED-T5-large"
path = "/home/jovyan/models/3_fred-t5/checkpoint-11000"
path = "/home/jovyan/models/7_fred-t5-large/checkpoint-35000"
tokenizer = GPT2Tokenizer.from_pretrained(path, eos_token='</s>')
model = T5ForConditionalGeneration.from_pretrained(path).to(device)

## ruT5-base-FT

In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
path = "/home/jovyan/models/8_ruT5-base/checkpoint-17000/"
model = T5ForConditionalGeneration.from_pretrained(path).to(device)
tokenizer = T5Tokenizer.from_pretrained(path)
tokenizer.add_tokens("\n")

## Common code then

In [None]:
def predict(text):
    input_ids = torch.tensor([tokenizer.encode(text)]).to(device)
    with torch.no_grad():
        outputs = model.generate(input_ids, max_new_tokens=50, eos_token_id=tokenizer.eos_token_id, early_stopping=True)
    return tokenizer.decode(outputs[0][1:])


predict("<SC1>Было у отца [3]<extra_id_0> сына, но не было даже [2- 3]<extra_id_1> пиджаков с блёстками за [142 990]<extra_id_2> руб.")

In [None]:
with open("examples.json") as f:
    test_examples = json.load(f)

In [None]:
# re_tokens = re.compile(r"[а-яА-Я]+\s*|\d+(?:\.\d+)?\s*|[^а-яА-Я\d\s]+\s*")
re_tokens = re.compile(r"(?:[.,!?]|[а-яА-Я]\S*|\d\S*(?:\.\d+)?|[^а-яА-Я\d\s]+)\s*")


def tokenize(text):
    return re.findall(re_tokens, text)


def strip_numbers(s):
    result = []
    for part in s.split():
        if part.isdigit():
            while len(part) > 3:
                result.append(part[:- 3 * ((len(part) - 1) // 3)])
                part = part[- 3 * ((len(part) - 1) // 3):]
            if part:
                result.append(part)
        else:
            result.append(part)
    return " ".join(result)


def construct_prompt(text):
    result = "<SC1>"
    etid = 0
    token_to_add = ""
    for token in tokenize(text) + [""]:
        if not re.search("[a-zA-Z\d]", token):
            if token_to_add:
                end_match = re.search(r"(.+?)(\W*)$", token_to_add, re.M).groups()
                result += f"[{strip_numbers(end_match[0])}]<extra_id_{etid}>{end_match[1]}"
                etid += 1
                token_to_add = ""
            result += token
        else:
            token_to_add += token
    return result


construct_prompt('я купил iphone 12X за 142 990 руб без 3-x часов 12:00, и т.д.')

In [None]:
def construct_answer(prompt:str, prediction:str) -> str:
    replaces = []
    re_prompt = re.compile(r"\[([^\]]+)\]<extra_id_(\d+)>")
    re_pred = re.compile(r"\<extra_id_(\d+)\>(.+?)(?=\<extra_id_\d+\>|</s>)")
    pred_data = {}
    for match in re.finditer(re_pred, prediction.replace("\n", " ")):
        pred_data[match[1]] = match[2].strip()
    while match := re.search(re_prompt, prompt):
        replace = pred_data.get(match[2], match[1])
        prompt = prompt[:match.span()[0]] + replace + prompt[match.span()[1]:]
    return prompt.replace("<SC1>", "")
        
construct_answer(
    '<SC1>Было у отца [3]<extra_id_0> сына. Старшему было [35]<extra_id_1>, среднему - не меньше [33]<extra_id_2>, а младший на [4]<extra_id_3> младше всех. Бывает.',
    """<extra_id_0>  три
 <extra_id_1>  тридцать пять
 <extra_id_2>  тридцати трех
 <extra_id_3>  четыре
</s>"""
)

In [None]:
def norm(message, history):
    prompt = construct_prompt(message)
    yield f"```Prompt:\n{prompt}\nPrediction:\n...```\n..."
    prediction = predict(prompt)
    answer = construct_answer(prompt, prediction)
    yield f"Prompt:\n```{prompt}```\nPrediction:\n```\n{prediction}\n```\n{answer}"


demo = gr.ChatInterface(norm, stop_btn=None, examples=list(test_examples.keys())).queue()
demo.launch(inline=False, server_name="0.0.0.0", server_port=GRADIO_PORT, inbrowser=True)
IFrame(src=f"http://{HOST_IP}:{GRADIO_PORT}", width='100%', height='500px')

In [None]:
# found bad results with batch generation on encoder-decoder architectures surprisingly so one by one here
for lm_text, gt in test_examples.items():
    prompt = construct_prompt(lm_text)
    prediction = predict(prompt)
    answer = construct_answer(prompt, prediction)
    if gt == answer:
        print(f"{gt}\n")
    else:
        print(f"{lm_text}\n{prompt}\n{gt}\n{answer}\n{prediction}\n")