In [1]:
from llama_cpp import Llama
import streamlit as st
from langchain.llms.base import LLM
from llama_index import LLMPredictor, LangchainEmbedding, ServiceContext, PromptHelper
from typing import Optional, List, Mapping, Any
from langchain.embeddings.huggingface import HuggingFaceEmbeddings


In [2]:

MODEL_NAME = 'meta-llama/Llama-2-7b-chat-hf'
MODEL_PATH = "/home/yumk/Desktop/trained_qlora/llama2hf"
# Number of threads to use
NUM_THREADS = 8
# define prompt helper
# set maximum input size
max_input_size = 2048
# set number of output tokens
num_output = 256
# set maximum chunk overlap
chunk_overlap_ratio = 0.8

In [3]:
# The prompt helper is needed to keep query and context inside of the models context window by reducing original text
# to a ratio of its original size. (chunk_overlap_ratio). The model may process each of these chunks, and then reprocess
# the final answer by combining each chunks response together. A larger chunk size is chosen initially to save
# compute power, however if this exceeds the context window, the prompt will be retried with a smaller chunk ratio.

try:
    prompt_helper = PromptHelper(max_input_size, num_output, chunk_overlap_ratio)
except Exception as e:
    chunk_overlap_ratio = 0.2  # Set a different max_chunk_overlap value for the next attempt
    prompt_helper = PromptHelper(max_input_size, num_output, chunk_overlap_ratio)
    
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())


class CustomLLM(LLM):
    model_name = MODEL_NAME

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        p = f"Human: {prompt} Assistant: "
        prompt_length = len(p)
        llm = Llama(model_path=MODEL_PATH, n_threads=NUM_THREADS)
        output = llm(p, max_tokens=512, stop=["Human:"], echo=True)['choices'][0]['text']
        # only return newly generated tokens by slicing list to include words after the original prompt
        response = output[prompt_length:]
        st.session_state.messages.append({"role": "user", "content": prompt})
        st.session_state.messages.append({"role": "assistant", "content": response})

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"name_of_model": self.model_name}

    @property
    def _llm_type(self) -> str:
        return "custom"




Downloading (…)a8e1d/.gitattributes:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)b20bca8e1d/README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading (…)0bca8e1d/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)e1d/data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)a8e1d/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading (…)8e1d/train_script.py:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Downloading (…)b20bca8e1d/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)bca8e1d/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

In [5]:
# define our LLM
llm_predictor = LLMPredictor(llm=CustomLLM())
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, embed_model=embed_model)


def clear_convo():
    st.session_state['messages'] = []


def init():
    st.set_page_config(page_title='Local LLama', page_icon=':robot_face: ')
    st.sidebar.title('Local LLama')
    if 'messages' not in st.session_state:
        st.session_state['messages'] = []


if __name__ == '__main__':
    init()


    @st.cache_resource
    def get_llm():
        llm = CustomLLM()
        return llm

    clear_button = st.sidebar.button("Clear Conversation", key="clear")
    if clear_button:
        clear_convo()

    user_input = st.chat_input("Say something")

    if user_input:
        llm = get_llm()
        llm._call(prompt=user_input)

    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

AttributeError: module 'streamlit' has no attribute 'init_session_state'