In [1]:
#!pip install ipywidgets jupyter-ui-poll
#!pip install llama-cpp-python langchain
#!pip install style-bert-vits2 alkana
#!pip install pandas black
#!pip install google-api-python-client selenium beautifulsoup4 pdfplumber

In [2]:
from warnings import simplefilter
ignore_warnings = lambda : simplefilter('ignore')

import os
import llama_cpp
import re
import json
import threading
import asyncio
from os import path

# For ipywidgets.
from IPython.display import clear_output, HTML
from ipywidgets import widgets

# my libs
from lib.infrastructure import ForgetableContext, LlamaCpp, temporal_llama_cache
from lib.uis import activate_cancel_ui, wait_for_change
from lib.utils import now, mixed2katakana, get, fix_indentation
from lib.infrastructure import LlamaInputIDManager

# llama-cpp-python
from llama_cpp import LlamaDiskCache

# style-bert-vits2
from style_bert_vits2.nlp import bert_models
from style_bert_vits2.constants import Languages
from style_bert_vits2.tts_model import TTSModel

  torchaudio.set_audio_backend("soundfile")


In [4]:
model = LlamaCpp(
    model_path="PATH_TO_YOUR_GGUF_FILE.gguf",
    n_gpu_layers=-1,
    n_batch=1024,
    n_ctx=8192,
    use_mlock=True,
    verbose=False,
    embedding=False,
)
internal_llama = model.model.model

In [81]:
# Globals.
from lib.tools import MyPythonREPL

DEFAULT_ASSISTANT_NAME = 'ミク'
DEFAULT_USER_NICKNAME = 'せんせ'

MAXLEN = 70 
assistant_name = DEFAULT_ASSISTANT_NAME
user_nickname = DEFAULT_USER_NICKNAME
context = ForgetableContext(maxlen=MAXLEN)
py = MyPythonREPL(replace_nl=False, temporal_working_directory='agent_working_dir')
login_time_stamp: str = ""

# Event loop.
loop = asyncio.get_event_loop()

input_id_manager = LlamaInputIDManager(internal_llama)

# style-bert-vits2
tts_model: TTSModel|None = None

# Set caches.
main_cache = LlamaDiskCache(".cache/cmr/main_cache", 10*10e9) # 20GB
agent_cache = LlamaDiskCache(".cache/cmr/agent_cache", 2*10e9) # 2GB

# Use your own API_KEY and CSE_ID.
GOOGLE_API_KEY = "YOUR_GOOGLE_API_KEY"
GOOGLE_CSE_ID = "YOUR_GOOGLE_CSE_ID"

In [82]:
internal_llama.cache = main_cache

In [83]:
# GUIs.
out = widgets.Output()
debug = widgets.Output(layout=widgets.Layout(width='600px', height='100px', overflow='scroll'))
field = widgets.Textarea(placeholder=f'ユーザー:', layout=widgets.Layout(width='490px', height='50px'))
user_nickname_field = widgets.Text(description='AI->you', value=user_nickname, placeholder='Your nickname', layout=widgets.Layout(width='200px'))
assistant_name_field = widgets.Text(description='You->AI', value=assistant_name, placeholder='Assistant name', layout=widgets.Layout(width='200px'))
button = widgets.Button(description='📤', button_style='success', layout=widgets.Layout(width='50px', height='50px'))
reset_button = widgets.Button(description='Reset', layout=widgets.Layout(width='120px'))
retrieve = widgets.Button(description='Undo', layout=widgets.Layout(width='120px'))
create_voice = widgets.Button(description='Synthesize voice', layout=widgets.Layout(width='120px'))
buttons = [button, reset_button, retrieve, create_voice]
voice_player = widgets.Output()
dropdown: widgets.Dropdown|None = None

In [84]:
# TTS model loader


def is_tts_model_dir(path):
    if not os.path.isdir(path):
        return False
    files_to_check = [
        f"{os.path.basename(path)}.safetensors",
        "config.json",
        "style_vectors.npy",
    ]
    folder_files = os.listdir(path)
    return all(file in folder_files for file in files_to_check)



@debug.capture()
def load_tts_models(model_path):
    import gc; gc.collect()
    
    bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
    bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
    
    return TTSModel(
        model_path=os.path.join(model_path, f"{os.path.basename(model_path)}.safetensors"),
        config_path=os.path.join(model_path, "config.json"),
        style_vec_path=os.path.join(model_path, "style_vectors.npy"),
        device="cpu"
    )

ASSET_ROOT = "style_bert_vits2_models"

async def capture_model_selection_change():
    global tts_model
    while True:
        selected = await wait_for_change(dropdown, "value")
        tts_model = load_tts_models(os.path.join(ASSET_ROOT, selected))
        debug.append_stdout(f"tts_model({selected}) loaded.")

tts_model_names = [item for item in os.listdir(ASSET_ROOT) if is_tts_model_dir(os.path.join(ASSET_ROOT, item))]
if len(tts_model_names) != 0:
    tts_model = load_tts_models(os.path.join(ASSET_ROOT, tts_model_names[0]))
dropdown = widgets.Dropdown(description="TTS model", options=tts_model_names, value=tts_model_names[0]) if len(tts_model_names) != 0 else widgets.Dropdown(description="TTS model", options=[])
loop.create_task(capture_model_selection_change())

<Task pending name='Task-210' coro=<capture_model_selection_change() running at /var/folders/wz/g2cjflgj1g1bkvbnp4127h5h0000gn/T/ipykernel_91869/3237276757.py:33>>

In [85]:
def reload_instruction():
    global instructions
    inst_file = 'notebooks/prompt_gen_13_cmr.txt'
    with open(path.join(LLAMA_DIR, inst_file) , 'r') as f:
        lines = f.readlines()
        instructions = ''.join(lines)

reload_instruction()

In [86]:
def ctx2str(context, skip: int = 0) -> str:
    """ignore front items of number specified by 'skip'"""
    text = ''
    for item in context.context()[skip:]:
        role = get(item, 'role')
        content = get(item, 'content')
        code = get(item, 'code')
        code_output = get(item, 'code_output')
        search_query = get(item, 'search_query')
        search_result = get(item, 'search_result')

        item_text = content

        def add_block(block_name, block_content) -> str:
            return """
```{block_name}
{block_content}
```
""".format(block_name=block_name, block_content=block_content)
        
        if code: item_text += add_block("python", code)
        if code_output: item_text += add_block("output", code_output)
        if search_query: item_text += add_block("google", search_query)
        if search_result: item_text += add_block("result", search_result)

        text += {
            assistant_name: f"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{item_text}<|END_OF_TURN_TOKEN|>",
            "ユーザー": f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{item_text}<|END_OF_TURN_TOKEN|>",
        }[role]
        
    return text

In [87]:
def export_log():
    import datetime
    global context
    filename = str(datetime.datetime.now()).replace(' ', '_') + '.txt'
    with open(f'log/{filename}', 'w') as f:
       f.writelines(str(context))


In [88]:
tool_description = {
    "google": "検索ワードを引数としてgoogle検索を行う。リアルタイム情報を得る場合には必ずこのツールを使用すること。",
    "python": "Pythonコードを記述・実行します。コードは適宜改行すること。変数や関数はセッション内で共有されます。結果は必ずprintあるいはplt.show()で出力すること。",
}
available_tools = list(tool_description.keys())

In [89]:
def compile_instruction(instructions):
    return instructions.format(
        now=login_time_stamp, 
        assistant=assistant_name, 
        user_nickname=user_nickname,
        tool_names=available_tools,
        tool_description='\n'.join([f"{name}: {description}" for name, description in tool_description.items()]),
    )

def make_ppt(
    instructions: str, 
    context,
    skip: int = 0,
) -> str:
    # llama_cpp's tokenizer seems to add bos token automatically?
    # https://gist.github.com/kohya-ss/37f4c5ef8171cbb2b6cc1f4fd7999b89
    # instructions = "<BOS_TOKEN>" + instructions
    return compile_instruction(instructions) + ctx2str(context, skip) + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"

In [90]:
def assistant_speaks(text) -> None:
    global tts_model

    if tts_model is None:
        return
    
    def replace_text(text, replacement_dict):
        for key, value in replacement_dict.items():
            text = text.replace(key, value)
        return text
    
    text = mixed2katakana(text)
    text = re.sub(r'[(（].*?[)）]', '　', text)

    text = replace_text(text, {
        'えへへ': 'えへっ',
        '言って': 'いって',
        '℃': '度',
        'kg': 'キロ',
        '♪': '。',
        '：': '。',
        ':': '。',
    })
    

    try:
        from IPython.display import Audio
        if len(text) > 0:
            with debug: sr, wav = tts_model.infer(text, length=1.15)
        else:
            return
        audio = Audio(wav, rate=sr, autoplay=False)
        voice_player.clear_output(wait=True)
        with voice_player: display(audio)
    except BaseException as e:
        with debug: print(e)

In [91]:
# LangChain.agent

ignore_warnings()

from langchain.tools import Tool
from lib.lang_chain_agent_cmr import create_agent_executor
from lib.lang_chain_tools import GoogleSearchOpenable

google = GoogleSearchOpenable(
    n_results=3, 
    api_key=GOOGLE_API_KEY,
    cse_id=GOOGLE_CSE_ID,
)

tools = [
    Tool(
        name="google", 
        description="検索ワードを指定し、Web検索を行うツールです。検索後はselectで各記事を開く事ができます。", 
        func=google.set
    ),
    Tool(
        name="select", 
        description="googleで検索した番号をAction Inputに[検索結果1]のように指定して、ページを開きます。", 
        func=google.open
    ),
]

agent_executor = create_agent_executor(
    model, 
    tools, 
    max_iterations=6,
    return_intermediate_steps=True,
    model_kwargs={'temperature': 0.1, 'max_tokens': 1500},
)

@temporal_llama_cache(internal_llama, agent_cache)
def exec_agent(model, request: str) -> tuple:
    global google, agent_executor
    def cleanup_agent_output(text: str) -> str:
        return text.split('Question')[0].split('Thought')[0].strip()
    
    google.unset()
    
    response = agent_executor.invoke(request)
    
    output = cleanup_agent_output(response['output'])
    ref = google.references()
    
    return output, ref

In [92]:
def format_to_html(context) -> str:
    def embed_image_to_tag(image_binary) -> str:
        import base64
        encoded_image = base64.b64encode(image_binary).decode('utf-8')
        html_image_tag = f'<img src="data:image/jpeg;base64,{encoded_image}" />'
        return html_image_tag

    def replace_charref(text) -> str:
        return text.replace('<', '&lt;').replace('>', '&gt;').replace(' ', '&nbsp;')

    messages: list[str] = []
    for message in context.history():
        role = message['role']
        content = message['content']

        name: str
        text: str
        text_template = """<div style="background-color: {color}; word-wrap: break-word; color: black; padding: 10px; border-radius: 20px;">{content}</div>"""
        
        if role == 'ユーザー':
            content = replace_charref(content)
            content = content.replace('\n', '</br>')
            text = text_template.format(content=content, color='#BBFFBB')
            name = '<font color=#888888>ユーザー</font>'
        elif role == assistant_name:
            text =replace_charref(content)

            references = get(message, 'references')
            code = get(message, 'code')
            code_output = get(message, 'code_output')
            image_output = get(message, 'image_output')
            search_result = get(message, 'search_result')
            search_query = get(message, 'search_query')

            if search_result:
                text += '<div style="background-color: #999999; color: black;">' + f'検索(query="{search_query}")' + '</div>'
                text += '<div style="background-color: #FFFFFF; color: black;">' + search_result + '</div>'
            if references:
                text += '<div style="background-color: #999999; color: black;">参考</div>'
                text += '<div style="background-color: #FFFFFF; color: black;">' + ''.join([f'<a href="{url}">⚫︎</a>' for url in references]) + '</div>'
            if code:
                header = '<div style="background-color: #999999; color: black;">{text}</div>'
                text += header.format(text='Python')
                text += '<pre><code>' + code + '</code></pre>'
                text += header.format(text='Output')
                text += '<pre><code>' + (code_output if code_output else 'Empty stdout/stderr.') + '</code></pre>'
            if image_output:
                text += '</br>' + embed_image_to_tag(image_binary=image_output) + '</br>'

            # Other code blocks.
            regex = re.compile(r'```(\w+)?[ \n](.*?)\n?```', re.DOTALL)
            text = re.sub(
                regex, 
                r'<div style="background-color: #999999; color: black;">\1</div><pre><code>\2</code></pre>', 
                text
            )
            
            
            text = text.replace('\n', '</br>')
            text = text_template.format(content=text, color='#FFEEBB')
            name = f'<font color=#888888><div style="text-align:right">{role}</div></font>'
        
        messages.append(f'{name}{text}')
        
    return ''.join(messages)

In [93]:
def print_context():
    global context
    
    clear_output(wait=True)
    
    html_text="""<!DOCTYPE html>
<html>

<head>
  <meta charset="utf-8" />
  <style>
    #wrapper {{
      display: flex;
      flex-direction: column-reverse;
      height: 1600px;
      width: 800px;
      overflow-y: scroll;
    }}

    /* Custom Scrollbar CSS */
    #wrapper::-webkit-scrollbar {{
      width: 10px;
    }}

    #wrapper::-webkit-scrollbar-track {{
      background: #f1f1f1;
    }}

    #wrapper::-webkit-scrollbar-thumb {{
      background: #888;
    }}

    #wrapper::-webkit-scrollbar-thumb:hover {{
      background: #555;
    }}
    
  </style>
</head>

<body>
    <div style="display: flex; align-items: flex-end;">
      <div id="wrapper">
        <div id="contents">
    {content}
        </div>
      </div>
    </div>
</body>

</html>
"""
    html_text = html_text.format(
        content=format_to_html(context), 
    )

    display(HTML(html_text))

In [94]:
guessing_image: widgets.Image

with open('guessing.gif', 'rb') as f:
    guessing_image = widgets.Image(value=f.read(), width=50, height=50)

def set_guessing_image(show: bool) -> None:
    global guessing_image
    with open('guessing.gif' if show else 'empty.png', 'rb') as f:
        guessing_image.value = f.read() 
        guessing_image.width = 50
        guessing_image.height = 50

set_guessing_image(False)

In [95]:
def set_buttons(disabled: bool) -> None:
    global buttons
    for b in buttons:
        b.disabled = disabled

def disable_uis(func):
    from functools import wraps
    @wraps(func)
    def wrapper(*args, **kwargs):
        set_buttons(disabled=True)
        set_guessing_image(True)
        result = func(*args, **kwargs)
        set_guessing_image(False)
        set_buttons(disabled=False)
        return result
    return wrapper

In [96]:
def initialize(sender=None) -> None:
    global context, field, assistant_name_field, user_nickname_field, login_time_stamp

    context.reset()
    field.value = ''
    user_nickname_field.disabled = False
    assistant_name_field.disabled = False
    login_time_stamp = now()
    
    with out:
        print_context()

    from IPython.display import Audio
    voice_player.clear_output(wait=True)
    with voice_player:
        display(Audio(b''))

initialize()

def submit_names_action(sender=None) -> None:
    global user_nickname, user_nickname_field, assistant_name, assistant_name_field

    user_nickname = user_nickname_field.value
    assistant_name = assistant_name_field.value
    user_nickname_field.disabled = True
    assistant_name_field.disabled = True

@out.capture()
def retrieve_latest_input(sender = None):
    global field, context

    if len(context) <= 0:
        field.value = ''
        print_context()
        return

    last_item = context.force_pop_front(stringize=False)
    if last_item['role'] == 'ユーザー':
        field.value = last_item['content']
    print_context()
    return

In [97]:
def force_save_state():
    global internal_llama
    prompt = make_ppt(instructions, context)
    prompt_tokens = ((internal_llama.tokenize(prompt.encode("utf-8"), special=True)))
    #
    # save to main_cache
    # using prompt_tokens(+completion_tokens) as key of the cache depends on implementation of llama_cpp's _create_complesion.
    # The code: https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py#L954 
    #
    internal_llama.cache[prompt_tokens] = internal_llama.save_state()

In [98]:
def predict_stream(additional_stop_tokens: list[str] = []):
    global model, instructions, context

    prompt = make_ppt(instructions, context)
    
    streamer = model.stream(
        input=prompt,
        temperature=0.2,
        max_tokens=1500,
        stop=additional_stop_tokens
    )

    pattern = re.compile(
        r'(.*?```\n?({tools}).*?```)'.format(tools='|'.join(available_tools)), 
        re.DOTALL
    )

    output = ''
    for token in streamer:
        output += token
        
        match = re.search(pattern, output)
        if match:
            output = match.group(1)
            yield output
            # Because of break, token generation doesn't complete properly.
            # The inproper interruption results in no activation of diskcache saving that normally happens after completion.
            # So, we need to add code to force saving LlamaState.
            force_save_state()
            break 
        else:
            yield output.strip()
        
        

def predict_stream_with_display(additional_stop_tokens: list[str] = []):
    with debug:
        for output in predict_stream(additional_stop_tokens):
            context.push_message(assistant_name, output)
            with out: print_context()
            context.force_pop_front()
            
    context.push_message(assistant_name, output)
    with out: print_context()
    context.force_pop_front()
    return output

In [99]:
def message_chain(reply):
    global assistant_name, model, context
    context.push_message(assistant_name, reply)
    with out: print_context()


def python_chain(python_code):

    # Reformat code.
    import black
    python_code = fix_indentation(python_code)
    try: python_code = black.format_str(python_code, mode=black.Mode())
    except: pass
    
    # Run code.
    py.unset(keep_locals=True)
    py.run(python_code)
    code, code_output, image_output = py.result()

    # Register into context.
    last_item = context.force_pop_front(stringize=False)
    last_item['code'] = code
    last_item['code_output'] = code_output if code_output else "Empty stdout/stderr."
    last_item['image_output'] = image_output
    context.push(last_item)
    with out: print_context()

    # Add reaction to the execution results.
    reply = predict_stream_with_display(additional_stop_tokens=["```"])
    context.push_message(assistant_name, reply)
    with out: print_context()



def search_chain(search_query):
    global context, assistant_name, model

    # Show confirmation of search to the user.
    context.push_message(assistant_name, '🔍「' + search_query + '」で検索します。')
    with out: print_context()
    with out: cancel = activate_cancel_ui(wait_sec=10)
    if cancel:
        retrieve_latest_input()
        return

    
    # Register search result to the context.
    with out: print_context()
    with debug:
        search_result, referred_urls = exec_agent(model, f'"{search_query}"で検索して内容をまとめてください。検索ワードは指示通りとし、変更してはいけません。ふたつは検索結果を開き内容を確認すること。')
    context.force_pop_front()
    last_item = context.force_pop_front(stringize=False)
    last_item['search_query'] = search_query
    last_item['search_result'] = search_result
    last_item['references'] = referred_urls
    context.push(last_item)
    with out: print_context()

    
    # Add reaction to the search results.
    reply = predict_stream_with_display(additional_stop_tokens=["```"])
    context.push_message(assistant_name, reply)
    with out: print_context()

In [100]:
# Main loop.
@disable_uis
def main(sender=None) -> None:
    global  context, field, submit_names

    submit_names_action()

    user_message = field.value
    if user_message != '':
        field.value = ''
        context.push_message('ユーザー', user_message)
    with out: print_context()
        
    output = predict_stream_with_display()

    # Parse tool.
    regex = re.compile(
        r'```({tools})(.*)```'.format(tools='|'.join(available_tools)), 
        re.DOTALL
    )
    tool = re.search(regex, output)

    # Remove tool tags.
    cleaned_output = re.sub(regex, '', output).rstrip()
    message_chain(cleaned_output)
    
    # Passing to extra chain.
    if tool:
        tool_type, tool_input = tool.group(1), tool.group(2).strip()
        {
            'python': python_chain,
            'google': search_chain,
        }[tool_type](tool_input)

    export_log()

In [101]:
@disable_uis
def create_voice_action(sender=None):
    try:
        text = get(context.history()[-1], "content")
        assistant_speaks(text)
    except:
        pass

In [102]:
# Define UI actions.
button.on_click(main)
reset_button.on_click(initialize)
retrieve.on_click(retrieve_latest_input)
create_voice.on_click(create_voice_action)

In [103]:
def show_guis() -> None:
    html = """<h2>{default_name}ちゃんとおしゃべり!(仮)</h2>
おしゃべりやウェブ検索・Python実行を利用したQ&Aができます。</br>
agent_working_dir: Python実行時のワーキングディレクトリ。</br>
style_bert_vits2_models: TTSモデルを入れるディレクトリ。</br>
(モデルのあるサブディレクトリ名と.safetensorsの名前は一致している必要があります)</br>
</br>
""".format(default_name=DEFAULT_ASSISTANT_NAME)
    HBox = widgets.HBox
    display(
        HTML(html),
        HBox([user_nickname_field, assistant_name_field]),
        out,
        HBox([field, button, guessing_image]),
        HBox([retrieve, reset_button, create_voice]),
        dropdown,
        voice_player,
        debug,
    )

In [104]:
ignore_warnings()
show_guis()

HBox(children=(Text(value='せんせ', description='AI->you', layout=Layout(width='200px'), placeholder='Your nickna…

Output()

HBox(children=(Textarea(value='', layout=Layout(height='50px', width='490px'), placeholder='ユーザー:'), Button(bu…

HBox(children=(Button(description='Undo', layout=Layout(width='120px'), style=ButtonStyle()), Button(descripti…

Dropdown(description='TTS model', options=(), value=None)

Output()

Output(layout=Layout(height='100px', overflow='scroll', width='600px'))