In [1]:
from langchain.llms.base import LLM
from llama_index import LLMPredictor
from typing import Optional, List, Mapping, Any
from llama_index import SimpleDirectoryReader, LangchainEmbedding, ServiceContext, Document, VectorStoreIndex
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter,SpacyTextSplitter
from llama_index.node_parser import SimpleNodeParser
from transformers import AutoModel, AutoTokenizer
!export PYTORCH_CUDA_ALLOC_CONF="0.0"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("/home/user/imported_models/chatglm-6b-20230419",trust_remote_code=True)
model = AutoModel.from_pretrained("/home/user/imported_models/chatglm-6b-20230419", trust_remote_code=True).half().cuda()
model = model.eval()

Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
Loading checkpoint shards: 100%|██████████| 8/8 [00:08<00:00,  1.12s/it]


In [3]:
class CustomLLM(LLM):
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        response, history = model.chat(tokenizer, prompt, history=[])
        return response

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"name_of_model": "chatglm-6b"}

    @property
    def _llm_type(self) -> str:
        return "custom"

In [5]:
llm_predictor = LLMPredictor(llm=CustomLLM())
# text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=100, chunk_overlap=20)
# parser = SimpleNodeParser(text_splitter=text_splitter)
# documents = SimpleDirectoryReader(input_files=['./datalevel.txt']).load_data()
# nodes = parser.get_nodes_from_documents(documents)
texts = open('./datalevel.txt', 'r', encoding='utf-8').read().split('\n\n')
documents = [Document(text) for text in texts]
embed_model = LangchainEmbedding(HuggingFaceEmbeddings(
    model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
))
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm_predictor=llm_predictor)

In [25]:
from llama_index import VectorStoreIndex
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
query_engine = index.as_query_engine(similarity_top_k=5)
query = "请说明客户信息表中，身份证号，吸烟史，是否患有糖尿病等属性属于什么安全级别?按照\"属性：安全级别数字\"的方式输出"
result = query_engine.query(query)
print(result)

属性： 安全级别数字

身份证号： 安全等级4

吸烟史： 安全等级3

是否患有糖尿病： 安全等级3


In [30]:
from llama_index import Prompt

QA_PROMPT_TMPL = (
    "{context_str}"
    "\n\n"
    "根据以上信息，回答下面的问题："
    "Q: {query_str}\n"
    )
qa_template = Prompt(QA_PROMPT_TMPL)
query_engine = index.as_query_engine(similarity_top_k=5, refine_template=qa_template)
query = "客户信息表（手机号码，身份证号，吸烟史，是否患有糖尿病）中的属性，安全级别都是多少？按照\"属性：安全级别数字\"的方式输出"
result = query_engine.query(query)
print(result)

手机号码 安全级别 3
身份证号 安全级别 3
吸烟史 安全级别 3
是否患有糖尿病 安全级别 2
