In [1]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
from langchain.chains import LLMRequestsChain, LLMChain
from langchain import LLMChain
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains.mapreduce import MapReduceChain
from langchain.prompts import PromptTemplate
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel, AutoConfig
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from torch.mps import empty_cache
import torch
import sys

class GLM(LLM):
    max_token: int = 2048
    temperature: float = 0.8
    top_p = 0.9
    tokenizer: object = None
    model: object = None
    history_len: int = 1024

    def __init__(self):
        super().__init__()

    @property
    def _llm_type(self) -> str:
        return "GLM"

    def load_model(self, llm_device="gpu",model_name_or_path=None):
        model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True).half().cuda()

    def _call(self,prompt:str,history:List[str] = [],stop: Optional[List[str]] = None):
        response, _ = self.model.chat(
                    self.tokenizer,prompt,
                    history=history[-self.history_len:] if self.history_len > 0 else [],
                    max_length=self.max_token,temperature=self.temperature,
                    top_p=self.top_p)
        return response

    
modelpath = "/root/model/chatglm-6b"
sys.path.append(modelpath)
llm = GLM()
llm.load_model(model_name_or_path = modelpath)

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [10]:
customer_review = """公司系经上海市人民政府沪府体改审(2000)019号《关于同意设立上海置信电气股份有限公司的批复》同意,由原上海置信电气工业有限公司整体变更而来。经中国证券监督管理委员会证监发行字(2003)113号文核准,公司于2003年9月18日公开发行人民币普通股股票2,500万股。
    2020年4月。公司名称由“上海置信电气股份有限公司”变更为“国网英大股份有限公司”,英文名称由“Shanghai Zhixin Electric Co.,Ltd.”变更为“State Grid Yingda CO.,LTD.
"""

review_template = """\
For the following text, extract the following information:

公司英文名

公司中文名


Format the output as JSON with the following keys:
name
ch_name

text: {text}
"""

In [11]:
from langchain.prompts import ChatPromptTemplate

prompt_template = ChatPromptTemplate.from_template(review_template)
messages = prompt_template.format_messages(text=customer_review)
response = llm(messages[0].content)
print(response)

- 公司英文名：Shanghai Zhixin Electric Co., Ltd.
- 公司中文名：State Grid Yingda CO.,LTD.


In [13]:
from langchain.output_parsers import ResponseSchema
from langchain.output_parsers import StructuredOutputParser
enname_schema = ResponseSchema(name="name",
                             description="公司英文名")
zhname_schema = ResponseSchema(name="zhname",
                             description="公司中文名")
response_schemas=[enname_schema, zhname_schema]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
print(format_instructions)


The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":

```json
{
	"name": string  // 公司英文名
	"zhname": string  // 公司中文名
}
```


In [15]:
review_template2 = """\
For the following text, extract the following information:

公司英文名

公司中文名

text: {text}

{format_instructions}
"""

prompt = ChatPromptTemplate.from_template(template=review_template2)

messages = prompt.format_messages(text=customer_review, 
                                format_instructions=format_instructions)

In [16]:
print(messages[0].content)

For the following text, extract the following information:

公司英文名

公司中文名

text: 公司系经上海市人民政府沪府体改审(2000)019号《关于同意设立上海置信电气股份有限公司的批复》同意,由原上海置信电气工业有限公司整体变更而来。经中国证券监督管理委员会证监发行字(2003)113号文核准,公司于2003年9月18日公开发行人民币普通股股票2,500万股。
    2020年4月。公司名称由“上海置信电气股份有限公司”变更为“国网英大股份有限公司”,英文名称由“Shanghai Zhixin Electric Co.,Ltd.”变更为“State Grid Yingda CO.,LTD.


The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":

```json
{
	"name": string  // 公司英文名
	"zhname": string  // 公司中文名
}
```



In [19]:
print(output_parser.parse(llm(messages[0].content)))

{'name': 'State Grid Yingda Co., Ltd.', 'zhname': '国网英大股份有限公司'}
