## Prerequisite
1. Jupyter must be installed.<br>
<code>\$ pip install jupyterlab</code><br>
2. Activate ipywidgets by one of following commands:<br>
    2-1. jupyter-lab<br>
<code>\$ jupyter labextension install @jupyter-widgets/jupyterlab-manager</code><br>
    2-2. jupyter notebook / Google Colab<br>
<code>\$ jupyter nbextension enable --py widgetsnbextension</code>
3. Install dependencies.<br>
<code>\$ pip install -r requirements.txt</code><br>
4. *(Optional)* You might want to unlock websocket message size limit to send large files or to play longer sounds (the limit is 10MB by default).<br>
<code>\$ jupyter notebook --generate-config</code><br>
After running the command, edit following line of the config file:<br>
<code>c.ServerApp.tornado_settings = {"websocket_max_message_size":100\*1024\*1024} # Your preference</code><br>
<br>
<a href=https://github.com/yamikumo-DSD/chat_cmr/tree/main>GitHub link for the repository</a>
## ⇩⇩ RUN ALL CELLS BELOW ⇩⇩
Run cells all the way down to the bottom!<br>
<font color=#FFFF00>**CAUTION!**</font> First run will take time to download a model.

In [1]:
from warnings import simplefilter
ignore_warnings = lambda : simplefilter('ignore')
ignore_warnings()

import os
import re
import asyncio
import dataclasses
from style_bert_vits2.tts_model import TTSModel

from IPython.display import clear_output, HTML
from ipywidgets import widgets

import lib.markdown as md
from lib.infrastructure import ForgetableContext, LlamaCpp
from lib.uis import wait_for_change, Switchable
from lib.utils import now, mixed2katakana, replace_text

import session_states as session
import agent_tools as tools
from global_settings import *
from prompt_builders import *

In [2]:
def make_default_dirs() -> None:
    default_dirs = [GGUF_DIR, AGENT_WORKING_DIR, LOG_DIR, TTS_ASSET_ROOT, CACHE_DIR]
    for directory in default_dirs:
        if os.path.isdir(directory):
            continue
        try:
            print(f"Directory \"{directory}\" does not exsist. Creating.")
            os.makedirs(directory)
        except FileExistsError as e:
            print(e)
            
make_default_dirs()

In [3]:
def reload_instruction(instruction_file_path: str):
    instructions = ""
    with open(instruction_file_path, 'r') as f:
        lines = f.readlines()
        instructions = ''.join(lines)
        
    return instructions

prompt_builder = Llama3PromptBuilder(
    reload_instruction("system_prompt_template/l3_sys_ppt_gen_16.txt"), 
    session.context
)

def change_prompt_builder(args):
    global prompt_builder
    context = session.context
    prompt_builder = {
        "Command R": CommandRPromptBuilder(reload_instruction("./system_prompt_template/cmdr_sys_ppt_gen_16.txt"), context),
        "Llama-3 Instruct": Llama3PromptBuilder(reload_instruction("./system_prompt_template/l3_sys_ppt_gen_16.txt"), context),
        "ChatML": ChatMLPromptBuilder(reload_instruction("./system_prompt_template/chatml_sys_ppt_gen_16.txt"), context),
        "Llama-2 Instruct": Llama2PromptBuilder(reload_instruction("./system_prompt_template/l2_sys_ppt_gen_16.txt"), context),
        "Llama-2 Instruct JA": JaCommMSPromptBuilder(reload_instruction("./system_prompt_template/ja_community_ms_sys_ppt_gen_16.txt"), context),
        "Gemma 2 Instruct": Gemma2Instruct(reload_instruction("./system_prompt_template/gemma_2_it_sys_ppt_gen_16.txt"), context),
    }[session.template_selector.value]

session.template_selector.observe(change_prompt_builder, "value")

In [4]:
# Updates GUIs.
def get_bigger(args):        
    nls = session.field.value.count('\n')
    session.field.rows = nls + 1 if nls >= 1 else 2
session.field.observe(get_bigger, 'value')


def fix_max_gen_tokens(args) -> None:
    options = session.llama_cpp_options
    options.define_max_gen_tokens.value = min(
        options.define_max_gen_tokens.value, 
        session.n_ctx//2
    )
    
session.llama_cpp_options.define_max_gen_tokens.observe(fix_max_gen_tokens, "value")

def set_guessing_image(show: bool) -> None:
    if not session.guessing_image:
        with open('images/guessing.gif' if show else 'images/empty.png', 'rb') as f:
            session.guessing_image = widgets.Image(value=f.read(), width=50, height=50)
        
    with open('images/guessing.gif' if show else 'images/empty.png', 'rb') as f:
        session.guessing_image.value = f.read() 
        session.guessing_image.width = 50
        session.guessing_image.height = 50

set_guessing_image(False)


def set_buttons(disabled: bool) -> None:
    for b in session.buttons:
        b.disabled = disabled

def disable_uis(func):
    from functools import wraps
    @wraps(func)
    def wrapper(*args, **kwargs):
        set_buttons(disabled=True)
        set_guessing_image(True)
        result = func(*args, **kwargs)
        set_guessing_image(False)
        set_buttons(disabled=False)
        return result
    return wrapper

In [5]:
def refresh_gguf_list(sender=None):
    selector = session.llama_cpp_options.gguf_selector
    session.ggufs = [item for item in os.listdir(GGUF_DIR) if item.endswith(".gguf")]
    ggufs = session.ggufs
    
    selector.options = ggufs if len(ggufs) > 0 else ["No gguf in dir"]
    selector.value = ggufs[0] if len(ggufs) > 0 else "No gguf in dir"

    
def _load_model(gguf_path: str) -> None:
    options = session.llama_cpp_options
    assert session.n_ctx > session.max_gen_tokens, ValueError("max_gen_tokens must be smaller than n_ctx")

    try:
        with session.debug:
            session.model = LlamaCpp(
                model_path=gguf_path,
                n_gpu_layers=options.n_gpu_layers.value,
                n_batch=1024,
                n_ctx=options.define_n_ctx.value,
                use_mlock=True,
                flash_attn=options.flash_attention.value,
                verbose=True,
                embedding=False,
                type_k=8 if options.quantize_kv.value else None, # Default is FP16
                type_v=8 if options.quantize_kv.value else None, # Default is FP16
            )
        session.n_ctx = options.define_n_ctx.value
        session.max_gen_tokens = options.define_max_gen_tokens.value
        session.active_gguf = os.path.basename(gguf_path)
        session.set_gguf_viewer(session.active_gguf)
    except ValueError as e:
        with session.debug:
            print(e)
    
def _unload_model() -> None:
    import gc
    import llama_cpp
    from packaging.version import Version

    if session.model == None:
        return

    # Explicitly free the model. This is very recently implemented method.
    # https://github.com/abetlen/llama-cpp-python/pull/1513
    if Version(llama_cpp.__version__) >= Version("0.2.78"):
        session.model.llama.close()
    
    session.model = None
    session.active_gguf = ""
    session.unset_gguf_viewer()
    gc.collect()
    
@disable_uis
def reload_model(sender=None) -> None:
    _unload_model()
    _load_model(os.path.join(GGUF_DIR, session.llama_cpp_options.gguf_selector.value))
    initialize()
    
@disable_uis
def unload_model(sender=None) -> None:
    _unload_model()
    initialize()

## Download default model
### Remove one cell below to disable default model download at first run.
This process will be automatically skipped if a directory <code>./ggufs</code> is not empty.<br>

In [6]:
def download_default_model(url: str) -> None:
    if len(session.ggufs) > 0: return
        
    print(f"Directory \"{GGUF_DIR}\" is empty. Downloading default model.")
    MAX_CONTINUE = 10
    for i in range(MAX_CONTINUE):
        answer = input("Do you want to skip? [yes(y)/no(n)]" if i == 0 else "Could not recognize your answer. Do you want to skip? [yes(y)/no(n)]")
        if answer in ["yes", "Yes", "YES", "y", "Y"]: 
            return
        elif answer in ["no", "No", "NO", "n", "N"]: 
            !wget -P {GGUF_DIR} {url}
            return

download_default_model("https://huggingface.co/internlm/internlm2_5-7b-chat-gguf/resolve/main/internlm2_5-7b-chat-q6_k.gguf")
refresh_gguf_list()

In [7]:
def format_to_html(context) -> str:
    from pygments import highlight
    from pygments.lexers import Python3Lexer
    from pygments.formatters import HtmlFormatter
    
    def embed_image_to_tag(image_binary, width=None, height=None, alt=None) -> str:
        import base64
        encoded_image = base64.b64encode(image_binary).decode('utf-8')
        html_image_tag = f'<img src="data:image/jpeg;base64,{encoded_image}" '
        if width:
            html_image_tag += f'width="{width}" '
        if height:
            html_image_tag += f'height="{height}" '
        if alt:
            html_image_tag += f'alt="{alt}" '
        html_image_tag += '/>' # Closure
        return html_image_tag

    USER_MESSAGE_BG_COLOR = "#BBFFBB"
    AI_MESSAGE_BG_COLOR = "#FFEEBB"
    
    messages: list[str] = []
    for message in context.history():
        text_template = """<div style="background-color: {color}; min-width: 30px; width: fit-content; word-wrap: break-word; color: black; padding: 10px; border-radius: 0 20px 20px 20px;">{content}</div>"""
        header = lambda text: f'<div style="background-color: #999999; color: black; border-radius: 5px 5px 0 0">{text}</div>'
        
        def render_user_item(item) -> str:
            content = item["content"]
            name = f"""<div style="background-color: #999999; color: white; width: 15px; height: 15px; padding: 10px; border-radius: 5px 5px 0 0; text-align: center;">{session.user_nickname[0]}</div>"""
            body = text_template.format(content=md.convert(content), color=USER_MESSAGE_BG_COLOR)
            return name + body
            
        def render_assistant_item(item) -> str:
            name = f"""<div style="background-color: #999999; color: white; width: 15px; height: 15px; padding: 10px; border-radius: 5px 5px 0 0; text-align: center;">{session.assistant_name[0]}</div>"""
            
            content = item["content"]
            tool = item.get("tool")
            
            if not tool:
                body = text_template.format(content=md.convert(content), color=AI_MESSAGE_BG_COLOR)
                return name + body
                
            tool_name = tool.get("name")
            tool_action = tool.get("action")
            tool_input = tool.get("input")
            
            if tool_action == "call":
                if tool_name == tools.web_search.name:
                    body = header("Search") 
                    body += '<div style="background-color: #FFFFFF; color: black;">' + tool_input + '</div>'
                    body = text_template.format(content=body, color=AI_MESSAGE_BG_COLOR)
                    return name + body
                elif tool_name == tools.exec_python.name:
                    body = header("Python")
                    body += highlight(tool_input, Python3Lexer(), HtmlFormatter())
                    body = text_template.format(content=body, color=AI_MESSAGE_BG_COLOR)
                    return name + body
            
        def render_tool_agent_item(item) -> str:
            name = f"""<div style="background-color: #999999; color: white; width: 15px; height: 15px; padding: 10px; border-radius: 5px 5px 0 0; text-align: center;">{session.assistant_name[0]}</div>"""
            tool = item.get("tool")
            tool_name = tool.get("name")
            tool_action = tool.get("action")
            tool_output = tool.get("output")
            
            if tool_name == tools.web_search.name:
                references = tool_output.get("references")
                search_result = tool_output.get("search_result")
                body = header("Documents matched the query")
                body += '<div style="background-color: #FFFFFF; color: black;">' + '<br>'.join([f'✅<a href="{url}">︎{url[:50]}...</a>' for url in references]) + '</div>'
                body = text_template.format(content=body, color=AI_MESSAGE_BG_COLOR)
                return name + body
            elif tool_name == tools.exec_python.name:
                stdout = tool_output.get("stdout")
                image = tool_output.get("image")
                caption = tool_output.get("caption")
                body = header('Output')
                if stdout:
                    body += '<pre><code>' + stdout + '</code></pre>'
                if caption:
                    body += embed_image_to_tag(image_binary=image) + '</br>'
                body = text_template.format(content=body, color=AI_MESSAGE_BG_COLOR)
                return name + body
                
        def render_file_uploader_item(item) -> str:
            name = f"""<div style="background-color: #999999; color: white; width: 15px; height: 15px; padding: 10px; border-radius: 5px 5px 0 0; text-align: center;">{session.user_nickname[0]}</div>"""
            image_output = item.get('image_output')
            image_caption = item.get('caption')
            body = ""
            if image_output:
                body += embed_image_to_tag(image_binary=image_output, width=300, alt=image_caption) + '</br>'
            body = text_template.format(content=body, color=USER_MESSAGE_BG_COLOR)
            return name + body

        role = message['role']
        if role == "User": messages.append(render_user_item(message))
        elif role == session.assistant_name: messages.append(render_assistant_item(message))
        elif role == TOOL_AGENT_NAME: messages.append(render_tool_agent_item(message))
        elif role == FILE_UPLOADER_NAME: messages.append(render_file_uploader_item(message))
        
    return ''.join(messages)
    
def print_context():
    clear_output(wait=True)

    HEIGHT, WIDTH = 1100, 800
    
    html_text = f"""<!DOCTYPE html>
<html>

<head>
  <meta charset="utf-8" />
  <style>
    #wrapper {{
      display: flex;
      flex-direction: column-reverse;
      height: {HEIGHT}px;
      width: {WIDTH}px;
      overflow-y: scroll;
    }}

    /* Custom Scrollbar CSS */
    #wrapper::-webkit-scrollbar {{
      width: 5px;
    }}

    #wrapper::-webkit-scrollbar-track {{
      background: #555555;
    }}

    #wrapper::-webkit-scrollbar-thumb {{
      background: #888;
    }}

    #wrapper::-webkit-scrollbar-thumb:hover {{
      background: #555;
    }}
    
  </style>
</head>

<body>
    <div style="display: flex; align-items: flex-end;">
      <div id="wrapper">
        <div id="contents">
    {format_to_html(session.context)}
        </div>
      </div>
    </div>
</body>

</html>
"""

    display(HTML(html_text))

def update_display() -> None:
    with session.out:
        print_context()

In [8]:
# TTS model loader

def is_tts_model_dir(path):
    if not os.path.isdir(path):
        return False
    files_to_check = [
        f"{os.path.basename(path)}.safetensors",
        "config.json",
        "style_vectors.npy",
    ]
    folder_files = os.listdir(path)
    return all(file in folder_files for file in files_to_check)



@session.debug.capture()
def load_tts_models(model_path):
    from style_bert_vits2.nlp import bert_models
    from style_bert_vits2.constants import Languages
    import gc; gc.collect()
    
    bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
    bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
    
    return TTSModel(
        model_path=os.path.join(model_path, f"{os.path.basename(model_path)}.safetensors"),
        config_path=os.path.join(model_path, "config.json"),
        style_vec_path=os.path.join(model_path, "style_vectors.npy"),
        device="cpu"
    )


async def capture_model_selection_change():
    while True:
        selected = await wait_for_change(session.dropdown, "value")
        session.tts_model = load_tts_models(os.path.join(TTS_ASSET_ROOT, selected))
        session.debug.append_stdout(f"tts_model({selected}) loaded.\n")

tts_model_names = [item for item in os.listdir(TTS_ASSET_ROOT) if is_tts_model_dir(os.path.join(TTS_ASSET_ROOT, item))]
if len(tts_model_names) != 0:
    session.tts_model = load_tts_models(os.path.join(TTS_ASSET_ROOT, tts_model_names[0]))

session.dropdown = widgets.Dropdown(
    description="TTS model", 
    options=tts_model_names, 
    value=tts_model_names[0]) if len(tts_model_names) != 0 else widgets.Dropdown(description="TTS model", options=[],
    layout=widgets.Layout(max_width="300px", width="100%"),
)

session.loop.create_task(capture_model_selection_change());

In [9]:
def time_stamp() -> str:
    import datetime
    import time
    now = datetime.datetime.now()
        
    weekday = ['Mon.', 'Tue', 'Wed.', 'Thu.', 'Fri.', 'Sat.', 'Sun.']
    return "{}-{:02}-{:02} {} {:02}:{:02}({})".format(
        now.year,
        now.month,
        now.day,
        weekday[now.weekday()],
        now.hour,
        now.minute,
        time.tzname[0],
    )

In [10]:
def load_sbv2_dict(dict_path: str = "sbv2_dict.json") -> dict:
    import json
    if not os.path.isfile(dict_path): 
        return {}
    with open(dict_path, "r") as f:
        return json.load(f)


def assistant_speaks(text, autoplay: bool = False):
    """
    A document describing current issue.
    https://github.com/jupyter/notebook/issues/3468

    According to this article, we cannot send messages larger than 10MB by default due to websocket message limit.
    Binaries of long Wave tend to exceed this limit and hence causes undefined behavior. 
    It's okay to force users to unlock the limit, but for now, it's more convinient to convert the wave binary to compressed format like MP3.
    """
    import numpy as np
    from IPython.display import Audio
    from pydub import AudioSegment
    from scipy.io import wavfile
    
    if session.tts_model is None:
        return

    # Text modification
    text = mixed2katakana(text)
    text = re.sub(r'[(（].*?[)）]', '　', text)
    text = replace_text(text, load_sbv2_dict())
    if len(text) <= 0: return

    # Generate audio as np.array
    sampling_rate, wav = session.tts_model.infer(
        text, 
        length=session.voice_length.value
    )

    # Convert to MP3
    temp_wav_path = os.path.join(CACHE_DIR, "temp.wav")
    temp_mp3_path = os.path.join(CACHE_DIR, "temp.mp3")
    wavfile.write(temp_wav_path, sampling_rate, wav.astype(np.int16))
    audio = AudioSegment.from_wav(temp_wav_path)
    audio.export(temp_mp3_path, format='mp3')

    return Audio(temp_mp3_path, autoplay=autoplay)



            
@disable_uis
def create_voice_action(sender=None):
    try:
        last_item = session.context.history()[-1]
        text = last_item.get("content")
        with session.debug:
            audio_widget = assistant_speaks(text)
        session.voice_player.clear_output(wait=True)
        with session.voice_player: 
            display(audio_widget)
    except:
        pass

In [11]:
def initialize(sender=None) -> None:
    session.context.reset()
    session.field.value = ''
    session.user_nickname_field.disabled = False
    session.assistant_name_field.disabled = False
    session.user_preamble.disabled = False
    session.template_selector.disabled = False
    set_buttons(disabled=False)
    set_guessing_image(show=False)
    tools.tool_selector.disabled = False
    
    session.context_window = 0

    update_display()

    from IPython.display import Audio
    session.voice_player.clear_output(wait=True)
    with session.voice_player:
        display(Audio(b''))

    session.initialized = True

initialize()

def start_session(sender=None) -> None:
    if session.user_nickname_field.value == "":
        session.user_nickname_field.value = DEFAULT_USER_NICKNAME
    if session.assistant_name_field.value == "":
        session.assistant_name_field.value = DEFAULT_ASSISTANT_NAME
    if session.user_nickname_field.value == session.assistant_name_field.value:
        shared_name = session.user_nickname_field.value
        session.user_nickname_field.value = f"{shared_name}_1"
        session.assistant_name_field.value = f"{shared_name}_2"
    session.user_nickname = session.user_nickname_field.value
    session.assistant_name = session.assistant_name_field.value
    session.user_nickname_field.disabled = True
    session.assistant_name_field.disabled = True
    session.user_preamble.disabled = True
    session.template_selector.disabled = True
    session.login_time_stamp = time_stamp()
    tools.tool_selector.disabled = True

    session.initialized = False

In [12]:
def interrupt_generation() -> None:
    """
    Currently this do nothing since llama-cpp-python doesn't provide the functionality. 
    Related PRs:
        https://github.com/abetlen/llama-cpp-python/pull/733
        https://github.com/abetlen/llama-cpp-python/issues/599
    """
    with session.debug: 
        print("Generation interrupted.")


        
def predict_stream(force_direct_answer: bool = False) -> dict:
    """
        Returns:
            dict: {"tool": tool , "tool_input": tool_input}
    """
    from llama_cpp import LlamaGrammar
    
    params = session.generation_params
    prefix = params.prefix.value

    any_string = r'([^\n]|"\n")+'

    
    tools_gbnf = "(" + "|".join(
        [f'"<tool>{tool.name}</tool>" "<tool_input>"{tool.grammar}"</tool_input>"' for tool in session.tools()]
    ) + ")"
    tool_use_gbnf = f"root ::= {tools_gbnf}"

    
    direct_answer_gbnf = f'''root ::= "<tool>" name "</tool>" "<tool_input>" input "</tool_input>"
name ::= "direct_answer"
input ::= string
string ::= {any_string}
'''
    
    direct_answer_with_prefix_gbnf = f'''root ::= "<tool>" name "</tool>" "<tool_input>{prefix}" input "</tool_input>"
name ::= "direct_answer"
input ::= string
string ::= {any_string}
'''
    grammar: LlamaGrammar
    if force_direct_answer: 
        grammar = LlamaGrammar.from_string(direct_answer_gbnf)
    elif prefix != "":
        grammar = LlamaGrammar.from_string(direct_answer_with_prefix_gbnf)
    else:
        grammar = LlamaGrammar.from_string(tool_use_gbnf)
    
    prompt = prompt_builder.build()
    
    with session.debug:
        print(f'prompt = """{prompt}"""')
        
    streamer = session.model.stream(
        input=prompt,
        temperature=params.temperature.value,
        top_p=params.top_p.value,
        top_k=params.top_k.value,
        frequency_penalty=params.frequency_penalty.value,
        presence_penalty=params.presence_penalty.value,
        repeat_penalty=params.repeat_penalty.value,
        max_tokens=session.max_gen_tokens,
        grammar=grammar,
        stop=["</tool_input>"]+prompt_builder.stops(),
    )

    def parse_tool_tag(text: str) -> str|None:
        pattern = r"<tool>({tools})</tool>".format(tools='|'.join([tool.name for tool in session.tools()]))
        match = re.search(pattern, text)
        return match.group(1) if match else None

    def parse_unclosed_tool_input_tag(text: str) -> str:
        pattern = re.compile("<tool_input>(.*)", re.DOTALL)
        match = re.search(pattern, text)
        return match.group(1) if match else ""
        
    def rstrip_seq(text: str, seq: str) -> str:
        for i in range(len(seq)):
            incomplete_seq = seq[:len(seq) - i]
            if text.endswith(incomplete_seq):
                return text[:-len(incomplete_seq)]
        return text
        
    def lstrip_seq(text: str, seq: str) -> str:
        for i in range(len(seq) + 1):
            incomplete_seq = seq[:len(seq) - i]
            print(incomplete_seq)
            if text.startswith(incomplete_seq):
                return text[len(incomplete_seq):]
        return text


    text =  ""
    for token in streamer:
        text += token
        tool = parse_tool_tag(text)
        tool_input = parse_unclosed_tool_input_tag(text)
        tool_input = rstrip_seq(tool_input, "</tool_input>")
        if not params.append_prefix.value:
            tool_input = lstrip_seq(tool_input, prefix)
        yield {"tool": tool , "tool_input": tool_input.strip()}


def predict_stream_with_display(force_direct_answer: bool = False):
    try:
        with session.debug:
            for response in predict_stream(force_direct_answer):
                if response["tool"] in [tool.name for tool in tools.tools() if tool.name != tools.direct_answer.name]: 
                    session.context.push({
                        "role": session.assistant_name,
                        "content": "",
                        "tool": {
                            "action": "call",
                            "name": response["tool"],
                            "input": response["tool_input"],
                        }
                    })
                else:
                    session.context.push({
                        "role": session.assistant_name, 
                        "content": response["tool_input"],
                    })
                update_display()
                session.context.force_pop_front()
    except KeyboardInterrupt:
        # This works fine in most of the cases, but if interruption occurs between push and pop, the behavior is undeifned.
        interrupt_generation()
    finally:
        session.context.push_message(session.assistant_name, response["tool_input"])
        update_display()
        session.context.force_pop_front()
        return response

In [13]:
def shift_context_by_item(n_items: int = 1) -> None:
    """
    Shift KV Cache by amount of n_items leaving system prompts.
    Calling this function without decrement context_window result in failure of context shifting.

    Implementation of StreamingLLM in oobabooga/text-generation-webui:
    https://github.com/oobabooga/text-generation-webui/blob/main/modules/cache_utils.py#L24
    According to this implementation, prefix (=n_sys_ppt_token) corresponds to the "Attention Sinks" in terms of StreamingLLM.
    
    Args:
        n_items (int, default 1): Number of items to shift, but NOT the number of "tokens".
    Returns:
        None
    """
    from lib.infrastructure import kv_cache_seq_ltrim
    
    for i in range(n_items):
        trimmed_history = session.context.history()[-session.context_window:]
        oldest_item = trimmed_history[i]
        
        n_sys_ppt_tokens = session.model.token_count(prompt_builder.render_instruction()) + 1 # +1 for bos token.
        n_oldest_ctx_tokens = session.model.token_count(prompt_builder.render_item(oldest_item))
        
        kv_cache_seq_ltrim(
            model=session.model.llama, 
            n_keep=n_sys_ppt_tokens,
            n_discard=n_oldest_ctx_tokens,
        )
        
        with session.debug: 
            print(f"{n_oldest_ctx_tokens} tokens discarded from position {n_sys_ppt_tokens}.")
            print(f"discarded: {prompt_builder.render_item(oldest_item)}")
    

def push_context(item, auto_shift_kv: bool) -> None:
    """
    Args:
        item: new item to push.
        auto_shift_kv(bool): if true, shift KV-Cache and input_id if needed.
    Returns:
        None
    """
    session.context.push(item)

    if not auto_shift_kv:
        return

    session.context_window += 1
    
    if session.model.token_count(prompt_builder.build()) < session.n_ctx - session.max_gen_tokens:
        return
        
    while session.model.token_count(prompt_builder.build()) >= session.n_ctx - session.max_gen_tokens:
        # The order matters since shift_context_by_item uses context_window internally.
        # Dont decrement context_windows before calling shift_context_by_item.
        if session.context_window >=2:
            shift_context_by_item(1)
            session.context_window -= 1
        else:
            session.context_window = 0
            break


def retrieve_latest_input(sender = None):
    """
    Scheme for context management when retrieving item:
              |        |
    [0, 1, 2, 3, 4, 5, 6]
              |     | <- shrink context window 
    [0, 1, 2, 3, 4, 5]
    """
    session.context_window = max(session.context_window - 1, 0)

    if len(session.context) <= 0:
        session.field.value = ''
        update_display()
        return

    last_item = session.context.force_pop_front(stringize=False)
    if last_item['role'] == "User":
        session.field.value = last_item['content']
    update_display()
    return

In [14]:
def model_not_loaded_error() -> None:
    push_context({
        "role": session.assistant_name, 
        "content": "Model is not loaded yet. Please select and load a local model before inference."
    }, auto_shift_kv=False)
    update_display()

In [15]:
# Main loop.
@disable_uis
def main(sender=None) -> None:
    from lib.utils import reformat_python_code

    def push_context_wrapper(item) -> None:
        push_context(item, auto_shift_kv=True)
        update_display()
        
    def export_log():
        import datetime
        filename = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S.txt")
        with open(os.path.join(LOG_DIR, filename), 'w') as f:
           f.writelines(str(session.context))

    if session.model == None:
        model_not_loaded_error()
        return

    # Register user's input.
    user_message = session.field.value
    if session.model.token_count(user_message) >= MAX_USER_TOKENS: 
        return
    if user_message == "": 
        return
    if session.initialized:
        start_session()
    
    session.field.value = ''
    push_context_wrapper({
        "role": "User", 
        "content": user_message
    })

    
    response = predict_stream_with_display()
    

    if response["tool"] == "direct_answer":
        push_context_wrapper({
            "role": session.assistant_name, 
            "content": response["tool_input"],
        })
        export_log()
        return

    # This is special behavior of "exec_python".
    if response["tool"] == "exec_python":
        response["tool_input"] = reformat_python_code(response["tool_input"])

    # Register tool call in the context.
    push_context_wrapper({
        "role": session.assistant_name,
        "content": "",
        "tool": {
            "action": "call",
            "name": response["tool"],
            "input": response["tool_input"],
        }
    })

    # Run tools.
    push_context_wrapper({
        "role": TOOL_AGENT_NAME,
        "content": "",
        "tool": {
            "action": "return",
            "name": response["tool"],
            "output": tools.run_tool(response["tool"], response["tool_input"]),
        }
    })
        
    # Add reaction to the tool results.
    reaction = predict_stream_with_display(force_direct_answer=True)
    push_context_wrapper({
        "role": session.assistant_name, 
        "content": reaction["tool_input"]
    })

    export_log()

In [16]:
# Multimodal

def get_caption(img):
    import io
    from lib.multimodal import Florence2Large
    from PIL import Image
    
    with session.debug:
        if isinstance(img, bytes):
            img = Image.open(io.BytesIO(img))
            img = img.convert("RGB")
        
        i2t = Florence2Large(use_accelerator=False)
        i2t.load_model()
        caption = i2t.get_caption(img)
        
    del i2t
    return caption


@disable_uis
def generate_caption(args) -> None:
    from PIL import Image
    import io

    if len(session.upload_file.value) == 0:
        return

    if session.model == None:
        model_not_loaded_error()
        session.upload_file.value = tuple()
        return
        
    if session.initialized:
        start_session()

    try:
        img_bytes = session.upload_file.value[0]["content"].tobytes()
        img = Image.open(io.BytesIO(img_bytes))
        img = img.convert("RGB")
        
        buffer = io.BytesIO()
        img.save(buffer, format="JPEG")
        img_binary = buffer.getvalue()
        
        push_context({
            "role": FILE_UPLOADER_NAME,
            "content": "",
            "image_output": img_binary,
            "caption": get_caption(img),
        }, auto_shift_kv=True)
    finally:
        session.upload_file.value = tuple()
        update_display()

In [17]:
# Define UI actions.
session.button.on_click(main)
session.reset_button.on_click(initialize)
session.retrieve.on_click(retrieve_latest_input)
session.create_voice.on_click(create_voice_action)
session.load_button.on_click(reload_model)
session.unload_button.on_click(unload_model)
session.reflesh_list_button.on_click(refresh_gguf_list)
session.upload_file.observe(generate_caption, names='value')

In [18]:
# Combine widgets.

HBox = widgets.HBox

    
def show_top() -> None:
    html = f"""<h1>Integrative LLM Chat UI for Jupyter</h1>
<h3>Features</h3>
1. voice synthesis</br>
2. web search-based RAG</br>
3. python env</br>
4. image-to-text model-based image recognition</br>
</br>
To maximize the quality, I strongly recommend you to use models optimized for RAG/tools.</br>
If you wanna use unsupported prompt template, define and register subclass of <code>PromptBuilderBase</code>.</br>
<h3>Directories</h3>
<code>./{GGUF_DIR}</code>: Directory to put GGUF models.</br>
<code>./{AGENT_WORKING_DIR}</code>: Working directory when executing Python code.</br>
<code>./{TTS_ASSET_ROOT}</code>: Directory to put TTS models.<br> 
(The subdirectory name containing the model should match the *.safetensors file name)</br>
</br>
"""
    display(HTML(html))


def show_gguf_loader() -> None:
    options = session.llama_cpp_options
    
    def output_memory_usage(output_widget: widgets.Output) -> None:
        from threading import Thread
    
        def job() -> None:
            from psutil import virtual_memory
            import time
            nonlocal output_widget
            
            while True:
                time.sleep(2)
                prop = virtual_memory().percent/100
                bar_num = int(20*prop)
                bar = "[" + "|"*bar_num + " "*(20-bar_num) + "]"
                output_widget.outputs = ({
                    'name': 'stdout', 
                    'text': "RAM " + bar + f"{virtual_memory().percent}%", 
                    'output_type': 'stream'
                },)
    
        th = Thread(target=job)
        th.start()

    mem_viewer = widgets.Output()
    output_memory_usage(mem_viewer)
    
    display(
        HBox([options.define_n_ctx, options.define_max_gen_tokens, options.n_gpu_layers]),
        HBox([options.flash_attention, options.quantize_kv]),
        HBox([options.gguf_selector, session.load_button, session.unload_button, session.reflesh_list_button]),
        session.active_gguf_viewer,
        mem_viewer,
    )

def show_session_options() -> None:
    display(
        HBox([session.template_selector,session.streamingllm]),
        HBox([session.user_nickname_field, session.assistant_name_field]),
        tools.tool_selector.display(),
        HBox([session.user_preamble]),
    )

def show_main_context_window() -> None:
    display(session.out)

def show_generation_params() -> None:
    params = session.generation_params
    
    display(
        HBox([params.temperature, params.top_k]),
        HBox([params.top_p, params.frequency_penalty]),
        HBox([params.presence_penalty, params.repeat_penalty]),
        HBox([params.prefix, params.append_prefix]),
    )

def show_input_field() -> None:
    display(
        HBox([session.field, session.button, session.guessing_image]),
        HBox([session.upload_file, session.retrieve, session.reset_button]),
    )

def show_voice_synthesis() -> None:
    display(
        session.voice_player,
        HBox([session.dropdown, session.voice_length]),
        HTML("Edit <code>sbv2_dict.json</code> to configure words replacement."),
        session.create_voice,
    )

def show_debug_window() -> None:
    clear = widgets.Button(description="Clear")
    def clear_debug_window(sender) -> None:
        session.debug.clear_output()
    clear.on_click(clear_debug_window)
    display(session.debug, clear)

In [19]:
ignore_warnings()

top_output = widgets.Output()
gguf_load_output = widgets.Output()
session_option_output = widgets.Output()
generation_params_output = widgets.Output()
voice_synthesis_output = widgets.Output()
debug_output = widgets.Output()

with top_output: show_top()
with gguf_load_output: show_gguf_loader()
with session_option_output: show_session_options()
with generation_params_output: show_generation_params()
with voice_synthesis_output: show_voice_synthesis()
with debug_output: show_debug_window()


display(widgets.Accordion(
    children=[top_output, gguf_load_output, session_option_output, debug_output], 
    titles=["About This App", "GGUF Loader", "Session Options", "Debug"]
))
show_main_context_window()
display(widgets.Accordion(
    children=[generation_params_output], 
    titles=["Generation Params"]
))
show_input_field()
display(widgets.Accordion(
    children=[voice_synthesis_output], 
    titles=["Text-to-Speech"]
))

Accordion(children=(Output(), Output(), Output(), Output()), titles=('About This App', 'GGUF Loader', 'Session…

Output()

Accordion(children=(Output(),), titles=('Generation Params',))

HBox(children=(Textarea(value='', layout=Layout(height='auto', max_width='700px', width='100%'), placeholder='…

HBox(children=(FileUpload(value=(), accept='.png,.jpg,.jpeg,.gif,.bmp', description='Image', layout=Layout(wid…

Accordion(children=(Output(),), titles=('Text-to-Speech',))