# Basic Chatbot

Language generation can be used to create a chatbot. The prompt given to a language model should contain some information about the task, the conversation history and the latest user input.

Below is the prompt template with placeholders for the conversation history and input. We'll try this out with `distilgpt2`. However it is a smaller model that will run quickly but not perform very well as a chatbot. In comparison, larger chat-specific language models (including ChatGPT) can adapt well to different inputs and have been trained specifically for conversations.

In [8]:
chatbot_prompt = """
You are an advanced chatbot who will try to answer the user's queries honestly.

CONVERSATION_HISTORY

User: INPUT
Chatbot:""".strip()

We'll load the text generation pipeline using the `distilgpt2` model

In [9]:
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
generator = pipeline('text-generation', model='distilgpt2')

Device set to use mps:0


Text generation typically runs until the desired length of text is attained (e.g. with `max_new_tokens`) or an end-of-sequence special token is output. We want our text generation to end when a newline character (e.g. a return) appears, so that the chatbot can only generate a single line of text. For that we'll define a special StoppingCriteria to give to the `generator`.

In [10]:
from transformers import StoppingCriteria, StoppingCriteriaList

newline_index = tokenizer.encode('\n')[0]


class StopOnNewLine(StoppingCriteria):
    """
    Purpose: To stop text generation when a newline character is generated.
    Inheritance: Inherits from the StoppingCriteria class.
    """

    def __call__(self, input_ids, scores, **kwargs):
        """
        The method checks if the last token ID in the sequence matches the token ID for the newline character (newline_index). If it does, the method returns True, signaling the generation process to stop. Otherwise, it returns False, allowing the generation to continue.

        :param input_ids: A tensor containing the token IDs generated so far. It represents the sequence of tokens the model has produced.
        :param scores: A tensor containing the model's output probabilities for the next token. This is not used in this implementation.
        :param kwargs: Additional keyword arguments (not used in this implementation).
        :return: 
        """
        return input_ids[0][-1] == newline_index

In [11]:
conversation_turns = 0
# Start the conversation with a blank slate
conversation_history = ''
# Run infinitely (press the Stop button to stop this cell)
while conversation_turns < 3:
    # Solicit input from the user
    user_input = input("User: ")

    # Integrate the conversation history and latest user input into the prompt
    prompt = chatbot_prompt.replace('CONVERSATION_HISTORY', conversation_history).replace('INPUT', user_input)

    print(f"Conversation turn: {conversation_turns}")
    print(f"Conversation history: {conversation_history}")
    print(f"Prompt: {prompt}\n")
    print("===============Previous context=================")

    # Run the generation pipeline
    generated = generator(prompt,
                          max_new_tokens=100,
                          do_sample=True,  # Use sampling for more interesting output
                          return_full_text=False,  # Only return new tokens, not the prompt
                          stopping_criteria=StoppingCriteriaList([StopOnNewLine()]),
                          # Stop on a new line (or if max_new_tokens is reached)
                          pad_token_id=50256,
                          )  # Provide that so it doesn't give an annoying warning

    # Get the response and output it
    response = generated[0]['generated_text'].strip()
    print(f"Chatbot: {response}\n")

    # Integrate in the latest input/response into the conversation history
    conversation_history = f"""
{conversation_history}

User: {user_input}
Chatbot: {response}
""".strip()
    conversation_turns += 1


Conversation turn: 0
Conversation history: 
Prompt: You are an advanced chatbot who will try to answer the user's queries honestly.



User: Are you a human?
Chatbot:

Chatbot: Yes.

Conversation turn: 1
Conversation history: User: Are you a human?
Chatbot: Yes.
Prompt: You are an advanced chatbot who will try to answer the user's queries honestly.

User: Are you a human?
Chatbot: Yes.

User: Which is stronger an elephant or a dog?
Chatbot:

Chatbot: Yes.

Conversation turn: 2
Conversation history: User: Are you a human?
Chatbot: Yes.

User: Which is stronger an elephant or a dog?
Chatbot: Yes.
Prompt: You are an advanced chatbot who will try to answer the user's queries honestly.

User: Are you a human?
Chatbot: Yes.

User: Which is stronger an elephant or a dog?
Chatbot: Yes.

User: Which is faster a falcon or a cheetah?
Chatbot:

Chatbot: Yes.



In [13]:
chatbot_prompt = """
You are an advanced YouTue video content editor. You suggest engaging titles, and help with content brainstorming

CONVERSATION_HISTORY

User: INPUT
Chatbot:""".strip()

user_inputs = ["Give me titles talking about Rust Programming",
               "Generate a video script description of one of the titles",
               "Elaborate more about the content of the video"]

conversation_history = ''

model = AutoModelForCausalLM.from_pretrained('distilgpt2')
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
for user_input in user_inputs:
    prompt = chatbot_prompt.replace('CONVERSATION_HISTORY', conversation_history).replace('INPUT', user_input)
    print("Prompt before generation:")
    print(prompt)
    print("==============================\n")
    generated = generator(prompt,
                          max_new_tokens=100,
                          do_sample=True,  # Use sampling for more interesting output
                          return_full_text=False,  # Only return new tokens, not the prompt
                          stopping_criteria=StoppingCriteriaList([StopOnNewLine()]),
                          # Stop on a new line (or if max_new_tokens is reached)
                          pad_token_id=50256)  # Provide that so it doesn't give an annoying warning

    # Get the response and output it
    response = generated[0]['generated_text'].strip()
    print(f"Chatbot: {response}\n")

    # Integrate in the latest input/response into the conversation history
    conversation_history = f"""
{conversation_history}

User: {user_input}
Chatbot: {response}
""".strip()

Device set to use mps:0


Prompt before generation:
You are an advanced YouTue video content editor. You suggest engaging titles, and help with content brainstorming



User: Give me titles talking about Rust Programming
Chatbot:

Chatbot: I've been working with Rust development for over a year now which is still kinda nice. So my next project are my projects and I'm having a lot fun with it.

Prompt before generation:
You are an advanced YouTue video content editor. You suggest engaging titles, and help with content brainstorming

User: Give me titles talking about Rust Programming
Chatbot: I've been working with Rust development for over a year now which is still kinda nice. So my next project are my projects and I'm having a lot fun with it.

User: Generate a video script description of one of the titles
Chatbot:

Chatbot: I'm a freelance developer who is getting into Visual Studio, so there's definitely no easy way to write a video. You can also use the text or use the video editor.

Prompt before generatio

## Limitations

There are numerous limitations with this basic approach

- The `distilgpt2` model is small and not optimised for chat so it won't perform that well in this context
- The model has a smaller context window and there isn't any checking to see if the prompt is too large for the model. This will happen when the `conversation_history` gets too long. Then the text generation won't function as expected and meaningless responses may be made.
- A lot of tweaks could be made to the parameters used for text generation (e.g. temperature, top_k, etc) to get better results