In [None]:
def controlled_generation(prompt, threshold=0.7, max_len=50):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    output = model.generate(
        input_ids=input_ids,
        max_length=max_len+len(input_ids[0]),
        do_sample=True,
        top_k=0,
        top_p=0.9,
        temperature=1,
        no_repeat_ngram_size=2,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    prompt_embedding = model.transformer.wte(input_ids).mean(dim=1)

    for i in range(len(generated_text)):
        next_token = tokenizer.encode(generated_text[i], return_tensors='pt')
        next_token_embedding = model.transformer.wte(next_token).mean(dim=1)
        cosine_sim = torch.nn.functional.cosine_similarity(prompt_embedding, next_token_embedding, dim=-1)

        if cosine_sim < threshold:
            continue
        else:
            while cosine_sim >= threshold and len(generated_text) < max_len:
                next_token = model.generate(
                    input_ids=output[:, :-1],
                    max_length=output.shape[-1] + 1,
                    do_sample=True,
                    top_k=0,
                    top_p=0.9,
                    temperature=1,
                    no_repeat_ngram_size=2,
                    num_return_sequences=1,
                    pad_token_id=tokenizer.eos_token_id
                )[:, -1]
                next_token_embedding = model.transformer.wte(next_token).mean(dim=1)
                cosine_sim = torch.nn.functional.cosine_similarity(prompt_embedding, next_token_embedding, dim=-1)

                output = torch.cat([output, next_token.unsqueeze(0)], dim=-1)
                generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return generated_text

In [None]:
from telebot import TeleBot
import codecs
from collections import defaultdict
import datetime
import os
import sys
from transformers import HfArgumentParser
from dataclasses import dataclass, field

from utils.gen_answer_persona import GenAnswerPersona
from utils.arguments_inference import InteractionArguments


# Should have no parent classes with fields without defaults (if one big `args` needed)
@dataclass
class TelegaArguments(InteractionArguments):
    bot_token: str = field(default="", metadata={"help": "Bot token."})

# Описываем парсер аргументов командной строки
parser = HfArgumentParser(TelegaArguments)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
    args = parser.parse_args_into_dataclasses()[0]

if args.tokenizer_name is None:
    args.tokenizer_name = args.model_checkpoint

log_file = "log_file_" + args.bot_token[-5:] + ".log"
log_fail_file = "log_fail_file.log"

bot = TeleBot(args.bot_token, threaded=False)

Chat = GenAnswerPersona(args=args)

act_rep = defaultdict(list)
win = args.max_history


# Обработчик команд
@bot.message_handler(commands=["len_context"])
def handle_intent_list(message):
    global win
    w = len(Chat.history)
    st = f"Длина контекста: {w}"
    bot.send_message(message.chat.id, st)


@bot.message_handler(commands=["start"])
def handle_intent_list(message):
    st = "Привет. Давай поболтаем о чем-нибудь"
    bot.send_message(message.chat.id, st)
    

@bot.message_handler(commands=["del_context"])
def handle_intent_list(message):
    st = "Контекст удален"
    bot.send_message(message.chat.id, st)
    global act_rep
    act_rep[message.chat.id] = []


@bot.message_handler(commands=["context"])
def handle_intent_list(message):
    global act_rep
    bot.send_message(message.chat.id, "Context: " + " || ".join(act_rep[message.chat.id]))


@bot.message_handler(content_types=["text"])
def handle_text(message):
    global act_rep
    context = message.text
    act_rep[message.chat.id].append(context)
    chat_id_history = act_rep[message.chat.id]

    bot.send_chat_action(message.chat.id, "typing")

    # answers, add_info = Chat.get_reply_batched(input_data, persona_code=args.persona_code)
    # answer = answers[0]
    

    act_rep[message.chat.id].append(answer)

    with codecs.open(log_file, "a", encoding="utf-8") as f:
        f.write(
            str(datetime.date.today().strftime("%Y-%m-%d"))
            + ";"
            + str(message.from_user.id)
            + ";"
            + str(message.chat.id)
            + ";"
            + str(message.message_id)
            + ";"
            + context
            + ";"
            + message.text
            + ";"
            + answer
            + "\n",
        )

    bot.send_message(
        message.chat.id,
        text=answer,
    )


if __name__ == "__main__":
    VMAX = 0
    while VMAX <= 100:
        try:
            bot.polling(none_stop=True, interval=0, timeout=2000)
        except:
            VMAX += 1
            print("Error")
