# Introduction 

This notebook replicates a simple chat template with continuous chat. The model understands and remembers the context according to its capacity.

This is a OPT 350M model fine-tuned to Chat Alpaca dataset (https://huggingface.co/datasets/flpelerin/ChatAlpaca-10k). Find the fine-tuning notebook in the `assistant_sft` directory.

**NOTE: The notebook uses a customized streamer for text streaming.**

In [1]:
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer,
    pipeline,
    logging,
)

from streaming_utils import TextStreamer

In [2]:
model = AutoModelForCausalLM.from_pretrained(
    '../assistant_sft/outputs/opt_350m_chat_alpaca/best_model/'
)
tokenizer = AutoTokenizer.from_pretrained(
    '../assistant_sft/outputs/opt_350m_chat_alpaca/best_model/'
)

In [3]:
streamer = TextStreamer(
    tokenizer, 
    skip_prompt=True, 
    skip_special_tokens=True, 
    truncate_before_pattern=['\[\/'],
    truncate=True
)

In [4]:
print(tokenizer.eos_token)

</s>


In [5]:
logging.set_verbosity(logging.CRITICAL)

In [6]:
# template = """</s>[INST] {prompt} [/INST]"""
eos_string = tokenizer.eos_token
history = None

In [7]:
# print(template)

In [8]:
while True:
    question=input("Question: ")
    inputs = ''

    if history is None:
        template = """</s>[INST] {prompt} [/INST]"""
    else:
        template = """[INST] {prompt} [/INST]"""

    prompt = history + ' ' + template.format(prompt=question, inputs=inputs) if history is not None else template.format(prompt=question, inputs=inputs)

    # print(f"PROMPT: {prompt}")

    prompt_tokenized = tokenizer(prompt, return_tensors='pt')['input_ids']
    
    output_tokenized = model.generate(
        input_ids=prompt_tokenized, 
        max_length=len(prompt_tokenized[0])+256,
        temperature=0.7,
        top_k=40,
        top_p=0.1,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,
        streamer=streamer
    )
    answer = tokenizer.decode(token_ids=output_tokenized[0][len(prompt_tokenized[0]):]).strip()
    
    if eos_string in answer:
        answer = answer.split(eos_string)[0].strip()
    if '[/' in answer:
        answer = answer.split('[/')[0].strip()

    history = ' '.join([prompt, answer, eos_string])
    # print(f"ANSWER: {answer}\n")
    # print(f"HISTORY: {history}\n")
    print('#' * 50)

Question:  What is 2+2?


2+2 is 2.
##################################################


KeyboardInterrupt: Interrupted by user