# Chatbot Based on MS [GODEL](https://www.microsoft.com/en-us/research/uploads/prod/2022/05/2206.11309.pdf)

###Install Gradio and Transformers

In [None]:
! pip install transformers gradio -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.2/14.2 MB[0m [31m42.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m28.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.4/182.4 KB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.5/84.5 KB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.8/55.8 KB[0m [31m622.3 kB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m270.5/270.5 KB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 KB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

### Model Setup

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")

Downloading:   0%|          | 0.00/2.37k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/37.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.49k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.95G [00:00<?, ?B/s]

### Predict Function with State

In [None]:
def predict(input, instruction, knowledge, history=[]):

    s = list(sum(history, ()))

    s.append(input)

    dialog = ' EOS ' .join(s)

    if knowledge == "":
        query = f"{instruction} [CONTEXT] {dialog} {knowledge}"
    else:
        query = f"{instruction} [CONTEXT] {dialog} [KNOWLEDGE] {knowledge}"

    top_p = 0.9
    min_length = 8
    max_length = 64


    # tokenize the new input sentence
    new_user_input_ids = tokenizer.encode(f"{query}", return_tensors='pt')


    output = model.generate(new_user_input_ids, min_length=int(
        min_length), max_length=int(max_length), top_p=top_p, do_sample=True).tolist()
    
    response = tokenizer.decode(output[0], skip_special_tokens=True)

    history.append((input, response))

    return history, history

### Gradio UI

In [None]:
import gradio as gr

gr.Interface(fn=predict,
             inputs=[gr.Textbox(label="write something..."), 
                     gr.Dropdown(["Instruction: given a dialog context, you need to response empathically",
                                  "Instruction: given a dialog context and related knowledge, you need to answer the question based on the knowledge.",
                                  "Instruction: given a dialog context and related knowledge, you need to response safely based on the knowledge."]), 
                     gr.Dropdown(["", "Carlos Alcaraz, at just 19, defeated No. 5 Casper Ruud to win the 2022 US Open",
                                  "Scooby-Doo is a character created in 1969 by the American animation company Hanna-Barbera.He is a male Great Dane" 
                                  "Over-the-counter medications such as ibuprofen (Advil, Motrin IB, others), acetaminophen (Tylenol, others) and aspirin.",
                                  "The best Stardew Valley mods PCGamesN_0 / About SMAPI"]), 
                     'state'],
             outputs=[gr.Chatbot(label="GODEL"),'state'],
             title = "Context & Knowledge Base Aware Chatbot with Behaviour Instruction",
             description = "Built on [Microsoft GODEL](https://www.microsoft.com/en-us/research/project/godel/)").launch(debug = True, share = True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://7e3e74a4-1c47-481d.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://7e3e74a4-1c47-481d.gradio.live


