## Prerequisite
1. Jupyter must be installed.<br>
<code>\$ pip install jupyterlab</code><br>
2. Activate ipywidgets by one of following commands:<br>
    2-1. jupyter-lab<br>
<code>\$ jupyter labextension install @jupyter-widgets/jupyterlab-manager</code><br>
    2-2. jupyter notebook / Google Colab<br>
<code>\$ jupyter nbextension enable --py widgetsnbextension</code>
3. Install dependencies.<br>
<code>\$ pip install -r requirements.txt</code>

## ⇩⇩ RUN ALL CELLS BELOW ⇩⇩
Run cells all the way down to the bottom!<br>
<font color=#FFFF00>**CAUTION!**</font> First run will take time to download a model.

In [None]:
from warnings import simplefilter
ignore_warnings = lambda : simplefilter('ignore')
ignore_warnings()

import os
import re
import asyncio
import dataclasses
from style_bert_vits2.tts_model import TTSModel

from IPython.display import clear_output, HTML
from ipywidgets import widgets

import lib.markdown as md
from lib.infrastructure import ForgetableContext, LlamaCpp
from lib.uis import wait_for_change
from lib.utils import now, mixed2katakana, get, replace_text

import session_states as session
from global_settings import *
from prompt_builders import *

### Remove one cell below to disable default model download at first run.
This process will be automatically skipped if a directory <code>./ggufs</code> is not empty.

In [None]:
if len(session.ggufs) == 0:
    # IMO, Llama-3-8B is the minimum requirement that can run Python/RAG/tools.
    # I don't believe Llama-2-based small model can run this app🤔
    print(f"Directory \"{GGUF_DIR}\" is empty. Downloading default model.")
    LLAMA_3_8B_GGUF_HF_URL = "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q5_K_M.gguf"
    !wget -P {GGUF_DIR} {LLAMA_3_8B_GGUF_HF_URL}
    session.ggufs.append("Meta-Llama-3-8B-Instruct.Q5_K_M.gguf")
    session.llama_cpp_options.gguf_selector.options = session.ggufs
    session.llama_cpp_options.gguf_selector.value = session.ggufs[0]

In [None]:
def reload_instruction(instruction_file_path: str):
    instructions = ""
    with open(instruction_file_path, 'r') as f:
        lines = f.readlines()
        instructions = ''.join(lines)
        
    return instructions

prompt_builder = Llama3PromptBuilder(
    reload_instruction("system_prompt_template/l3_sys_ppt_gen_16.txt"), 
    session.context
)

In [None]:
# Updates GUIs.
    
def get_bigger(args):        
    nls = session.field.value.count('\n')
    session.field.rows = nls + 1 if nls >= 1 else 2
session.field.observe(get_bigger, 'value')

def change_prompt_builder(args):
    global prompt_builder
    context = session.context
    prompt_builder = {
        "Command R Template": CommandRPromptBuilder(reload_instruction("./system_prompt_template/cmdr_sys_ppt_gen_16.txt"), context),
        "Llama-3 Template": Llama3PromptBuilder(reload_instruction("./system_prompt_template/l3_sys_ppt_gen_16.txt"), context),
        "ChatML Template": ChatMLPromptBuilder(reload_instruction("./system_prompt_template/chatml_sys_ppt_gen_16.txt"), context),
    }[session.template_selector.value]

session.template_selector.observe(change_prompt_builder, "value")


def set_guessing_image(show: bool) -> None:
    if not session.guessing_image:
        with open('images/guessing.gif' if show else 'images/empty.png', 'rb') as f:
            session.guessing_image = widgets.Image(value=f.read(), width=50, height=50)
        
    with open('images/guessing.gif' if show else 'images/empty.png', 'rb') as f:
        session.guessing_image.value = f.read() 
        session.guessing_image.width = 50
        session.guessing_image.height = 50

set_guessing_image(False)


def set_buttons(disabled: bool) -> None:
    for b in session.buttons:
        b.disabled = disabled

def disable_uis(func):
    from functools import wraps
    @wraps(func)
    def wrapper(*args, **kwargs):
        set_buttons(disabled=True)
        set_guessing_image(True)
        result = func(*args, **kwargs)
        set_guessing_image(False)
        set_buttons(disabled=False)
        return result
    return wrapper

In [None]:
def _load_model(gguf_path: str) -> None:
    assert session.n_ctx > MAX_GENERATION_TOKENS
    session.n_ctx = session.llama_cpp_options.define_n_ctx.value

    try:
        session.model = LlamaCpp(
            model_path=gguf_path,
            n_gpu_layers=session.llama_cpp_options.n_gpu_layers.value,
            n_batch=1024,
            n_ctx=session.n_ctx,
            use_mlock=True,
            flash_attn=session.llama_cpp_options.flash_attention.value,
            verbose=False,
            embedding=False,
        )
    except ValueError as e:
        with session.debug:
            print(e)
    
def _unload_model() -> None:
    import gc
    import llama_cpp

    if session.model == None:
        return

    # Explicitly free the model.
    # https://github.com/abetlen/llama-cpp-python/pull/1513
    # session.model.llama.close()
    
    session.model = None
    gc.collect()
    
@disable_uis
def reload_model(sender=None) -> None:
    _unload_model()
    _load_model(os.path.join(GGUF_DIR, session.llama_cpp_options.gguf_selector.value))
    initialize()
    
@disable_uis
def unload_model(sender=None) -> None:
    _unload_model()
    initialize()

In [None]:
def format_to_html(context) -> str:
    from pygments import highlight
    from pygments.lexers import Python3Lexer
    from pygments.formatters import HtmlFormatter
    
    def embed_image_to_tag(image_binary) -> str:
        import base64
        encoded_image = base64.b64encode(image_binary).decode('utf-8')
        html_image_tag = f'<img src="data:image/jpeg;base64,{encoded_image}" />'
        return html_image_tag

    USER_MESSAGE_BG_COLOR = "#BBFFBB"
    AI_MESSAGE_BG_COLOR = "#FFEEBB"
    
    messages: list[str] = []
    for message in context.history():
        role = message['role']
        content = message['content']


        name: str
        text: str
        text_template = """<div style="background-color: {color}; word-wrap: break-word; color: black; padding: 10px; border-radius: 20px;">{content}</div>"""
        
        if role == "User":
            text = md.convert(content)
            text = text_template.format(content=text, color=USER_MESSAGE_BG_COLOR)
        else:
            text = md.convert(content)

            references = get(message, 'references')
            code = get(message, 'code')
            code_output = get(message, 'code_output')
            image_output = get(message, 'image_output')
            search_result = get(message, 'search_result')
            search_query = get(message, 'search_query')

            header = lambda text: f'<div style="background-color: #999999; color: black;">{text}</div>'
            if search_query:
                text += header("Google")
                text += '<div style="background-color: #FFFFFF; color: black;">' + search_query + '</div>'
            if search_result:
                # The search result itself is hidden.
                pass
            if references:
                text += header("Documents matched the query")
                text += '<div style="background-color: #FFFFFF; color: black;">' + '<br>'.join([f'✅<a href="{url}">︎{url[:50]}...</a>' for url in references]) + '</div>'
            if code:
                text += header("Python")
                text += highlight(code, Python3Lexer(), HtmlFormatter())
            if code_output:
                text += header('Output')
                text += '<pre><code>' + code_output + '</code></pre>'
            if image_output:
                text += '</br>' + embed_image_to_tag(image_binary=image_output) + '</br>'

            text = text_template.format(content=text, color=AI_MESSAGE_BG_COLOR)

        if role == "User":
            name = f"""<div style="background-color: #999999; color: white; width: 15px; height: 15px; padding: 10px; border-radius: 5px; text-align: center;">{session.user_nickname[0]}</div>"""
        elif role == session.assistant_name:
            name = f"""<div style="background-color: #999999; color: white; width: 15px; height: 15px; padding: 10px; border-radius: 5px; text-align: center;">{session.assistant_name[0]}</div>"""
        else:
            name = "</br>"
        messages.append(f'{name}{text}')
        
    return ''.join(messages)
    
def print_context():
    clear_output(wait=True)

    HEIGHT, WIDTH = 1100, 800
    
    html_text = f"""<!DOCTYPE html>
<html>

<head>
  <meta charset="utf-8" />
  <style>
    #wrapper {{
      display: flex;
      flex-direction: column-reverse;
      height: {HEIGHT}px;
      width: {WIDTH}px;
      overflow-y: scroll;
    }}

    /* Custom Scrollbar CSS */
    #wrapper::-webkit-scrollbar {{
      width: 10px;
    }}

    #wrapper::-webkit-scrollbar-track {{
      background: #f1f1f1;
    }}

    #wrapper::-webkit-scrollbar-thumb {{
      background: #888;
    }}

    #wrapper::-webkit-scrollbar-thumb:hover {{
      background: #555;
    }}
    
  </style>
</head>

<body>
    <div style="display: flex; align-items: flex-end;">
      <div id="wrapper">
        <div id="contents">
    {format_to_html(session.context)}
        </div>
      </div>
    </div>
</body>

</html>
"""

    display(HTML(html_text))

def update_display() -> None:
    with session.out:
        print_context()

In [None]:
# TTS model loader

def is_tts_model_dir(path):
    if not os.path.isdir(path):
        return False
    files_to_check = [
        f"{os.path.basename(path)}.safetensors",
        "config.json",
        "style_vectors.npy",
    ]
    folder_files = os.listdir(path)
    return all(file in folder_files for file in files_to_check)



@session.debug.capture()
def load_tts_models(model_path):
    from style_bert_vits2.nlp import bert_models
    from style_bert_vits2.constants import Languages
    import gc; gc.collect()
    
    bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
    bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
    
    return TTSModel(
        model_path=os.path.join(model_path, f"{os.path.basename(model_path)}.safetensors"),
        config_path=os.path.join(model_path, "config.json"),
        style_vec_path=os.path.join(model_path, "style_vectors.npy"),
        device="cpu"
    )


async def capture_model_selection_change():
    while True:
        selected = await wait_for_change(session.dropdown, "value")
        session.tts_model = load_tts_models(os.path.join(TTS_ASSET_ROOT, selected))
        session.debug.append_stdout(f"tts_model({selected}) loaded.\n")

tts_model_names = [item for item in os.listdir(TTS_ASSET_ROOT) if is_tts_model_dir(os.path.join(TTS_ASSET_ROOT, item))]
if len(tts_model_names) != 0:
    session.tts_model = load_tts_models(os.path.join(TTS_ASSET_ROOT, tts_model_names[0]))
session.dropdown = widgets.Dropdown(description="TTS model", options=tts_model_names, value=tts_model_names[0]) if len(tts_model_names) != 0 else widgets.Dropdown(description="TTS model", options=[])
session.loop.create_task(capture_model_selection_change());

In [None]:
def export_log():
    import datetime
    filename = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S.txt")
    with open(os.path.join(LOG_DIR, filename), 'w') as f:
       f.writelines(str(session.context))
        
def time_stamp() -> str:
    import datetime
    import time
    now = datetime.datetime.now()
        
    weekday = ['Mon.', 'Tue', 'Wed.', 'Thu.', 'Fri.', 'Sat.', 'Sun.']
    return "{}-{:02}-{:02} {} {:02}:{:02}({})".format(
        now.year,
        now.month,
        now.day,
        weekday[now.weekday()],
        now.hour,
        now.minute,
        time.tzname[0],
    )

In [None]:
def load_sbv2_dict(dict_path: str = "sbv2_dict.json") -> dict:
    import json
    if not os.path.isfile(dict_path): 
        return {}
    with open(dict_path, "r") as f:
        return json.load(f)

def assistant_speaks(text) -> None:
    if session.tts_model is None:
        return
    
    text = mixed2katakana(text)
    text = re.sub(r'[(（].*?[)）]', '　', text)

    text = replace_text(text, load_sbv2_dict())
    

    try:
        from IPython.display import Audio
        if len(text) > 0:
            with session.debug: 
                sr, wav = session.tts_model.infer(
                    text, 
                    length=session.voice_length.value
                )
        else:
            return
        audio = Audio(wav, rate=sr, autoplay=False)
        session.voice_player.clear_output(wait=True)
        with session.voice_player: display(audio)
    except BaseException as e:
        with session.debug: print(e)
            
@disable_uis
def create_voice_action(sender=None):
    try:
        text = get(session.context.history()[-1], "content")
        assistant_speaks(text)
    except:
        pass

In [None]:
def initialize(sender=None) -> None:
    session.context.reset()
    session.field.value = ''
    session.user_nickname_field.disabled = False
    session.assistant_name_field.disabled = False
    session.user_preamble.disabled = False
    session.template_selector.disabled = False
    set_buttons(disabled=False)
    set_guessing_image(show=False)
    
    session.login_time_stamp = time_stamp()
    session.context_window = 0

    update_display()

    from IPython.display import Audio
    session.voice_player.clear_output(wait=True)
    with session.voice_player:
        display(Audio(b''))

initialize()

def submit_names_action(sender=None) -> None:
    if session.user_nickname_field.value == "":
        session.user_nickname_field.value = DEFAULT_USER_NICKNAME
    if session.assistant_name_field.value == "":
        session.assistant_name_field.value = DEFAULT_ASSISTANT_NAME
    if session.user_nickname_field.value == session.assistant_name_field.value:
        shared_name = session.user_nickname_field.value
        session.user_nickname_field.value = f"{shared_name}_1"
        session.assistant_name_field.value = f"{shared_name}_2"
    session.user_nickname = session.user_nickname_field.value
    session.assistant_name = session.assistant_name_field.value
    session.user_nickname_field.disabled = True
    session.assistant_name_field.disabled = True
    session.user_preamble.disabled = True
    session.template_selector.disabled = True

In [None]:
def interrupt_generation() -> None:
    """
    Currently this do nothing since llama-cpp-python doesn't provide the functionality. 
    Related PRs:
        https://github.com/abetlen/llama-cpp-python/pull/733
        https://github.com/abetlen/llama-cpp-python/issues/599
    """
    with session.debug: 
        print("Generation interrupted.")

def predict_stream(additional_stop_tokens: list[str] = []):
    prompt = prompt_builder.build()

    generation_params = session.generation_params
    
    streamer = session.model.stream(
        input=prompt,
        temperature=generation_params.temperature.value,
        top_p=generation_params.top_p.value,
        top_k=generation_params.top_k.value,
        frequency_penalty=generation_params.frequency_penalty.value,
        presence_penalty=generation_params.presence_penalty.value,
        repeat_penalty=generation_params.repeat_penalty.value,
        max_tokens=MAX_GENERATION_TOKENS,
        stop=additional_stop_tokens
    )

    pattern = re.compile(
        r'(.*?```\n?({tools}).*?```)'.format(tools='|'.join([tool.name for tool in session.tools])), 
        re.DOTALL
    )

    output = ''
    for token in streamer:
        output += token
        
        match = re.search(pattern, output)
        if match:
            output = match.group(1)
            yield output
            interrupt_generation()
            break 
        else:
            yield output.strip()
        
        

def predict_stream_with_display(additional_stop_tokens: list[str] = []):
    try:
        with session.debug:
            for output in predict_stream(additional_stop_tokens):
                session.context.push_message(session.assistant_name, output)
                update_display()
                session.context.force_pop_front()
    except KeyboardInterrupt:
        # This works fine in most of the cases, but if interruption occurs between push and pop, the behavior is undeifned.
        interrupt_generation()
    finally:
        session.context.push_message(session.assistant_name, output)
        update_display()
        session.context.force_pop_front()
        return output

In [None]:
def shift_context_by_item(n_items: int = 1) -> None:
    """
    Shift KV Cache by amount of n_items leaving system prompts.
    Calling this function without decrement context_window result in failure of context shifting.

    Implementation of StreamingLLM in oobabooga/text-generation-webui:
    https://github.com/oobabooga/text-generation-webui/blob/main/modules/cache_utils.py#L24
    According to this implementation, prefix (=n_sys_ppt_token) corresponds to the "Attention Sinks" in terms of StreamingLLM.
    
    Args:
        n_items (int, default 1): Number of items to shift, but NOT the number of "tokens".
    Returns:
        None
    """
    from lib.infrastructure import kv_cache_seq_ltrim
    
    for i in range(n_items):
        trimmed_history = session.context.history()[-session.context_window:]
        oldest_item = trimmed_history[i]
        
        n_sys_ppt_tokens = session.model.token_count(prompt_builder.render_instruction()) + 1 # +1 for bos token.
        n_oldest_ctx_tokens = session.model.token_count(prompt_builder.render_item(oldest_item))
        
        kv_cache_seq_ltrim(
            model=session.model.llama, 
            n_keep=n_sys_ppt_tokens,
            n_discard=n_oldest_ctx_tokens,
        )
        
        with session.debug: 
            print(f"{n_oldest_ctx_tokens} tokens discarded from position {n_sys_ppt_tokens}.")
            print(f"discarded: {prompt_builder.render_item(oldest_item)}")
    

def push_context(item, auto_shift_kv: bool) -> None:
    """
    Args:
        item: new item to push.
        auto_shift_kv(bool): if true, shift KV-Cache and input_id if needed.
    Returns:
        None
    """
    session.context.push(item)

    if not auto_shift_kv:
        return

    session.context_window += 1
    
    if session.model.token_count(prompt_builder.build()) < session.n_ctx - MAX_GENERATION_TOKENS:
        return
        
    while session.model.token_count(prompt_builder.build()) >= session.n_ctx - MAX_GENERATION_TOKENS:
        # The order matters since shift_context_by_item uses context_window internally.
        # Dont decrement context_windows before calling shift_context_by_item.
        shift_context_by_item(1)
        session.context_window -= 1
        if session.context_window == 0: break



def retrieve_latest_input(sender = None):
    """
    Scheme for context management when retrieving item:
              |        |
    [0, 1, 2, 3, 4, 5, 6]
              |     | <- shrink context window 
    [0, 1, 2, 3, 4, 5]
    """
    session.context_window = max(session.context_window - 1, 0)

    if len(session.context) <= 0:
        session.field.value = ''
        update_display()
        return

    last_item = session.context.force_pop_front(stringize=False)
    if last_item['role'] == "User":
        session.field.value = last_item['content']
    update_display()
    return

In [None]:
def python_chain(python_code) -> None:
    if python_code == "":
        push_context({
            "role": PYTHON_RUNTIME_NAME,
            "content": "Empty code is not allowed.",
        }, auto_shift_kv=True)
        update_display()
        return

        
    # Run code.
    session.py.unset(keep_locals=True)
    session.py.run(python_code)
    _, code_output, image_output = session.py.result()

    # Register into context.
    push_context({
        "role": PYTHON_RUNTIME_NAME,
        "content": "",
        "code_output": code_output.strip() if code_output else "Empty stdout/stderr.",
        "image_output": image_output,
    }, auto_shift_kv=True)
    update_display()



def search_chain(search_query: str) -> None:
    # To discard dialogs of first downloading the model.
    with session.debug:
        from lib.rag import pick_relevant_web_documents, MultilingualE5Small

    if search_query == "":
        push_context({
            "role": SEARCH_AGENT_NAME,
            "content": "Empty search query is not allowed.",
        }, auto_shift_kv=True)
        update_display()
        return
    
    # Collect documents.
    documents = pick_relevant_web_documents(
        search_query, 
        embedding=MultilingualE5Small(),
        engine="duckduckgo",
        n_relevant_chunks=3,
        n_search_results=20,
    )
    
    # Format results.
    search_result = ""
    for i, doc in enumerate(documents):
        search_result += f"""# Document Num: {i+1}
# Document Title: {doc["title"]}
# Document URL: {doc["url"]}
# Document Content: {doc["content"]}

"""
    search_result = search_result.rstrip()
    if len(documents) == 0:
        search_result= "No result hits."
    referred_urls = list(set([doc["url"] for doc in documents]))

    
    push_context({
        "role": SEARCH_AGENT_NAME,
        "content": "",
        "references": referred_urls,
        "search_result": search_result,
    }, auto_shift_kv=True)
    update_display()

In [None]:
# Main loop.
@disable_uis
def main(sender=None) -> None:
    from lib.utils import reformat_python_code

    
    if session.model == None:
        push_context({
            "role": session.assistant_name, 
            "content": "Model is not loaded yet. Please select and load a local model before inference."
        }, auto_shift_kv=False)
        update_display()
        return
        

    # Register user's input.
    user_message = session.field.value
    if session.model.token_count(user_message) >= MAX_USER_TOKENS: return
    if user_message == "": return
    submit_names_action()
    
    session.field.value = ''
    push_context({
        "role": "User", 
        "content": user_message
    }, auto_shift_kv=True)
    update_display()

    
    output = predict_stream_with_display()

    
    # Parse tool.
    regex = re.compile(r'```\n?({tools})(.*)\n?```'.format(tools='|'.join([tool.name for tool in session.tools])), re.DOTALL)
    tool = re.search(regex, output)

    
    if tool is None:
        push_context({
            "role": session.assistant_name, 
            "content": output,
        }, auto_shift_kv=True)
        update_display()
        
    else:
        stripped = re.sub(regex, '', output).rstrip()
        tool_type, tool_input = tool.group(1), tool.group(2).strip()

        if tool_type == "python":
            reformatted = reformat_python_code(tool_input)
            push_context({
                "role": session.assistant_name,
                "content": stripped,
                "code": reformatted,
            }, auto_shift_kv=True)
            update_display()
            python_chain(reformatted)
        elif tool_type == "google":
            push_context({
                "role": session.assistant_name,
                "content": stripped,
                "search_query": tool_input,
            }, auto_shift_kv=True)
            update_display()
            search_chain(tool_input)
        else:
            raise NotImplementedError(f"Tool {tool_type} is not implemented yet.")
        
        # Add reaction to the tool results.
        reaction = predict_stream_with_display(additional_stop_tokens=[f"```{tool.name}" for tool in session.tools])
        push_context({
            "role": session.assistant_name, 
            "content": reaction
        }, auto_shift_kv=True)
        update_display()

    
    export_log()

In [None]:
# Define UI actions.
session.button.on_click(main)
session.reset_button.on_click(initialize)
session.retrieve.on_click(retrieve_latest_input)
session.create_voice.on_click(create_voice_action)
session.load_button.on_click(reload_model)
session.unload_button.on_click(unload_model)

In [None]:
def show_guis(show_debug: bool = False) -> None:
    html = f"""<h1>Integrative LLM Chat UI for Jupyter</h1>
This is an LLM-powered chat interface integrated with voice synthesis model, web search-based RAG, and python environment.</br>
For better results, I strongly recommend you to select a model large enough or trained for tool use.</br>
If you wanna use unsupported prompt template, define <code>PromptBuilder</code> class yourself.</br>
<br>
<code>./{GGUF_DIR}</code>: Directory to put GGUF models.</br>
<code>./{AGENT_WORKING_DIR}</code>: Working directory when executing Python code.</br>
<code>./{TTS_ASSET_ROOT}</code>: Directory to put TTS models.<br> 
(The subdirectory name containing the model should match the *.safetensors file name)</br>
</br>
"""
    HBox = widgets.HBox
    generation_params = session.generation_params
    display(
        HTML(html),
        
        HTML("<h3>Session Options</h3>"),
        HBox([session.template_selector,session.streamingllm]),
        HBox([session.user_nickname_field, session.assistant_name_field]),
        session.user_preamble,
        
        HTML("<h3>llama-cpp-python GGUF Loader</h3>"),
        HBox([session.llama_cpp_options.define_n_ctx, session.llama_cpp_options.n_gpu_layers, session.llama_cpp_options.flash_attention]),
        HBox([session.llama_cpp_options.gguf_selector, session.load_button, session.unload_button]),
        session.out,
        
        HTML("<h3>Generation Params</h3>"),
        HBox([generation_params.temperature, generation_params.top_k]),
        HBox([generation_params.top_p, generation_params.frequency_penalty]),
        HBox([generation_params.presence_penalty, generation_params.repeat_penalty]),
        HBox([session.field, session.button, session.guessing_image]),
        HBox([session.retrieve, session.reset_button]),
        
        HTML("<h3>Voice Synthesis (last item)</h3>"),
        HBox([session.dropdown, session.voice_length]),
        session.create_voice,
        session.voice_player,
        session.debug if show_debug else HTML(""),
    )
    
ignore_warnings()
show_guis(show_debug=False)