In [1]:
from transformers import  AutoTokenizer, AutoModelForSeq2SeqLM
from dimweb_persona_bot.dataloaders.seq2seq_samplers.seq2seq_samplers_hypothesis_2 import H2Seq2SeqInferencePersonaSampleV1
from dimweb_persona_bot.hyperparameters.causal_modeling_hyperparameters import H2PersonaChatHyperparametersV1
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "dim/bart-base-15or5dmk"
device = "cuda"

model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path)

hyperparameters = H2PersonaChatHyperparametersV1(
	model_name="facebook/bart-base",
	model_architecture="seq2seq",
	chat_history_pair_length=3,
	persona_max_length=14, 
	chat_max_length=19,
)

In [6]:
tokenizer.encode("<c_sep>")

[0, 50265, 2]

In [27]:
sample = H2Seq2SeqInferencePersonaSampleV1(
	tokenizer=tokenizer,
	hyperparameters=hyperparameters,
	dataset_sample={
		"persona": [
      		'I like chocolate ice cream.', 
        	"Sometimes I feel lonely.", 
        	"I like to play video games."
        ],
		"history": [
    	  	"Hi, do you like ice cream?", 
       		'i do like ice cream but i prefer chocolate ice cream',
			"I feel lonely",
			'i feel lonely when i play video games',
			"Do you wanna play with me? I think we can have fun together.",
    ]
	}
).get_sample()

for key in sample.keys():
    sample[key] = torch.tensor(sample[key]).unsqueeze(0).to(device)
    
# model(**sample)
answer = model.generate(**sample, max_length=20)
tokenizer.batch_decode(answer, skip_special_tokens=True,)

['i do not like ice cream but i love chocolate']

In [3]:
# ['i feel lonely when i play video games']

In [29]:
answer = model.generate(
    **sample, 
    max_length=20, 
    penalty_alpha=0.2, 
    top_k=6
)
tokenizer.batch_decode(answer, skip_special_tokens=True,)

['i think so too. what do you like to eat?']

In [3]:
import random 
class DialogBotV1:
    """
    uses greedy decoding
	"""
    def __init__(self, 
        model, 
        tokenizer, 
        hyperparameters,
        history=None,
        persona=None,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.hyperparameters = hyperparameters
        
        if history is None:
            self.history = []
        self.history = history
        
        if persona is None:
            self.persona = []
        self.persona = persona
    
    def chat(self, message):
        self.history.append(message)
        
        random.shuffle(self.persona)
        sample = H2Seq2SeqInferencePersonaSampleV1(
            tokenizer=self.tokenizer,
            hyperparameters=self.hyperparameters,
            dataset_sample={
                "persona": self.persona,
                "history": self.history,
            }
        ).get_sample()

        for key in sample.keys():
            sample[key] = torch.tensor(sample[key]).unsqueeze(0).to(device)
        
        answer = self.generate_responce(sample)
        answer = self.tokenizer.batch_decode(answer, skip_special_tokens=True,)
        self.history.append(answer[0])
        return answer[0]
    
    def single_chat(self, message):
        random.shuffle(self.persona)
        temp_history = self.history.copy()
        temp_history.append(message)
        
        sample = H2Seq2SeqInferencePersonaSampleV1(
            tokenizer=self.tokenizer,
            hyperparameters=self.hyperparameters,
            dataset_sample={
                "persona": self.persona,
                "history": temp_history,
            }
        ).get_sample()

        for key in sample.keys():
            sample[key] = torch.tensor(sample[key]).unsqueeze(0).to(device)
        
        answer = self.generate_responce(sample)
        answer = self.tokenizer.batch_decode(answer, skip_special_tokens=True,)
        return answer[0]
    
    def generate_responce(self, sample):
        return self.model.generate(**sample, max_length=20)
    
    def start_chat(self):
        while True:
            message = input("You: ")
            if message == "exit":
                break
            answer = self.chat(message)
            print("Bot:", answer)

In [96]:
bot = DialogBotV1(
    model=model,
    tokenizer=tokenizer,
    hyperparameters=hyperparameters,
    history=[
        'Hi, how are you doing?',
        "i'm doing well. how are you?",
        "I'am fixing a bug right now",
        "oh wow that's interesting. what bug are you fixing?",
	],
    persona=[
        "I'm a computer science fresher.",
        "I like racing games.",
        "Sometimes I write code for fun.",
	]
)

# bot.start_chat()
# Oh, It's disgusting bug. The button on website sometimes is black and sometimes is red. What do you think?

In [107]:
response = bot.single_chat(
	"The button sometimes is black and sometimes is red.",
)
print(response)
print(bot.persona)
print(bot.history)

what do you do for a living?
['I like racing games.', 'Sometimes I write code for fun.', "I'm a computer science fresher."]
['Hi, how are you doing?', "i'm doing well. how are you?", "I'am fixing a bug right now", "oh wow that's interesting. what bug are you fixing?"]


In [32]:
bot.history

['Hi, how are you doing?',
 "i'm doing well. how are you?",
 "I'am fixing a bug right now",
 'what bug are you fixing?',
 "Oh, It's disgusting bug. The button on website sometimes is black and sometimes is red. What do you think?",
 'what do you do for a living?']

In [4]:
class DialogBotV2(DialogBotV1):
    """
    uses Contrastive Search 
    """
    
    def generate_responce(self, sample):
        return self.model.generate(
            **sample, 
            # max_length=20,
            max_new_tokens=20, 
            penalty_alpha=0.1, 
            top_k=6
        )
        
bot2 = DialogBotV2(model=model,
    tokenizer=tokenizer,
    hyperparameters=hyperparameters,
    history=[
        'Hi, how are you doing?',
        "i'm doing well. how are you?",
        "I'am fixing a bug right now",
        "oh wow that's interesting. what bug are you fixing?",
	],
    persona=[
        "I'm a junior frontend developer.", 'I like racing games.', 'Sometimes I write code for fun.', "I'm a computer science fresher."
	]
)
    

In [5]:
response = bot2.single_chat(
	"I don't wanna talking about it.",
)
print(response)
print(bot2.persona)
print(bot2.history)

ok what do you do for fun?
['I like racing games.', "I'm a computer science fresher.", "I'm a junior frontend developer.", 'Sometimes I write code for fun.']
['Hi, how are you doing?', "i'm doing well. how are you?", "I'am fixing a bug right now", "oh wow that's interesting. what bug are you fixing?"]


In [53]:
bot2.history

['Hi, how are you doing?',
 "i'm doing well. how are you?",
 "I'am fixing a bug right now",
 "oh wow that's interesting. what bug are you fixing?",
 "i'm fixing a computer that needs fixed",
 "oh wow that's pretty cool. what do you do for fun?"]

In [35]:
bot2.start_chat()

Bot: i'm doing well. how are you?
Bot: oh wow that's interesting. what bug are you fixing?
Bot: i like red hot chilli peppers


In [36]:
bot2.history

['Hi, how are you doing?',
 "i'm doing well. how are you?",
 "I'am fixing a bug right now",
 "oh wow that's interesting. what bug are you fixing?",
 "Oh, It's disgusting bug. The button on website sometimes is black and sometimes is red. What do you think?",
 'i like red hot chilli peppers']