In [1]:
from app.exllamav2_chat.model import ChatExllamaV2Model

In [2]:
model_path = "elyza/exl2/"

chat_model = ChatExllamaV2Model.from_model_dir(
    model_path,
    cache_max_seq_len=4096,
    system_message_template="[INST] <<SYS>>\n{}\n<</SYS>>\n",
    human_message_template="{}[/INST]",
    ai_message_template="{}",
    temperature=0.0001,
    # max_new_tokens=1024,
    max_new_tokens=128,
    repetition_penalty=1.15,
    low_memory=True,
    cache_8bit=True,
)

In [3]:
from langchain.prompts import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough


system_template = "あなたは誠実で優秀な日本人のアシスタントです。"

prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template(system_template),
        HumanMessagePromptTemplate.from_template("{query}"),
        AIMessagePromptTemplate.from_template(" "),
    ]
)

chain = {"query": RunnablePassthrough()} | prompt | chat_model | StrOutputParser()

In [4]:
from typing import Any
from langchain_core.callbacks.base import BaseCallbackHandler


def handler_print(token: str):
    print(token, sep="", end="")


class StreamingCallbackHandlerSimple(BaseCallbackHandler):
    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        handler_print(token)

In [5]:
from langchain_core.runnables.config import RunnableConfig


config = RunnableConfig(callbacks=[StreamingCallbackHandlerSimple()])

In [9]:
query = "富士山の高さは？正確に"
async for s in chain.astream(query):
    print(s, end="", flush=True)

富士山の高さは3,776mで、富士山が立地し、南北朝方向に長く、東西にややせた形状を成り、平野部は少々見る。

富士山は、三保山・大石田山と共に「富士山群」を成り、富士五湖の水源地を形作る。

富士山は、古来よき神社の聖地と信じれらる。

富士山は、登山者